diff --git a/openspace/.env.example b/openspace/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..d0328013e04fd41277679eff92402393cdb31d16 --- /dev/null +++ b/openspace/.env.example @@ -0,0 +1,53 @@ +# ============================================ +# OpenSpace Environment Variables +# Copy this file to .env and fill in your keys +# ============================================ + +# ---- LLM API Keys ---- +# At least one LLM API key is required for OpenSpace to function. +# OpenSpace uses LiteLLM for model routing, so the key you need depends on your chosen model. +# See https://docs.litellm.ai/docs/providers for supported providers. + +# Anthropic (for anthropic/claude-* models) +# ANTHROPIC_API_KEY= + +# OpenAI (for openai/gpt-* models) +# OPENAI_API_KEY= + +# OpenRouter (for openrouter/* models, e.g. openrouter/anthropic/claude-sonnet-4.5) +OPENROUTER_API_KEY= + +# ── OpenSpace Cloud (optional) ────────────────────────────── +# Register at https://open-space.cloud to get your key. +# Enables cloud skill search & upload; local features work without it. + +OPENSPACE_API_KEY=sk_xxxxxxxxxxxxxxxx + +# ---- GUI Backend (Anthropic Computer Use) ---- +# Required only if using the GUI backend. Uses the same ANTHROPIC_API_KEY above. +# Optional backup key for rate limit fallback: +# ANTHROPIC_API_KEY_BACKUP= + +# ---- Web Backend (Deep Research) ---- +# Required only if using the Web backend for deep research. +# Uses OpenRouter API by default: +# OPENROUTER_API_KEY= + +# ---- Embedding (Optional) ---- +# For remote embedding API instead of local model. +# If not set, OpenSpace uses a local embedding model (BAAI/bge-small-en-v1.5). +# EMBEDDING_BASE_URL= +# EMBEDDING_API_KEY= +# EMBEDDING_MODEL= "openai/text-embedding-3-small" + +# ---- E2B Sandbox (Optional) ---- +# Required only if sandbox mode is enabled in security config. +# E2B_API_KEY= + +# ---- Local Server (Optional) ---- +# Override the default local server URL (default: http://127.0.0.1:5000) +# Useful for remote VM integration (e.g., OSWorld). +# LOCAL_SERVER_URL=http://127.0.0.1:5000 + +# ---- Debug (Optional) ---- +# OPENSPACE_DEBUG=true \ No newline at end of file diff --git a/openspace/__init__.py b/openspace/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d99383ed6b729d2b779e036c28b3e6640c4575c1 --- /dev/null +++ b/openspace/__init__.py @@ -0,0 +1,71 @@ +from importlib import import_module as _imp +from typing import Dict as _Dict, Any as _Any, TYPE_CHECKING as _TYPE_CHECKING + +if _TYPE_CHECKING: + from openspace.tool_layer import OpenSpace as OpenSpace, OpenSpaceConfig as OpenSpaceConfig + from openspace.agents import GroundingAgent as GroundingAgent + from openspace.llm import LLMClient as LLMClient + from openspace.recording import RecordingManager as RecordingManager + +__version__ = "0.1.0" + +__all__ = [ + # Version + "__version__", + + # Main API + "OpenSpace", + "OpenSpaceConfig", + + # Core Components + "GroundingAgent", + "GroundingClient", + "LLMClient", + "BaseTool", + "ToolResult", + "BackendType", + + # Recording System + "RecordingManager", + "RecordingViewer", +] + +# Map attribute → sub-module that provides it +_attr_to_module: _Dict[str, str] = { + # Main API + "OpenSpace": "openspace.tool_layer", + "OpenSpaceConfig": "openspace.tool_layer", + + # Core Components + "GroundingAgent": "openspace.agents", + "GroundingClient": "openspace.grounding.core.grounding_client", + "LLMClient": "openspace.llm", + "BaseTool": "openspace.grounding.core.tool.base", + "ToolResult": "openspace.grounding.core.types", + "BackendType": "openspace.grounding.core.types", + + # Recording System + "RecordingManager": "openspace.recording", + "RecordingViewer": "openspace.recording.viewer", +} + + +def __getattr__(name: str) -> _Any: + """Dynamically import sub-modules on first attribute access. + + This keeps the *initial* package import lightweight and avoids raising + `ModuleNotFoundError` for optional / heavy dependencies until the + corresponding functionality is explicitly used. + """ + if name not in _attr_to_module: + raise AttributeError(f"module 'openspace' has no attribute '{name}'") + + module_name = _attr_to_module[name] + module = _imp(module_name) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__(): + return sorted(list(globals().keys()) + list(_attr_to_module.keys())) \ No newline at end of file diff --git a/openspace/__main__.py b/openspace/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..28ce3c283b21a59bef9e304de25d0517885f82f6 --- /dev/null +++ b/openspace/__main__.py @@ -0,0 +1,473 @@ +import asyncio +import argparse +import sys +import logging +from typing import Optional + +from openspace.tool_layer import OpenSpace, OpenSpaceConfig +from openspace.utils.logging import Logger +from openspace.utils.ui import create_ui, OpenSpaceUI +from openspace.utils.ui_integration import UIIntegration +from openspace.utils.cli_display import CLIDisplay +from openspace.utils.display import colorize + +logger = Logger.get_logger(__name__) + + +class UIManager: + def __init__(self, ui: Optional[OpenSpaceUI], ui_integration: Optional[UIIntegration]): + self.ui = ui + self.ui_integration = ui_integration + self._original_log_levels = {} + + async def start_live_display(self): + if not self.ui or not self.ui_integration: + return + + print() + print(colorize(" ▣ Starting real-time visualization...", 'c')) + print() + await asyncio.sleep(1) + + self._suppress_logs() + + await self.ui.start_live_display() + await self.ui_integration.start_monitoring(poll_interval=2.0) + + async def stop_live_display(self): + if not self.ui or not self.ui_integration: + return + + await self.ui_integration.stop_monitoring() + await self.ui.stop_live_display() + + self._restore_logs() + + def print_summary(self, result: dict): + if self.ui: + self.ui.print_summary(result) + else: + CLIDisplay.print_result_summary(result) + + def _suppress_logs(self): + log_names = ["openspace", "openspace.grounding", "openspace.agents"] + for name in log_names: + log = logging.getLogger(name) + self._original_log_levels[name] = log.level + log.setLevel(logging.CRITICAL) + + def _restore_logs(self): + for name, level in self._original_log_levels.items(): + logging.getLogger(name).setLevel(level) + self._original_log_levels.clear() + + +async def _execute_task(openspace: OpenSpace, query: str, ui_manager: UIManager): + await ui_manager.start_live_display() + result = await openspace.execute(query) + await ui_manager.stop_live_display() + ui_manager.print_summary(result) + return result + + +async def interactive_mode(openspace: OpenSpace, ui_manager: UIManager): + CLIDisplay.print_interactive_header() + + while True: + try: + prompt = colorize(">>> ", 'c', bold=True) + query = input(f"\n{prompt}").strip() + + if not query: + continue + + if query.lower() in ['exit', 'quit', 'q']: + print("\nExiting...") + break + + if query.lower() == 'status': + _print_status(openspace) + continue + + if query.lower() == 'help': + CLIDisplay.print_help() + continue + + CLIDisplay.print_task_header(query) + await _execute_task(openspace, query, ui_manager) + + except KeyboardInterrupt: + print("\n\nInterrupt signal detected, exiting...") + break + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + print(f"\nError: {e}") + + +async def single_query_mode(openspace: OpenSpace, query: str, ui_manager: UIManager): + CLIDisplay.print_task_header(query, title="▶ Single Query Execution") + await _execute_task(openspace, query, ui_manager) + + +def _print_status(openspace: OpenSpace): + """Print system status""" + from openspace.utils.display import Box, BoxStyle + + box = Box(width=70, style=BoxStyle.ROUNDED, color='bl') + print() + print(box.text_line(colorize("System Status", 'bl', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + status_lines = [ + f"Initialized: {colorize('Yes' if openspace.is_initialized() else 'No', 'g' if openspace.is_initialized() else 'rd')}", + f"Running: {colorize('Yes' if openspace.is_running() else 'No', 'y' if openspace.is_running() else 'g')}", + f"Model: {colorize(openspace.config.llm_model, 'c')}", + ] + + if openspace.is_initialized(): + backends = openspace.list_backends() + status_lines.append(f"Backends: {colorize(', '.join(backends), 'c')}") + + sessions = openspace.list_sessions() + status_lines.append(f"Active Sessions: {colorize(str(len(sessions)), 'y')}") + + for line in status_lines: + print(box.text_line(f" {line}", indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + +def _create_argument_parser() -> argparse.ArgumentParser: + """Create command-line argument parser""" + parser = argparse.ArgumentParser( + description='OpenSpace - Self-Evolving Skill Worker & Community', + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Subcommands + subparsers = parser.add_subparsers(dest='command', help='Available commands') + + # refresh-cache subcommand + cache_parser = subparsers.add_parser( + 'refresh-cache', + help='Refresh MCP tool cache (starts all servers once)' + ) + cache_parser.add_argument( + '--config', '-c', type=str, + help='MCP configuration file path' + ) + + # Basic arguments (for run mode) + parser.add_argument('--config', '-c', type=str, help='Configuration file path (JSON format)') + parser.add_argument('--query', '-q', type=str, help='Single query mode: execute query directly') + + # LLM arguments + parser.add_argument('--model', '-m', type=str, help='LLM model name') + + # Logging arguments + parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], help='Log level') + + # Execution arguments + parser.add_argument('--max-iterations', type=int, help='Maximum iteration count') + parser.add_argument('--timeout', type=float, help='LLM API call timeout (seconds)') + + # UI arguments + parser.add_argument('--interactive', '-i', action='store_true', help='Force interactive mode') + parser.add_argument('--no-ui', action='store_true', help='Disable visualization UI') + parser.add_argument('--ui-compact', action='store_true', help='Use compact UI layout') + + return parser + + +async def refresh_mcp_cache(config_path: Optional[str] = None): + """Refresh MCP tool cache by starting servers one by one and saving tool metadata.""" + from openspace.grounding.backends.mcp import MCPProvider, get_tool_cache + from openspace.grounding.core.types import SessionConfig, BackendType + from openspace.config import load_config, get_config + + print("Refreshing MCP tool cache...") + print("Servers will be started one by one (start -> get tools -> close).") + print() + + # Load config + if config_path: + config = load_config(config_path) + else: + config = get_config() + + # Get MCP config + mcp_config = getattr(config, 'mcp', None) or {} + if hasattr(mcp_config, 'model_dump'): + mcp_config = mcp_config.model_dump() + + # Skip dependency checks for refresh-cache (servers are pre-validated) + mcp_config["check_dependencies"] = False + + # Create provider + provider = MCPProvider(config=mcp_config) + await provider.initialize() + + servers = provider.list_servers() + total = len(servers) + print(f"Found {total} MCP servers configured") + print() + + cache = get_tool_cache() + cache.set_server_order(servers) # Preserve config order when saving + total_tools = 0 + success_count = 0 + skipped_count = 0 + failed_servers = [] + + # Load existing cache to skip already processed servers + existing_cache = cache.get_all_tools() + + # Timeout for each server (in seconds) + SERVER_TIMEOUT = 60 + + # Process servers one by one + for i, server_name in enumerate(servers, 1): + # Skip if already cached (resume support) + if server_name in existing_cache: + cached_tools = existing_cache[server_name] + total_tools += len(cached_tools) + skipped_count += 1 + print(f"[{i}/{total}] {server_name}... ⏭ cached ({len(cached_tools)} tools)") + continue + + print(f"[{i}/{total}] {server_name}...", end=" ", flush=True) + session_id = f"mcp-{server_name}" + + try: + # Create session and get tools with timeout protection + async with asyncio.timeout(SERVER_TIMEOUT): + # Create session for this server + cfg = SessionConfig( + session_name=session_id, + backend_type=BackendType.MCP, + connection_params={"server": server_name}, + ) + session = await provider.create_session(cfg) + + # Get tools from this server + tools = await session.list_tools() + + # Convert to metadata format + tool_metadata = [] + for tool in tools: + tool_metadata.append({ + "name": tool.schema.name, + "description": tool.schema.description or "", + "parameters": tool.schema.parameters or {}, + }) + + # Save to cache (incremental) + cache.save_server(server_name, tool_metadata) + + # Close session immediately to free resources + await provider.close_session(session_id) + + total_tools += len(tools) + success_count += 1 + print(f"✓ {len(tools)} tools") + + except asyncio.TimeoutError: + error_msg = f"Timeout after {SERVER_TIMEOUT}s" + failed_servers.append((server_name, error_msg)) + print(f"✗ {error_msg}") + + # Save failed server info to cache + cache.save_failed_server(server_name, error_msg) + + # Try to close session if it was created + try: + await provider.close_session(session_id) + except Exception: + pass + + except Exception as e: + error_msg = str(e) + failed_servers.append((server_name, error_msg)) + print(f"✗ {error_msg[:50]}") + + # Save failed server info to cache + cache.save_failed_server(server_name, error_msg) + + # Try to close session if it was created + try: + await provider.close_session(session_id) + except Exception: + pass + + print() + print(f"{'='*50}") + print(f"✓ Collected {total_tools} tools from {success_count + skipped_count}/{total} servers") + if skipped_count > 0: + print(f" (skipped {skipped_count} cached, processed {success_count} new)") + print(f"✓ Cache saved to: {cache.cache_path}") + + if failed_servers: + print(f"✗ Failed servers ({len(failed_servers)}):") + for name, err in failed_servers[:10]: + print(f" - {name}: {err[:60]}") + if len(failed_servers) > 10: + print(f" ... and {len(failed_servers) - 10} more (see cache file for details)") + + print() + print("Done! Future list_tools() calls will use cache (no server startup).") + + +def _load_config(args) -> OpenSpaceConfig: + """Load configuration""" + cli_overrides = {} + if args.model: + cli_overrides['llm_model'] = args.model + if args.max_iterations is not None: + cli_overrides['grounding_max_iterations'] = args.max_iterations + if args.timeout is not None: + cli_overrides['llm_timeout'] = args.timeout + if args.log_level: + cli_overrides['log_level'] = args.log_level + + try: + # Load from config file if provided + if args.config: + import json + with open(args.config, 'r', encoding='utf-8') as f: + config_dict = json.load(f) + + # Apply CLI overrides + config_dict.update(cli_overrides) + config = OpenSpaceConfig(**config_dict) + + print(f"✓ Loaded from config file: {args.config}") + else: + # Use default config + CLI overrides + config = OpenSpaceConfig(**cli_overrides) + print("✓ Using default configuration") + + if cli_overrides: + print(f"✓ CLI overrides: {', '.join(cli_overrides.keys())}") + + if args.log_level: + Logger.set_level(args.log_level) + + return config + + except Exception as e: + logger.error(f"Failed to load configuration: {e}") + sys.exit(1) + + +def _setup_ui(args) -> tuple[Optional[OpenSpaceUI], Optional[UIIntegration]]: + if args.no_ui: + CLIDisplay.print_banner() + return None, None + + ui = create_ui(enable_live=True, compact=args.ui_compact) + ui.print_banner() + ui_integration = UIIntegration(ui) + return ui, ui_integration + + +async def _initialize_openspace(config: OpenSpaceConfig, args) -> OpenSpace: + openspace = OpenSpace(config) + + init_steps = [("Initializing OpenSpace...", "loading")] + CLIDisplay.print_initialization_progress(init_steps, show_header=False) + + if not args.config: + original_log_level = Logger.get_logger("openspace").level + for log_name in ["openspace", "openspace.grounding", "openspace.agents"]: + Logger.get_logger(log_name).setLevel(logging.WARNING) + + await openspace.initialize() + + # Restore log level + if not args.config: + for log_name in ["openspace", "openspace.grounding", "openspace.agents"]: + Logger.get_logger(log_name).setLevel(original_log_level) + + # Print initialization results + backends = openspace.list_backends() + init_steps = [ + ("LLM Client", "ok"), + (f"Grounding Backends ({len(backends)} available)", "ok"), + ("Grounding Agent", "ok"), + ] + + if config.enable_recording: + init_steps.append(("Recording Manager", "ok")) + + CLIDisplay.print_initialization_progress(init_steps, show_header=True) + + return openspace + + +async def main(): + parser = _create_argument_parser() + args = parser.parse_args() + + # Handle subcommands + if args.command == 'refresh-cache': + await refresh_mcp_cache(args.config) + return 0 + + # Load configuration + config = _load_config(args) + + # Setup UI + ui, ui_integration = _setup_ui(args) + + # Print configuration + CLIDisplay.print_configuration(config) + + openspace = None + + try: + # Initialize OpenSpace + openspace = await _initialize_openspace(config, args) + + # Connect UI (if enabled) + if ui_integration: + ui_integration.attach_llm_client(openspace._llm_client) + ui_integration.attach_grounding_client(openspace._grounding_client) + CLIDisplay.print_system_ready() + + ui_manager = UIManager(ui, ui_integration) + + # Run appropriate mode + if args.query: + await single_query_mode(openspace, args.query, ui_manager) + else: + await interactive_mode(openspace, ui_manager) + + except KeyboardInterrupt: + print("\n\nInterrupt signal detected") + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + print(f"\nError: {e}") + return 1 + finally: + if openspace: + print("\nCleaning up resources...") + await openspace.cleanup() + + print("\nGoodbye!") + return 0 + + +def run_main(): + """Run main function""" + try: + exit_code = asyncio.run(main()) + sys.exit(exit_code) + except KeyboardInterrupt: + print("\n\nProgram interrupted") + sys.exit(0) + + +if __name__ == "__main__": + run_main() \ No newline at end of file diff --git a/openspace/agents/__init__.py b/openspace/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6710acf672c6d5a0d425c513c0e9bb563caa9378 --- /dev/null +++ b/openspace/agents/__init__.py @@ -0,0 +1,9 @@ +from openspace.agents.base import BaseAgent, AgentStatus, AgentRegistry +from openspace.agents.grounding_agent import GroundingAgent + +__all__ = [ + "BaseAgent", + "AgentStatus", + "AgentRegistry", + "GroundingAgent", +] \ No newline at end of file diff --git a/openspace/agents/base.py b/openspace/agents/base.py new file mode 100644 index 0000000000000000000000000000000000000000..024783617816690b1d0dbc7b676bca23e4264598 --- /dev/null +++ b/openspace/agents/base.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Any + +from openspace.utils.logging import Logger + +if TYPE_CHECKING: + from openspace.llm import LLMClient + from openspace.grounding.core.grounding_client import GroundingClient + from openspace.recording import RecordingManager + +logger = Logger.get_logger(__name__) + + +class BaseAgent(ABC): + def __init__( + self, + name: str, + backend_scope: Optional[List[str]] = None, + llm_client: Optional[LLMClient] = None, + grounding_client: Optional[GroundingClient] = None, + recording_manager: Optional[RecordingManager] = None, + ) -> None: + """ + Initialize the BaseAgent. + + Args: + name: Unique name for the agent + backend_scope: List of backend types this agent can access (e.g., ["gui", "shell", "mcp", "web", "system"]) + llm_client: LLM client for agent reasoning (optional, can be set later) + grounding_client: Reference to GroundingClient for tool execution + recording_manager: RecordingManager for recording execution + """ + self._name = name + self._grounding_client: Optional[GroundingClient] = grounding_client + self._backend_scope = backend_scope or [] + self._llm_client = llm_client + self._recording_manager: Optional[RecordingManager] = recording_manager + self._step = 0 + self._status = AgentStatus.ACTIVE + + self._register_self() + logger.info(f"Initialized {self.__class__.__name__}: {name}") + + @property + def name(self) -> str: + return self._name + + @property + def grounding_client(self) -> Optional[GroundingClient]: + """Get the grounding client.""" + return self._grounding_client + + @property + def backend_scope(self) -> List[str]: + return self._backend_scope + + @property + def llm_client(self) -> Optional[LLMClient]: + return self._llm_client + + @llm_client.setter + def llm_client(self, client: LLMClient) -> None: + self._llm_client = client + + @property + def recording_manager(self) -> Optional[RecordingManager]: + """Get the recording manager.""" + return self._recording_manager + + @property + def step(self) -> int: + return self._step + + @property + def status(self) -> str: + return self._status + + @abstractmethod + async def process(self, context: Dict[str, Any]) -> Dict[str, Any]: + pass + + @abstractmethod + def construct_messages(self, context: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Construct messages for LLM reasoning. + Context must contain 'instruction' key. + """ + pass + + async def get_llm_response( + self, + messages: List[Dict[str, Any]], + tools: Optional[List] = None, + **kwargs + ) -> Dict[str, Any]: + if not self._llm_client: + raise ValueError(f"LLM client not initialized for agent {self.name}") + + try: + response = await self._llm_client.complete( + messages=messages, + tools=tools, + **kwargs + ) + return response + except Exception as e: + logger.error(f"{self.name}: LLM call failed: {e}", exc_info=True) + raise + + def response_to_dict(self, response: str) -> Dict[str, Any]: + try: + if response.strip().startswith("```json") or response.strip().startswith("```"): + lines = response.strip().split('\n') + if lines and lines[0].startswith('```'): + lines = lines[1:] + end_idx = len(lines) + for i, line in enumerate(lines): + if line.strip() == '```': + end_idx = i + break + response = '\n'.join(lines[:end_idx]) + + return json.loads(response) + except json.JSONDecodeError as e: + # If parsing fails, try to find and extract just the JSON object/array + if "Extra data" in str(e): + try: + decoder = json.JSONDecoder() + obj, idx = decoder.raw_decode(response) + logger.warning( + f"{self.name}: Successfully extracted JSON but found extra text after position {idx}. " + f"Extra text: {response[idx:idx+100]}..." + ) + return obj + except Exception as e2: + logger.error(f"{self.name}: Failed to extract JSON even with raw_decode: {e2}") + + logger.error(f"{self.name}: Failed to parse response: {e}") + logger.error(f"{self.name}: Response content: {response[:500]}") + return {"error": "Failed to parse response", "raw": response} + + def increment_step(self) -> None: + self._step += 1 + + @classmethod + def _register_self(cls) -> None: + """Register the agent class in the registry upon instantiation.""" + # Get the actual instance class, not BaseAgent + if cls.__name__ != "BaseAgent" and cls.__name__ not in AgentRegistry._registry: + AgentRegistry.register(cls.__name__, cls) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self.name}, step={self.step}, status={self.status})>" + + +class AgentStatus: + """Constants for agent status.""" + ACTIVE = "active" + IDLE = "idle" + WAITING = "waiting" + + +class AgentRegistry: + """ + Registry for managing agent classes. + Allows dynamic registration and retrieval of agent types. + """ + + _registry: Dict[str, Type[BaseAgent]] = {} + + @classmethod + def register(cls, name: str, agent_cls: Type[BaseAgent]) -> None: + if name in cls._registry: + logger.warning(f"Agent class '{name}' already registered, overwriting") + cls._registry[name] = agent_cls + logger.debug(f"Registered agent class: {name}") + + @classmethod + def get_cls(cls, name: str) -> Type[BaseAgent]: + if name not in cls._registry: + raise ValueError(f"No agent class registered under '{name}'") + return cls._registry[name] + + @classmethod + def list_registered(cls) -> List[str]: + return list(cls._registry.keys()) + + @classmethod + def clear(cls) -> None: + cls._registry.clear() + logger.debug("Agent registry cleared") \ No newline at end of file diff --git a/openspace/agents/grounding_agent.py b/openspace/agents/grounding_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..4e36cba89fa1e09b10800327f66dcd787d9a49e2 --- /dev/null +++ b/openspace/agents/grounding_agent.py @@ -0,0 +1,1212 @@ +from __future__ import annotations + +import copy +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from openspace.agents.base import BaseAgent +from openspace.grounding.core.types import BackendType, ToolResult +from openspace.platforms.screenshot import ScreenshotClient +from openspace.prompts import GroundingAgentPrompts +from openspace.utils.logging import Logger + +if TYPE_CHECKING: + from openspace.llm import LLMClient + from openspace.grounding.core.grounding_client import GroundingClient + from openspace.recording import RecordingManager + from openspace.skill_engine import SkillRegistry + +logger = Logger.get_logger(__name__) + + +class GroundingAgent(BaseAgent): + def __init__( + self, + name: str = "GroundingAgent", + backend_scope: Optional[List[str]] = None, + llm_client: Optional[LLMClient] = None, + grounding_client: Optional[GroundingClient] = None, + recording_manager: Optional[RecordingManager] = None, + system_prompt: Optional[str] = None, + max_iterations: int = 15, + visual_analysis_timeout: float = 30.0, + tool_retrieval_llm: Optional[LLMClient] = None, + visual_analysis_model: Optional[str] = None, + ) -> None: + """ + Initialize the Grounding Agent. + + Args: + name: Agent name + backend_scope: List of backends this agent can access (None = all available) + llm_client: LLM client for reasoning + grounding_client: GroundingClient for tool execution + recording_manager: RecordingManager for recording execution + system_prompt: Custom system prompt + max_iterations: Maximum LLM reasoning iterations for self-correction + visual_analysis_timeout: Timeout for visual analysis LLM calls in seconds + tool_retrieval_llm: LLM client for tool retrieval filter (None = use llm_client) + visual_analysis_model: Model name for visual analysis (None = use llm_client.model) + """ + super().__init__( + name=name, + backend_scope=backend_scope or ["gui", "shell", "mcp", "web", "system"], + llm_client=llm_client, + grounding_client=grounding_client, + recording_manager=recording_manager + ) + + self._system_prompt = system_prompt or self._default_system_prompt() + self._max_iterations = max_iterations + self._visual_analysis_timeout = visual_analysis_timeout + self._tool_retrieval_llm = tool_retrieval_llm + self._visual_analysis_model = visual_analysis_model + + # Skill context injection (set externally before process()) + self._skill_context: Optional[str] = None + self._active_skill_ids: List[str] = [] + + # Skill registry for mid-iteration retrieve_skill tool + self._skill_registry: Optional["SkillRegistry"] = None + + # Tools from the last execution (available for post-execution analysis) + self._last_tools: List = [] + + logger.info(f"Grounding Agent initialized: {name}") + logger.info(f"Backend scope: {self._backend_scope}") + logger.info(f"Max iterations: {self._max_iterations}") + logger.info(f"Visual analysis timeout: {self._visual_analysis_timeout}s") + if tool_retrieval_llm: + logger.info(f"Tool retrieval model: {tool_retrieval_llm.model}") + if visual_analysis_model: + logger.info(f"Visual analysis model: {visual_analysis_model}") + + def set_skill_context( + self, + context: str, + skill_ids: Optional[List[str]] = None, + ) -> None: + """Inject skill guidance into the agent's system prompt. + + Called by ``OpenSpace.execute()`` before ``process()`` when skills + are matched. The context is a formatted string built by + ``SkillRegistry.build_context_injection()``. + + Args: + context: Formatted skill content for system prompt injection. + skill_ids: skill_id values of injected skills. + """ + self._skill_context = context if context else None + self._active_skill_ids = skill_ids or [] + if self._skill_context: + logger.info(f"Skill context set: {', '.join(self._active_skill_ids) or '(unnamed)'}") + + def clear_skill_context(self) -> None: + """Remove skill guidance (used before fallback execution).""" + if self._skill_context: + logger.info(f"Skill context cleared (was: {', '.join(self._active_skill_ids)})") + self._skill_context = None + self._active_skill_ids = [] + + @property + def has_skill_context(self) -> bool: + return self._skill_context is not None + + def set_skill_registry(self, registry: Optional["SkillRegistry"]) -> None: + """Attach a SkillRegistry so the agent can offer ``retrieve_skill`` as a tool.""" + self._skill_registry = registry + if registry: + count = len(registry.list_skills()) + logger.info(f"Skill registry attached ({count} skill(s) available for mid-iteration retrieval)") + + _MAX_SINGLE_CONTENT_CHARS = 30_000 + + @classmethod + def _cap_message_content(cls, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Truncate oversized individual message contents in-place. + + Targets tool-result messages and assistant messages that can + carry enormous file contents (read_file on large CSVs/scripts). + System messages and the first user instruction are never touched. + """ + cap = cls._MAX_SINGLE_CONTENT_CHARS + trimmed = 0 + for msg in messages: + content = msg.get("content") + if not isinstance(content, str) or len(content) <= cap: + continue + if msg.get("role") == "system": + continue + original_len = len(content) + msg["content"] = ( + content[: cap // 2] + + f"\n\n... [truncated {original_len - cap:,} chars] ...\n\n" + + content[-(cap // 2):] + ) + trimmed += 1 + if trimmed: + logger.info(f"Capped {trimmed} oversized message(s) to {cap:,} chars each") + return messages + + def _truncate_messages( + self, + messages: List[Dict[str, Any]], + keep_recent: int = 8, + max_tokens_estimate: int = 120000 + ) -> List[Dict[str, Any]]: + # First: cap any single oversized message to prevent one huge + # tool-result from dominating the context window. + messages = self._cap_message_content(messages) + + if len(messages) <= keep_recent + 2: # +2 for system and initial user + return messages + + total_text = json.dumps(messages, ensure_ascii=False) + estimated_tokens = len(total_text) // 4 + + if estimated_tokens < max_tokens_estimate: + return messages + + logger.info(f"Truncating message history: {len(messages)} messages, " + f"~{estimated_tokens:,} tokens -> keeping recent {keep_recent} rounds") + + system_messages = [] + user_instruction = None + conversation_messages = [] + + for msg in messages: + role = msg.get("role") + if role == "system": + system_messages.append(msg) + elif role == "user" and user_instruction is None: + user_instruction = msg + else: + conversation_messages.append(msg) + + recent_messages = conversation_messages[-(keep_recent * 2):] if conversation_messages else [] + + truncated = system_messages.copy() + if user_instruction: + truncated.append(user_instruction) + truncated.extend(recent_messages) + + logger.info(f"After truncation: {len(truncated)} messages, " + f"~{len(json.dumps(truncated, ensure_ascii=False))//4:,} tokens (estimated)") + + return truncated + + async def process(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Process a task execution request with multi-round iteration control. + """ + instruction = context.get("instruction", "") + if not instruction: + logger.error("Grounding Agent: No instruction provided") + return {"error": "No instruction provided", "status": "error"} + + # Store current instruction for visual analysis context + self._current_instruction = instruction + + logger.info(f"Grounding Agent: Processing instruction at step {self.step}") + + # Exist workspace files check + workspace_info = await self._check_workspace_artifacts(context) + if workspace_info["has_files"]: + context["workspace_artifacts"] = workspace_info + logger.info(f"Workspace has {len(workspace_info['files'])} existing files: {workspace_info['files']}") + + # Get available tools (auto-search with cap) + tools = await self._get_available_tools(instruction) + self._last_tools = tools # expose for post-execution analysis + + # Get search debug info (similarity scores, LLM selections) + search_debug_info = None + if self.grounding_client: + search_debug_info = self.grounding_client.get_last_search_debug_info() + + # Build retrieved tools list for return value + retrieved_tools_list = [] + for tool in tools: + tool_info = { + "name": getattr(tool, "name", str(tool)), + "description": getattr(tool, "description", ""), + } + # Prefer runtime_info.backend + # over backend_type (may be NOT_SET for cached RemoteTools) + runtime_info = getattr(tool, "_runtime_info", None) + if runtime_info and hasattr(runtime_info, "backend"): + tool_info["backend"] = runtime_info.backend.value if hasattr(runtime_info.backend, "value") else str(runtime_info.backend) + tool_info["server_name"] = runtime_info.server_name + elif hasattr(tool, "backend_type"): + tool_info["backend"] = tool.backend_type.value if hasattr(tool.backend_type, "value") else str(tool.backend_type) + + # Add similarity score if available + if search_debug_info and search_debug_info.get("tool_scores"): + for score_info in search_debug_info["tool_scores"]: + if score_info["name"] == tool_info["name"]: + tool_info["similarity_score"] = score_info["score"] + break + + retrieved_tools_list.append(tool_info) + + # Record retrieved tools + if self._recording_manager: + from openspace.recording import RecordingManager + await RecordingManager.record_retrieved_tools( + task_instruction=instruction, + tools=tools, + search_debug_info=search_debug_info, + ) + + # Initialize iteration state + max_iterations = context.get("max_iterations", self._max_iterations) + current_iteration = 0 + all_tool_results = [] + iteration_contexts = [] + consecutive_empty_responses = 0 # Track consecutive empty LLM responses + MAX_CONSECUTIVE_EMPTY = 5 # Exit after this many empty responses + + # Build initial messages + messages = self.construct_messages(context) + + # Record initial conversation setup once (system prompts + user instruction + tool definitions) + from openspace.recording import RecordingManager + await RecordingManager.record_conversation_setup( + setup_messages=copy.deepcopy(messages), + tools=tools, + ) + + try: + while current_iteration < max_iterations: + current_iteration += 1 + logger.info(f"Grounding Agent: Iteration {current_iteration}/{max_iterations}") + + # Strip skill context after the first iteration to save prompt tokens. + # Skills only need to guide the first LLM call; subsequent iterations + # already have the plan and tool results in context. + if current_iteration == 2 and self._skill_context: + skill_ctx = self._skill_context + messages = [ + m for m in messages + if not (m.get("role") == "system" and m.get("content") == skill_ctx) + ] + logger.info("Skill context removed from messages after first iteration") + + # Cap oversized individual messages every iteration to prevent + # a single huge tool result from ballooning all subsequent calls. + if current_iteration >= 2: + messages = self._cap_message_content(messages) + + # Truncate message history to prevent context length issues + # Start truncating after 5 iterations to keep context manageable + if current_iteration >= 5: + messages = self._truncate_messages( + messages, + keep_recent=8, + max_tokens_estimate=120000 + ) + + messages_input_snapshot = copy.deepcopy(messages) + + # [DISABLED] Iteration summary generation + # Tool results (including visual analysis) are already in context, + # LLM can make decisions directly without separate summary. + # To re-enable, uncomment below and pass iteration_summary_prompt to complete() + # iteration_summary_prompt = GroundingAgentPrompts.iteration_summary( + # instruction=instruction, + # iteration=current_iteration, + # max_iterations=max_iterations + # ) if context.get("auto_execute", True) else None + + # Call LLMClient for single round + # LLM will decide whether to call tools or finish with + llm_response = await self._llm_client.complete( + messages=messages, + tools=tools if context.get("auto_execute", True) else None, + execute_tools=context.get("auto_execute", True), + summary_prompt=None, # Disabled + tool_result_callback=self._visual_analysis_callback + ) + + # Update messages with LLM response + messages = llm_response["messages"] + + # Collect tool results + tool_results_this_iteration = llm_response.get("tool_results", []) + if tool_results_this_iteration: + all_tool_results.extend(tool_results_this_iteration) + + # [DISABLED] Iteration summary logging + # llm_summary = llm_response.get("iteration_summary") + # if llm_summary: + # logger.info(f"Iteration {current_iteration} summary: {llm_summary[:150]}...") + + assistant_message = llm_response.get("message", {}) + assistant_content = assistant_message.get("content", "") + + has_tool_calls = llm_response.get('has_tool_calls', False) + logger.info(f"Iteration {current_iteration} - Has tool calls: {has_tool_calls}, " + f"Tool results: {len(tool_results_this_iteration)}, " + f"Content length: {len(assistant_content)} chars") + + if len(assistant_content) > 0: + logger.info(f"Iteration {current_iteration} - Assistant content preview: {repr(assistant_content[:300])}") + consecutive_empty_responses = 0 # Reset counter on valid response + else: + if not has_tool_calls: + consecutive_empty_responses += 1 + logger.warning(f"Iteration {current_iteration} - NO tool calls and NO content " + f"(empty response {consecutive_empty_responses}/{MAX_CONSECUTIVE_EMPTY})") + + if consecutive_empty_responses >= MAX_CONSECUTIVE_EMPTY: + logger.error(f"Exiting due to {MAX_CONSECUTIVE_EMPTY} consecutive empty LLM responses. " + "This may indicate API issues, rate limiting, or context too long.") + break + else: + consecutive_empty_responses = 0 # Reset if we have tool calls + + # Snapshot messages after LLM call (accumulated context) + messages_output_snapshot = copy.deepcopy(messages) + + # Delta messages: only the messages produced in this iteration + # (avoids repeating system prompts / initial user instruction each time) + delta_messages = messages[len(messages_input_snapshot):] + + # Response metadata (lightweight; full content lives in delta_messages) + response_metadata = { + "has_tool_calls": has_tool_calls, + "tool_calls_count": len(tool_results_this_iteration), + } + iteration_context = { + "iteration": current_iteration, + "messages_input": messages_input_snapshot, + "messages_output": messages_output_snapshot, + "response_metadata": response_metadata, + } + iteration_contexts.append(iteration_context) + + # Real-time save to conversations.jsonl (delta only, no redundancy) + await RecordingManager.record_iteration_context( + iteration=current_iteration, + delta_messages=copy.deepcopy(delta_messages), + response_metadata=response_metadata, + ) + + # Check for completion token in assistant content + # [DISABLED] Also check in iteration summary when enabled + # is_complete = ( + # GroundingAgentPrompts.TASK_COMPLETE in assistant_content or + # (llm_summary and GroundingAgentPrompts.TASK_COMPLETE in llm_summary) + # ) + is_complete = GroundingAgentPrompts.TASK_COMPLETE in assistant_content + + if is_complete: + # Task is complete - LLM generated completion token + logger.info(f"Task completed at iteration {current_iteration} (found {GroundingAgentPrompts.TASK_COMPLETE})") + break + + else: + # LLM didn't generate , continue to next iteration + if tool_results_this_iteration: + logger.debug(f"Task in progress, LLM called {len(tool_results_this_iteration)} tools") + else: + logger.debug(f"Task in progress, LLM did not generate ") + + # Remove previous iteration guidance to avoid accumulation + messages = [ + msg for msg in messages + if not (msg.get("role") == "system" and "Iteration" in msg.get("content", "") and "complete" in msg.get("content", "")) + ] + + guidance_msg = { + "role": "system", + "content": f"Iteration {current_iteration} complete. " + f"Check if task is finished - if yes, output {GroundingAgentPrompts.TASK_COMPLETE}. " + f"If not, continue with next action." + } + messages.append(guidance_msg) + + # [DISABLED] Full iteration feedback with summary + # self._remove_previous_guidance(messages) + # feedback_msg = self._build_iteration_feedback( + # iteration=current_iteration, + # llm_summary=llm_summary, + # add_guidance=True + # ) + # if feedback_msg: + # messages.append(feedback_msg) + # logger.debug(f"Added iteration {current_iteration} feedback with guidance") + + continue + + # Build final result + result = await self._build_final_result( + instruction=instruction, + messages=messages, + all_tool_results=all_tool_results, + iterations=current_iteration, + max_iterations=max_iterations, + iteration_contexts=iteration_contexts, + retrieved_tools_list=retrieved_tools_list, + search_debug_info=search_debug_info, + ) + + # Record agent action to recording manager + if self._recording_manager: + await self._record_agent_execution(result, instruction) + + # Increment step + self.increment_step() + + logger.info(f"Grounding Agent: Execution completed with status: {result.get('status')}") + return result + + except Exception as e: + logger.error(f"Grounding Agent: Execution failed: {e}") + result = { + "error": str(e), + "status": "error", + "instruction": instruction, + "iteration": current_iteration + } + self.increment_step() + return result + + def _default_system_prompt(self) -> str: + """Default system prompt tailored to the agent's actual backend scope.""" + return GroundingAgentPrompts.build_system_prompt(self._backend_scope) + + def construct_messages( + self, + context: Dict[str, Any] + ) -> List[Dict[str, Any]]: + messages = [{"role": "system", "content": self._system_prompt}] + + # Get instruction from context + instruction = context.get("instruction", "") + if not instruction: + raise ValueError("context must contain 'instruction' field") + + # Add workspace directory + workspace_dir = context.get("workspace_dir") + if workspace_dir: + messages.append({ + "role": "system", + "content": GroundingAgentPrompts.workspace_directory(workspace_dir) + }) + + # Add workspace artifacts information + workspace_artifacts = context.get("workspace_artifacts") + if workspace_artifacts and workspace_artifacts.get("has_files"): + files = workspace_artifacts.get("files", []) + matching_files = workspace_artifacts.get("matching_files", []) + recent_files = workspace_artifacts.get("recent_files", []) + + if matching_files: + artifact_msg = GroundingAgentPrompts.workspace_matching_files(matching_files) + elif len(recent_files) >= 2: + artifact_msg = GroundingAgentPrompts.workspace_recent_files( + total_files=len(files), + recent_files=recent_files + ) + else: + artifact_msg = GroundingAgentPrompts.workspace_file_list(files) + + messages.append({ + "role": "system", + "content": artifact_msg + }) + + # Skill injection — only active (selected) skills, full content + if self._skill_context: + messages.append({ + "role": "system", + "content": self._skill_context + }) + logger.info(f"Injected active skill context ({len(self._active_skill_ids)} skill(s))") + + # User instruction + messages.append({"role": "user", "content": instruction}) + + return messages + + async def _get_available_tools(self, task_description: Optional[str]) -> List: + """ + Retrieve tools for the current execution phase. + + Both skill-augmented and normal modes use the same + ``get_tools_with_auto_search`` pipeline: + - Non-MCP tools (shell, gui, web, system) are always included. + - MCP tools are filtered by relevance only when their count + exceeds ``max_tools``. + + When skills are active, the shell backend is guaranteed to be in + scope (skills commonly reference ``shell_agent``). + + Falls back to returning all tools if anything fails. + """ + grounding_client = self.grounding_client + if not grounding_client: + return [] + + backends = [BackendType(name) for name in self._backend_scope] + + # Ensure shell backend is available when skills are active + # (skills commonly reference shell_agent, read_file, etc.) + if self.has_skill_context: + shell_bt = BackendType.SHELL + if shell_bt not in backends: + backends = list(backends) + [shell_bt] + logger.info("Added Shell backend to scope for skill file I/O") + + try: + retrieval_llm = self._tool_retrieval_llm or self._llm_client + tools = await grounding_client.get_tools_with_auto_search( + task_description=task_description, + backend=backends, + use_cache=True, + llm_callable=retrieval_llm, + ) + logger.info( + f"GroundingAgent selected {len(tools)} tools (auto-search) " + f"from {len(backends)} backends" + + (f" [skill-augmented]" if self.has_skill_context else "") + ) + except Exception as e: + logger.warning(f"Auto-search tools failed, falling back to full list: {e}") + tools = await self._load_all_tools(grounding_client) + + # Append retrieve_skill tool when skill registry is available + if self._skill_registry and self._skill_registry.list_skills(): + from openspace.skill_engine.retrieve_tool import RetrieveSkillTool + retrieve_llm = self._tool_retrieval_llm or self._llm_client + retrieve_tool = RetrieveSkillTool( + self._skill_registry, + backends=[b.value for b in backends], + llm_client=retrieve_llm, + skill_store=getattr(self, "_skill_store", None), + ) + retrieve_tool.bind_runtime_info( + backend=BackendType.SYSTEM, + session_name="internal", + ) + tools.append(retrieve_tool) + logger.info("Added retrieve_skill tool for mid-iteration skill retrieval") + + return tools + + async def _load_all_tools(self, grounding_client: "GroundingClient") -> List: + """Fallback: load all tools from all backends without search.""" + all_tools = [] + for backend_name in self._backend_scope: + try: + backend_type = BackendType(backend_name) + tools = await grounding_client.list_tools(backend=backend_type) + all_tools.extend(tools) + logger.debug(f"Retrieved {len(tools)} tools from backend: {backend_name}") + except Exception as e: + logger.debug(f"Could not get tools from {backend_name}: {e}") + + logger.info( + f"GroundingAgent fallback retrieved {len(all_tools)} tools " + f"from {len(self._backend_scope)} backends" + ) + return all_tools + + async def _visual_analysis_callback( + self, + result: ToolResult, + tool_name: str, + tool_call: Dict, + backend: str + ) -> ToolResult: + """ + Callback for LLMClient to handle visual analysis after tool execution. + """ + # 1. Check if LLM requested to skip visual analysis + skip_visual_analysis = False + try: + arguments = tool_call.function.arguments + if isinstance(arguments, str): + args = json.loads(arguments.strip() or "{}") + else: + args = arguments + + if isinstance(args, dict) and args.get("skip_visual_analysis"): + skip_visual_analysis = True + logger.info(f"Visual analysis skipped for {tool_name} (meta-parameter set by LLM)") + except Exception as e: + logger.debug(f"Could not parse tool arguments: {e}") + + # 2. If skip requested, return original result + if skip_visual_analysis: + return result + + # 3. Check if this backend needs visual analysis + if backend != "gui": + return result + + # 4. Check if tool has visual data + metadata = getattr(result, 'metadata', None) + has_screenshots = metadata and (metadata.get("screenshot") or metadata.get("screenshots")) + + # 5. If no visual data, try to capture a screenshot + if not has_screenshots: + try: + logger.info(f"No visual data from {tool_name}, capturing screenshot...") + screenshot_client = ScreenshotClient() + screenshot_bytes = await screenshot_client.capture() + + if screenshot_bytes: + # Add screenshot to result metadata + if metadata is None: + result.metadata = {} + metadata = result.metadata + metadata["screenshot"] = screenshot_bytes + has_screenshots = True + logger.info(f"Screenshot captured for visual analysis") + else: + logger.warning("Failed to capture screenshot") + except Exception as e: + logger.warning(f"Error capturing screenshot: {e}") + + # 6. If still no screenshots, return original result + if not has_screenshots: + logger.debug(f"No visual data available for {tool_name}") + return result + + # 7. Perform visual analysis + return await self._enhance_result_with_visual_context(result, tool_name) + + async def _enhance_result_with_visual_context( + self, + result: ToolResult, + tool_name: str + ) -> ToolResult: + """ + Enhance tool result with visual analysis for grounding agent workflows. + """ + import asyncio + import base64 + import litellm + + try: + metadata = getattr(result, 'metadata', None) + if not metadata: + return result + + # Collect all screenshots + screenshots_bytes = [] + + # Check for multiple screenshots first + if metadata.get("screenshots"): + screenshots_list = metadata["screenshots"] + if isinstance(screenshots_list, list): + screenshots_bytes = [s for s in screenshots_list if s] + # Fall back to single screenshot + elif metadata.get("screenshot"): + screenshots_bytes = [metadata["screenshot"]] + + if not screenshots_bytes: + return result + + # Select key screenshots if there are too many + selected_screenshots = self._select_key_screenshots(screenshots_bytes, max_count=3) + + # Convert to base64 + visual_b64_list = [] + for visual_data in selected_screenshots: + if isinstance(visual_data, bytes): + visual_b64_list.append(base64.b64encode(visual_data).decode('utf-8')) + else: + visual_b64_list.append(visual_data) # Already base64 + + # Build prompt based on number of screenshots + num_screenshots = len(visual_b64_list) + + prompt = GroundingAgentPrompts.visual_analysis( + tool_name=tool_name, + num_screenshots=num_screenshots, + task_description=getattr(self, '_current_instruction', '') + ) + + # Build content with text prompt + all images + content = [{"type": "text", "text": prompt}] + for visual_b64 in visual_b64_list: + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{visual_b64}" + } + }) + + # Use dedicated visual analysis model if configured, otherwise use main LLM model + visual_model = self._visual_analysis_model or (self._llm_client.model if self._llm_client else "openrouter/anthropic/claude-sonnet-4.5") + response = await asyncio.wait_for( + litellm.acompletion( + model=visual_model, + messages=[{ + "role": "user", + "content": content + }], + timeout=self._visual_analysis_timeout + ), + timeout=self._visual_analysis_timeout + 5 + ) + + analysis = response.choices[0].message.content.strip() + + # Inject visual analysis into content + original_content = result.content or "(no text output)" + enhanced_content = f"{original_content}\n\n**Visual content**: {analysis}" + + # Create enhanced result + enhanced_result = ToolResult( + status=result.status, + content=enhanced_content, + error=result.error, + metadata={**metadata, "visual_analyzed": True, "visual_analysis": analysis}, + execution_time=result.execution_time + ) + + logger.info(f"Enhanced {tool_name} result with visual analysis ({num_screenshots} screenshot(s))") + return enhanced_result + + except asyncio.TimeoutError: + logger.warning(f"Visual analysis timed out for {tool_name}, returning original result") + return result + except Exception as e: + logger.warning(f"Failed to analyze visual content for {tool_name}: {e}") + return result + + def _select_key_screenshots( + self, + screenshots: List[bytes], + max_count: int = 3 + ) -> List[bytes]: + """ + Select key screenshots if there are too many. + """ + if len(screenshots) <= max_count: + return screenshots + + selected_indices = set() + + # Always include last (final state) + selected_indices.add(len(screenshots) - 1) + + # If room, include first (initial state) + if max_count >= 2: + selected_indices.add(0) + + # Fill remaining slots with evenly spaced middle screenshots + remaining_slots = max_count - len(selected_indices) + if remaining_slots > 0: + # Calculate spacing + available_indices = [ + i for i in range(1, len(screenshots) - 1) + if i not in selected_indices + ] + + if available_indices: + step = max(1, len(available_indices) // (remaining_slots + 1)) + for i in range(remaining_slots): + idx = min((i + 1) * step, len(available_indices) - 1) + if idx < len(available_indices): + selected_indices.add(available_indices[idx]) + + # Return screenshots in original order + selected = [screenshots[i] for i in sorted(selected_indices)] + + logger.debug( + f"Selected {len(selected)} screenshots at indices {sorted(selected_indices)} " + f"from total of {len(screenshots)}" + ) + + return selected + + def _get_workspace_path(self, context: Dict[str, Any]) -> Optional[str]: + """ + Get workspace directory path from context. + """ + return context.get("workspace_dir") + + def _scan_workspace_files( + self, + workspace_path: str, + recent_threshold: int = 600 # seconds + ) -> Dict[str, Any]: + """ + Scan workspace directory and collect file information. + + Args: + workspace_path: Path to workspace directory + recent_threshold: Threshold in seconds for recent files + + Returns: + Dictionary with file information: + - files: List of all filenames + - file_details: Dict mapping filename to file info (size, modified, age_seconds) + - recent_files: List of recently modified filenames + """ + import os + import time + + result = { + "files": [], + "file_details": {}, + "recent_files": [] + } + + if not workspace_path or not os.path.exists(workspace_path): + return result + + # Recording system files to exclude from workspace scanning + excluded_files = {"metadata.json", "traj.jsonl"} + + try: + current_time = time.time() + + for filename in os.listdir(workspace_path): + filepath = os.path.join(workspace_path, filename) + if os.path.isfile(filepath) and filename not in excluded_files: + result["files"].append(filename) + + # Get file stats + stat = os.stat(filepath) + file_info = { + "size": stat.st_size, + "modified": stat.st_mtime, + "age_seconds": current_time - stat.st_mtime + } + result["file_details"][filename] = file_info + + # Track recently created/modified files + if file_info["age_seconds"] < recent_threshold: + result["recent_files"].append(filename) + + result["files"] = sorted(result["files"]) + + except Exception as e: + logger.debug(f"Error scanning workspace files: {e}") + + return result + + async def _check_workspace_artifacts(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Check workspace directory for existing artifacts that might be relevant to the task. + Enhanced to detect if task might already be completed. + """ + import re + + workspace_info = {"has_files": False, "files": [], "file_details": {}, "recent_files": []} + + try: + # Get workspace path + workspace_path = self._get_workspace_path(context) + + # Scan workspace files + scan_result = self._scan_workspace_files(workspace_path, recent_threshold=600) + + if scan_result["files"]: + workspace_info["has_files"] = True + workspace_info["files"] = scan_result["files"] + workspace_info["file_details"] = scan_result["file_details"] + workspace_info["recent_files"] = scan_result["recent_files"] + + logger.info(f"Grounding Agent: Found {len(scan_result['files'])} existing files in workspace " + f"({len(scan_result['recent_files'])} recent)") + + # Check if instruction mentions specific filenames + instruction = context.get("instruction", "") + if instruction: + # Look for potential file references in instruction + potential_outputs = [] + # Match common file patterns: filename.ext, "filename", 'filename' + file_patterns = re.findall(r'["\']?([a-zA-Z0-9_\-]+\.[a-zA-Z0-9]+)["\']?', instruction) + for pattern in file_patterns: + if pattern in scan_result["files"]: + potential_outputs.append(pattern) + + if potential_outputs: + workspace_info["matching_files"] = potential_outputs + logger.info(f"Grounding Agent: Found {len(potential_outputs)} files matching task: {potential_outputs}") + + except Exception as e: + logger.debug(f"Could not check workspace artifacts: {e}") + + return workspace_info + + def _build_iteration_feedback( + self, + iteration: int, + llm_summary: Optional[str] = None, + add_guidance: bool = True + ) -> Optional[Dict[str, str]]: + """ + Build feedback message to add to next iteration. + """ + if not llm_summary: + return None + + feedback_content = GroundingAgentPrompts.iteration_feedback( + iteration=iteration, + llm_summary=llm_summary, + add_guidance=add_guidance + ) + + return { + "role": "system", + "content": feedback_content + } + + def _remove_previous_guidance(self, messages: List[Dict[str, Any]]) -> None: + """ + Remove guidance section from previous iteration feedback messages. + """ + for msg in messages: + if msg.get("role") == "system": + content = msg.get("content", "") + # Check if this is an iteration feedback message with guidance + if "## Iteration" in content and "Summary" in content and "---" in content: + # Remove everything from "---" onwards (the guidance part) + summary_only = content.split("---")[0].strip() + msg["content"] = summary_only + + async def _generate_final_summary( + self, + instruction: str, + messages: List[Dict], + iterations: int + ) -> tuple[str, bool, List[Dict]]: + """ + Generate final summary across all iterations for reporting to upper layer. + + Returns: + tuple[str, bool, List[Dict]]: (summary_text, success_flag, context_used) + - summary_text: The generated summary or error message + - success_flag: True if summary was generated successfully, False otherwise + - context_used: The cleaned messages used for generating summary + """ + final_summary_prompt = { + "role": "user", + "content": GroundingAgentPrompts.final_summary( + instruction=instruction, + iterations=iterations + ) + } + + clean_messages = [] + for msg in messages: + # Skip tool result messages + if msg.get("role") == "tool": + continue + # Copy message and remove tool_calls if present + clean_msg = msg.copy() + if "tool_calls" in clean_msg: + del clean_msg["tool_calls"] + clean_messages.append(clean_msg) + + clean_messages.append(final_summary_prompt) + + # Save context for return + context_for_return = copy.deepcopy(clean_messages) + + try: + # Call LLMClient to generate final summary (without tools) + summary_response = await self._llm_client.complete( + messages=clean_messages, + tools=None, + execute_tools=False + ) + + final_summary = summary_response.get("message", {}).get("content", "") + + if final_summary: + logger.info(f"Generated final summary: {final_summary[:200]}...") + return final_summary, True, context_for_return + else: + logger.warning("LLM returned empty final summary") + return f"Task completed after {iterations} iteration(s). Check execution history for details.", True, context_for_return + + except Exception as e: + logger.error(f"Error generating final summary: {e}") + return f"Task completed after {iterations} iteration(s), but failed to generate summary: {str(e)}", False, context_for_return + + + async def _build_final_result( + self, + instruction: str, + messages: List[Dict], + all_tool_results: List[Dict], + iterations: int, + max_iterations: int, + iteration_contexts: List[Dict] = None, + retrieved_tools_list: List[Dict] = None, + search_debug_info: Dict[str, Any] = None, + ) -> Dict[str, Any]: + """ + Build final execution result. + + Args: + instruction: Original instruction + messages: Complete conversation history (including all iteration summaries) + all_tool_results: All tool execution results + iterations: Number of iterations performed + max_iterations: Maximum allowed iterations + iteration_contexts: Context snapshots for each iteration + retrieved_tools_list: List of tools retrieved for this task + search_debug_info: Debug info from tool search (similarity scores, LLM selections) + """ + is_complete = self._check_task_completion(messages) + + tool_executions = self._format_tool_executions(all_tool_results) + + result = { + "instruction": instruction, + "step": self.step, + "iterations": iterations, + "tool_executions": tool_executions, + "messages": messages, + "iteration_contexts": iteration_contexts or [], + "retrieved_tools_list": retrieved_tools_list or [], + "search_debug_info": search_debug_info, + "active_skills": list(self._active_skill_ids), + "keep_session": True + } + + if is_complete: + logger.info("Task completed with marker") + # Use LLM's own completion response directly (no extra LLM call needed) + # LLM already generates a summary before outputting + last_response = self._extract_last_assistant_message(messages) + # Remove the token from response for cleaner output + result["response"] = last_response.replace(GroundingAgentPrompts.TASK_COMPLETE, "").strip() + result["status"] = "success" + + # [DISABLED] Extra LLM call to generate final summary + # final_summary, summary_success, final_summary_context = await self._generate_final_summary( + # instruction=instruction, + # messages=messages, + # iterations=iterations + # ) + # result["response"] = final_summary + # result["final_summary_context"] = final_summary_context + else: + result["response"] = self._extract_last_assistant_message(messages) + result["status"] = "incomplete" + result["warning"] = ( + f"Task reached max iterations ({max_iterations}) without completion. " + f"This may indicate the task needs more steps or clarification." + ) + + return result + + def _format_tool_executions(self, all_tool_results: List[Dict]) -> List[Dict]: + executions = [] + for tr in all_tool_results: + tool_result_obj = tr.get("result") + tool_call = tr.get("tool_call") + + status = "unknown" + if hasattr(tool_result_obj, 'status'): + status_obj = tool_result_obj.status + status = getattr(status_obj, 'value', status_obj) + + # Extract tool_name and arguments from tool_call object (litellm format) + tool_name = "unknown" + arguments = {} + if tool_call is not None: + if hasattr(tool_call, 'function'): + # tool_call is an object with .function attribute + tool_name = getattr(tool_call.function, 'name', 'unknown') + args_raw = getattr(tool_call.function, 'arguments', '{}') + if isinstance(args_raw, str): + try: + arguments = json.loads(args_raw) if args_raw.strip() else {} + except json.JSONDecodeError: + arguments = {} + else: + arguments = args_raw if isinstance(args_raw, dict) else {} + elif isinstance(tool_call, dict): + # Fallback: tool_call is a dict + func = tool_call.get("function", {}) + tool_name = func.get("name", "unknown") + args_raw = func.get("arguments", "{}") + if isinstance(args_raw, str): + try: + arguments = json.loads(args_raw) if args_raw.strip() else {} + except json.JSONDecodeError: + arguments = {} + else: + arguments = args_raw if isinstance(args_raw, dict) else {} + + executions.append({ + "tool_name": tool_name, + "arguments": arguments, + "backend": tr.get("backend"), + "server_name": tr.get("server_name"), + "status": status, + "content": tool_result_obj.content if hasattr(tool_result_obj, 'content') else None, + "error": tool_result_obj.error if hasattr(tool_result_obj, 'error') else None, + "execution_time": tool_result_obj.execution_time if hasattr(tool_result_obj, 'execution_time') else None, + "metadata": tool_result_obj.metadata if hasattr(tool_result_obj, 'metadata') else {}, + }) + return executions + + def _check_task_completion(self, messages: List[Dict]) -> bool: + for msg in reversed(messages): + if msg.get("role") == "assistant": + content = msg.get("content", "") + return GroundingAgentPrompts.TASK_COMPLETE in content + return False + + def _extract_last_assistant_message(self, messages: List[Dict]) -> str: + for msg in reversed(messages): + if msg.get("role") == "assistant": + return msg.get("content", "") + return "" + + async def _record_agent_execution( + self, + result: Dict[str, Any], + instruction: str + ) -> None: + """ + Record agent execution to recording manager. + + Args: + result: Execution result + instruction: Original instruction + """ + if not self._recording_manager: + return + + # Extract tool execution summary + tool_summary = [] + if result.get("tool_executions"): + for exec_info in result["tool_executions"]: + tool_summary.append({ + "tool": exec_info.get("tool_name", "unknown"), + "backend": exec_info.get("backend", "unknown"), + "status": exec_info.get("status", "unknown"), + }) + + await self._recording_manager.record_agent_action( + agent_name=self.name, + action_type="execute", + input_data={"instruction": instruction}, + reasoning={ + "response": result.get("response", ""), + "tools_selected": tool_summary, + }, + output_data={ + "status": result.get("status", "unknown"), + "iterations": result.get("iterations", 0), + "num_tool_executions": len(result.get("tool_executions", [])), + }, + metadata={ + "step": self.step, + "instruction": instruction, + } + ) \ No newline at end of file diff --git a/openspace/cloud/__init__.py b/openspace/cloud/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20d524b5cfcd03b8f631c5d266a0e236fbb35e1c --- /dev/null +++ b/openspace/cloud/__init__.py @@ -0,0 +1,31 @@ +"""Cloud platform integration. + +Provides: + - ``OpenSpaceClient`` — HTTP client for the cloud API + - ``get_openspace_auth`` — credential resolution + - ``SkillSearchEngine`` — hybrid BM25 + embedding search + - ``generate_embedding`` — OpenAI embedding generation +""" + +from openspace.cloud.auth import get_openspace_auth + + +def __getattr__(name: str): + if name == "OpenSpaceClient": + from openspace.cloud.client import OpenSpaceClient + return OpenSpaceClient + if name == "SkillSearchEngine": + from openspace.cloud.search import SkillSearchEngine + return SkillSearchEngine + if name == "generate_embedding": + from openspace.cloud.embedding import generate_embedding + return generate_embedding + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "OpenSpaceClient", + "get_openspace_auth", + "SkillSearchEngine", + "generate_embedding", +] diff --git a/openspace/cloud/auth.py b/openspace/cloud/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..a00f267e918dd353beb5010a61238c5b17c68b03 --- /dev/null +++ b/openspace/cloud/auth.py @@ -0,0 +1,102 @@ +"""OpenSpace cloud platform authentication. + +Resolution order for OPENSPACE_API_KEY: + 1. ``OPENSPACE_API_KEY`` env var + 2. Auto-detect from host agent config (MCP env block) + 3. Empty (caller treats as "not configured"). + +Base URL resolution: + 1. ``OPENSPACE_API_BASE`` env var + 2. Default: ``https://open-space.cloud/api/v1`` +""" + +from __future__ import annotations + +import logging +import os +from typing import Dict, Optional + +logger = logging.getLogger("openspace.cloud") + +OPENSPACE_DEFAULT_BASE = "https://open-space.cloud/api/v1" + + +def get_openspace_auth() -> tuple[Dict[str, str], str]: + """Resolve OpenSpace credentials and base URL. + + Returns: + ``(auth_headers, api_base)`` — headers dict ready for HTTP requests + and the API base URL. If no credentials are found, ``auth_headers`` + is empty. + """ + from openspace.host_detection import read_host_mcp_env + + auth_headers: Dict[str, str] = {} + api_base = OPENSPACE_DEFAULT_BASE + + # Tier 1: env vars + env_key = os.environ.get("OPENSPACE_API_KEY", "").strip() + env_base = os.environ.get("OPENSPACE_API_BASE", "").strip() + + if env_key: + auth_headers["X-API-Key"] = env_key + if env_base: + api_base = env_base.rstrip("/") + logger.info("OpenSpace auth: using OPENSPACE_API_KEY env var") + return auth_headers, api_base + + # Tier 2: host agent config MCP env block + mcp_env = read_host_mcp_env() + cfg_key = str(mcp_env.get("OPENSPACE_API_KEY", "")).strip() + cfg_base = str(mcp_env.get("OPENSPACE_API_BASE", "")).strip() + + if cfg_key: + auth_headers["X-API-Key"] = cfg_key + if cfg_base: + api_base = cfg_base.rstrip("/") + logger.info("OpenSpace auth: using OPENSPACE_API_KEY from host agent MCP env config") + return auth_headers, api_base + + return auth_headers, api_base + + +def get_api_base(cli_override: Optional[str] = None) -> str: + """Resolve OpenSpace API base URL (for CLI scripts). + + Priority: ``cli_override`` → env var → host agent config → default. + """ + from openspace.host_detection import read_host_mcp_env + + if cli_override: + return cli_override.rstrip("/") + env_base = os.environ.get("OPENSPACE_API_BASE", "").strip() + if env_base: + return env_base.rstrip("/") + mcp_env = read_host_mcp_env() + cfg_base = str(mcp_env.get("OPENSPACE_API_BASE", "")).strip() + if cfg_base: + return cfg_base.rstrip("/") + return OPENSPACE_DEFAULT_BASE + + +def get_auth_headers_or_exit() -> Dict[str, str]: + """Resolve auth headers for CLI scripts. Exits on failure.""" + import sys + from openspace.host_detection import read_host_mcp_env + + env_key = os.environ.get("OPENSPACE_API_KEY", "").strip() + if env_key: + return {"X-API-Key": env_key} + + mcp_env = read_host_mcp_env() + cfg_key = str(mcp_env.get("OPENSPACE_API_KEY", "")).strip() + if cfg_key: + return {"X-API-Key": cfg_key} + + print( + "ERROR: No OPENSPACE_API_KEY configured.\n" + " Register at https://open-space.cloud to obtain a key, then add it to\n" + " your host agent config in the OpenSpace MCP env block.", + file=sys.stderr, + ) + sys.exit(1) diff --git a/openspace/cloud/cli/__init__.py b/openspace/cloud/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openspace/cloud/cli/download_skill.py b/openspace/cloud/cli/download_skill.py new file mode 100644 index 0000000000000000000000000000000000000000..67b80293227bf850bd0813aef08945f75a298c69 --- /dev/null +++ b/openspace/cloud/cli/download_skill.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +"""Download a skill from the OpenSpace cloud platform. + +Usage: + openspace-download-skill --skill-id "weather__imp_abc12345" --output-dir ./skills/ +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +from openspace.cloud.auth import get_api_base, get_auth_headers_or_exit +from openspace.cloud.client import OpenSpaceClient, CloudError + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="openspace-download-skill", + description="Download a skill from OpenSpace's cloud community", + ) + parser.add_argument("--skill-id", required=True, help="Cloud skill record ID") + parser.add_argument("--output-dir", required=True, help="Target directory for extraction") + parser.add_argument("--api-base", default=None, help="Override API base URL") + parser.add_argument("--force", action="store_true", help="Overwrite existing skill directory") + + args = parser.parse_args() + + api_base = get_api_base(args.api_base) + headers = get_auth_headers_or_exit() + output_base = Path(args.output_dir).resolve() + + print(f"Fetching skill: {args.skill_id} ...", file=sys.stderr) + + try: + client = OpenSpaceClient(headers, api_base) + result = client.import_skill(args.skill_id, output_base) + except CloudError as e: + print(f"ERROR: {e}", file=sys.stderr) + sys.exit(1) + + if result.get("status") == "already_exists" and not args.force: + print( + f"ERROR: Skill directory already exists: {result.get('local_path')}\n" + f" Use --force to overwrite.", + file=sys.stderr, + ) + sys.exit(1) + + files = result.get("files", []) + local_path = result.get("local_path", "") + print(f" Extracted {len(files)} file(s) to {local_path}", file=sys.stderr) + for f in files: + print(f" {f}", file=sys.stderr) + + print(json.dumps(result, indent=2, ensure_ascii=False)) + print(f"\nSkill downloaded to: {local_path}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/openspace/cloud/cli/upload_skill.py b/openspace/cloud/cli/upload_skill.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a1e353a35e6702ce95e62e50cd8592fe804cf5 --- /dev/null +++ b/openspace/cloud/cli/upload_skill.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""Upload a skill to the OpenSpace cloud platform. + +Usage: + openspace-upload-skill --skill-dir ./my-skill --visibility public --origin imported + openspace-upload-skill --skill-dir ./my-skill --visibility private --origin fixed --parent-ids "parent_id" +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +from openspace.cloud.auth import get_api_base, get_auth_headers_or_exit +from openspace.cloud.client import OpenSpaceClient, CloudError + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="openspace-upload-skill", + description="Upload a skill to OpenSpace's cloud community", + ) + parser.add_argument("--skill-dir", required=True, help="Path to skill directory (must contain SKILL.md)") + parser.add_argument("--visibility", required=True, choices=["public", "private"]) + parser.add_argument("--origin", default="imported", choices=["imported", "captured", "derived", "fixed"]) + parser.add_argument("--parent-ids", default="", help="Comma-separated parent skill IDs") + parser.add_argument("--tags", default="", help="Comma-separated tags") + parser.add_argument("--created-by", default="", help="Creator display name") + parser.add_argument("--change-summary", default="", help="Change summary text") + parser.add_argument("--api-base", default=None, help="Override API base URL") + parser.add_argument("--dry-run", action="store_true", help="List files without uploading") + + args = parser.parse_args() + + skill_dir = Path(args.skill_dir).resolve() + if not skill_dir.is_dir(): + print(f"ERROR: Not a directory: {skill_dir}", file=sys.stderr) + sys.exit(1) + + api_base = get_api_base(args.api_base) + + if args.dry_run: + files = OpenSpaceClient._collect_files(skill_dir) + print(f"Dry run — would upload {len(files)} file(s):", file=sys.stderr) + for f in files: + print(f" {f.relative_to(skill_dir)}", file=sys.stderr) + sys.exit(0) + + headers = get_auth_headers_or_exit() + + parent_ids = [p.strip() for p in args.parent_ids.split(",") if p.strip()] + tags = [t.strip() for t in args.tags.split(",") if t.strip()] + + print(f"\n{'='*60}", file=sys.stderr) + print(f"Upload Skill: {skill_dir.name}", file=sys.stderr) + print(f" Visibility: {args.visibility}", file=sys.stderr) + print(f" Origin: {args.origin}", file=sys.stderr) + print(f" API Base: {api_base}", file=sys.stderr) + print(f"{'='*60}\n", file=sys.stderr) + + try: + client = OpenSpaceClient(headers, api_base) + result = client.upload_skill( + skill_dir, + visibility=args.visibility, + origin=args.origin, + parent_skill_ids=parent_ids, + tags=tags, + created_by=args.created_by, + change_summary=args.change_summary, + ) + except CloudError as e: + print(f"ERROR: {e}", file=sys.stderr) + sys.exit(1) + + print(f"\nUpload complete!", file=sys.stderr) + print(json.dumps(result, indent=2, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/openspace/cloud/client.py b/openspace/cloud/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0792f41a65c1d17cf4cfa07471f20abd27f878 --- /dev/null +++ b/openspace/cloud/client.py @@ -0,0 +1,497 @@ +"""OpenSpace cloud platform HTTP client. + +All methods are **synchronous** (use ``urllib``). In async contexts +(MCP server), wrap calls with ``asyncio.to_thread()``. + +Provides both low-level HTTP operations and higher-level workflows: + - ``fetch_record`` / ``download_artifact`` / ``fetch_metadata`` + - ``stage_artifact`` / ``create_record`` + - ``upload_skill`` (stage → diff → create — full workflow) + - ``import_skill`` (fetch → download → extract — full workflow) +""" + +from __future__ import annotations + +import difflib +import io +import json +import logging +import os +import uuid +import urllib.error +import urllib.parse +import urllib.request +import zipfile +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger("openspace.cloud") + +SKILL_FILENAME = "SKILL.md" +SKILL_ID_FILENAME = ".skill_id" + +_TEXT_EXTENSIONS = frozenset({ + ".md", ".txt", ".yaml", ".yml", ".json", ".py", ".sh", ".toml", +}) + + +class CloudError(Exception): + """Raised when a cloud API call fails.""" + + def __init__(self, message: str, status_code: int = 0, body: str = ""): + super().__init__(message) + self.status_code = status_code + self.body = body + + +class OpenSpaceClient: + """HTTP client for the OpenSpace cloud API. + + Args: + auth_headers: Pre-resolved auth headers (from ``get_openspace_auth``). + api_base: API base URL (e.g. ``https://open-space.cloud/api/v1``). + """ + + _DEFAULT_UA = "OpenSpace-Client/1.0" + + def __init__(self, auth_headers: Dict[str, str], api_base: str): + if not auth_headers: + raise CloudError( + "No OPENSPACE_API_KEY configured. " + "Register at https://open-space.cloud to obtain a key." + ) + self._headers = { + "User-Agent": self._DEFAULT_UA, + **auth_headers, + } + self._base = api_base.rstrip("/") + + def _request( + self, + method: str, + path: str, + *, + body: Optional[bytes] = None, + extra_headers: Optional[Dict[str, str]] = None, + timeout: int = 30, + ) -> tuple[int, bytes]: + """Execute HTTP request. Returns ``(status_code, response_body)``.""" + url = f"{self._base}{path}" + headers = {**self._headers} + if extra_headers: + headers.update(extra_headers) + + req = urllib.request.Request(url, data=body, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + return resp.status, resp.read() + except urllib.error.HTTPError as e: + resp_body = e.read().decode("utf-8", errors="replace") + raise CloudError( + f"HTTP {e.code}: {resp_body[:500]}", + status_code=e.code, + body=resp_body, + ) + except urllib.error.URLError as e: + raise CloudError(f"Connection failed: {e.reason}") + + def _get_json(self, path: str, timeout: int = 30) -> Dict[str, Any]: + _, data = self._request("GET", path, timeout=timeout) + return json.loads(data.decode("utf-8")) + + def fetch_record(self, record_id: str) -> Dict[str, Any]: + """GET /records/{record_id} — fetch record metadata.""" + return self._get_json(f"/records/{urllib.parse.quote(record_id)}") + + def download_artifact(self, record_id: str) -> bytes: + """GET /records/{record_id}/download — download artifact zip bytes.""" + _, data = self._request( + "GET", + f"/records/{urllib.parse.quote(record_id)}/download", + timeout=120, + ) + return data + + def fetch_metadata( + self, + *, + include_embedding: bool = False, + limit: int = 200, + ) -> List[Dict[str, Any]]: + """GET /records/metadata — fetch all visible records with pagination.""" + all_items: List[Dict[str, Any]] = [] + cursor: Optional[str] = None + + while True: + params: Dict[str, str] = {"limit": str(limit)} + if include_embedding: + params["include_embedding"] = "true" + if cursor: + params["cursor"] = cursor + + path = f"/records/metadata?{urllib.parse.urlencode(params)}" + data = self._get_json(path, timeout=15) + + all_items.extend(data.get("items", [])) + + if not data.get("has_more"): + break + cursor = data.get("next_cursor") + if not cursor: + break + + return all_items + + def stage_artifact(self, skill_dir: Path) -> tuple[str, int]: + """POST /artifacts/stage — upload skill files. + + Returns ``(artifact_id, file_count)``. + """ + file_paths = self._collect_files(skill_dir) + if not file_paths: + raise CloudError("No files found in skill directory") + + boundary = f"----OpenSpaceUpload{os.urandom(8).hex()}" + body_parts: list[bytes] = [] + for fp in file_paths: + rel_path = str(fp.relative_to(skill_dir)) + body_parts.append(f"--{boundary}\r\n".encode()) + body_parts.append( + f'Content-Disposition: form-data; name="files"; ' + f'filename="{rel_path}"\r\n'.encode() + ) + ctype = "text/plain" if fp.suffix in _TEXT_EXTENSIONS else "application/octet-stream" + body_parts.append(f"Content-Type: {ctype}\r\n\r\n".encode()) + body_parts.append(fp.read_bytes()) + body_parts.append(b"\r\n") + body_parts.append(f"--{boundary}--\r\n".encode()) + + _, resp_data = self._request( + "POST", + "/artifacts/stage", + body=b"".join(body_parts), + extra_headers={"Content-Type": f"multipart/form-data; boundary={boundary}"}, + timeout=60, + ) + stage = json.loads(resp_data.decode("utf-8")) + artifact_id = stage.get("artifact_id") + if not artifact_id: + raise CloudError("No artifact_id in stage response") + file_count = stage.get("stats", {}).get("file_count", 0) + return artifact_id, file_count + + def create_record(self, payload: Dict[str, Any]) -> tuple[Dict[str, Any], int]: + """POST /records — create skill record with 409 conflict handling. + + Returns ``(response_data, status_code)``. + """ + body = json.dumps(payload).encode("utf-8") + try: + status, resp_data = self._request( + "POST", + "/records", + body=body, + extra_headers={"Content-Type": "application/json"}, + ) + return json.loads(resp_data.decode("utf-8")), status + except CloudError as e: + if e.status_code == 409: + return self._handle_409(e.body, payload) + raise + + def _handle_409( + self, body_text: str, payload: Dict[str, Any], + ) -> tuple[Dict[str, Any], int]: + """Handle 409 conflict responses.""" + try: + err_data = json.loads(body_text) + except json.JSONDecodeError: + raise CloudError(f"409 conflict: {body_text}", status_code=409, body=body_text) + + err_type = err_data.get("error", "") + + if err_type == "fingerprint_record_id_conflict": + existing_id = err_data.get("existing_record_id", "") + return { + "record_id": existing_id, + "status": "duplicate", + "existing_record_id": existing_id, + }, 409 + + if err_type == "record_id_fingerprint_conflict": + # Retry with a new UUID + name = payload.get("name", "skill") + payload["record_id"] = f"{name}__clo_{uuid.uuid4().hex[:8]}" + retry_body = json.dumps(payload).encode("utf-8") + status, resp_data = self._request( + "POST", + "/records", + body=retry_body, + extra_headers={"Content-Type": "application/json"}, + ) + return json.loads(resp_data.decode("utf-8")), status + + raise CloudError(f"409 conflict: {body_text}", status_code=409, body=body_text) + + def upload_skill( + self, + skill_dir: Path, + *, + visibility: str = "public", + origin: str = "imported", + parent_skill_ids: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + created_by: str = "", + change_summary: str = "", + ) -> Dict[str, Any]: + """Upload a local skill to the cloud (stage → diff → create record). + + Returns a result dict with status, record_id, etc. + """ + from openspace.skill_engine.skill_utils import parse_frontmatter + + skill_path = Path(skill_dir) + skill_file = skill_path / SKILL_FILENAME + if not skill_file.exists(): + raise CloudError(f"SKILL.md not found in {skill_dir}") + + content = skill_file.read_text(encoding="utf-8") + fm = parse_frontmatter(content) + name = fm.get("name", skill_path.name) + description = fm.get("description", "") + + if not name: + raise CloudError("SKILL.md frontmatter missing 'name' field") + + parents = parent_skill_ids or [] + self._validate_origin_parents(origin, parents) + + api_visibility = "group_only" if visibility == "private" else "public" + + # Step 1: Stage + logger.info(f"upload_skill: staging files for '{name}'") + artifact_id, file_count = self.stage_artifact(skill_path) + logger.info(f"upload_skill: staged {file_count} file(s), artifact_id={artifact_id}") + + # Step 2: Content diff + content_diff = self._compute_content_diff(skill_path, api_visibility, parents) + + # Step 3: Create record + record_id = f"{name}__clo_{uuid.uuid4().hex[:8]}" + payload: Dict[str, Any] = { + "record_id": record_id, + "artifact_id": artifact_id, + # name/description are NOT sent — the server extracts them + # from SKILL.md YAML frontmatter (Task 4+F4 change). + "origin": origin, + "visibility": api_visibility, + "parent_skill_ids": parents, + "tags": tags or [], + "level": "workflow", + } + if created_by: + payload["created_by"] = created_by + if change_summary: + payload["change_summary"] = change_summary + if content_diff is not None: + payload["content_diff"] = content_diff + + record_data, status_code = self.create_record(payload) + action = "created" if status_code == 201 else "exists (idempotent)" + final_record_id = record_data.get("record_id", record_id) + + logger.info( + f"upload_skill: {name} [{final_record_id}] — {action} " + f"(visibility={api_visibility}, origin={origin})" + ) + + # Check for duplicate status from 409 handling + if record_data.get("status") == "duplicate": + return { + "status": "duplicate", + "message": f"Same content already exists as record '{record_data.get('existing_record_id', '')}'", + "existing_record_id": record_data.get("existing_record_id", ""), + } + + return { + "status": "success", + "action": action, + "record_id": final_record_id, + "name": name, + "description": description, + "visibility": api_visibility, + "origin": origin, + "parent_skill_ids": parents, + "artifact_id": artifact_id, + "file_count": file_count, + } + + def import_skill( + self, + skill_id: str, + target_dir: Path, + ) -> Dict[str, Any]: + """Download a cloud skill and extract to a local directory. + + Returns a result dict with status, local_path, files, etc. + """ + # 1. Fetch metadata + logger.info(f"import_skill: fetching metadata for {skill_id}") + record_data = self.fetch_record(skill_id) + skill_name = record_data.get("name", skill_id) + + skill_dir = target_dir / skill_name + + # Check if already exists locally + if skill_dir.exists() and (skill_dir / SKILL_FILENAME).exists(): + return { + "status": "already_exists", + "skill_id": skill_id, + "name": skill_name, + "local_path": str(skill_dir), + } + + # 2. Download artifact + logger.info(f"import_skill: downloading artifact for {skill_id}") + zip_data = self.download_artifact(skill_id) + + # 3. Extract + skill_dir.mkdir(parents=True, exist_ok=True) + extracted = self._extract_zip(zip_data, skill_dir) + + # 4. Write .skill_id sidecar + (skill_dir / SKILL_ID_FILENAME).write_text(skill_id + "\n", encoding="utf-8") + + logger.info( + f"import_skill: {skill_name} [{skill_id}] → {skill_dir} " + f"({len(extracted)} files)" + ) + + return { + "status": "success", + "skill_id": skill_id, + "name": skill_name, + "description": record_data.get("description", ""), + "local_path": str(skill_dir), + "files": extracted, + } + + @staticmethod + def _collect_files(skill_dir: Path) -> List[Path]: + """Collect all files in skill directory (skip .skill_id sidecar).""" + return [ + p for p in sorted(skill_dir.rglob("*")) + if p.is_file() and p.name != SKILL_ID_FILENAME + ] + + @staticmethod + def _collect_text_files(skill_dir: Path) -> Dict[str, str]: + """Collect text files as ``{relative_path: content}``.""" + files: Dict[str, str] = {} + for p in sorted(skill_dir.rglob("*")): + if p.is_file() and p.name != SKILL_ID_FILENAME: + rel = str(p.relative_to(skill_dir)) + try: + files[rel] = p.read_text(encoding="utf-8") + except (UnicodeDecodeError, OSError): + pass + return files + + @staticmethod + def _extract_zip(zip_data: bytes, target_dir: Path) -> List[str]: + """Extract zip bytes to target directory with path traversal protection.""" + extracted: List[str] = [] + try: + with zipfile.ZipFile(io.BytesIO(zip_data)) as zf: + for info in zf.infolist(): + if info.is_dir(): + continue + clean_name = Path(info.filename).as_posix() + if clean_name.startswith("..") or clean_name.startswith("/"): + continue + target_path = target_dir / clean_name + target_path.parent.mkdir(parents=True, exist_ok=True) + target_path.write_bytes(zf.read(info)) + extracted.append(clean_name) + except zipfile.BadZipFile: + raise CloudError("Downloaded artifact is not a valid zip file") + return extracted + + @staticmethod + def _extract_zip_text_files(zip_data: bytes) -> Dict[str, str]: + """Extract text files from zip as ``{filename: content}``.""" + files: Dict[str, str] = {} + try: + with zipfile.ZipFile(io.BytesIO(zip_data)) as zf: + for info in zf.infolist(): + if info.is_dir() or info.filename == SKILL_ID_FILENAME: + continue + try: + files[info.filename] = zf.read(info).decode("utf-8") + except (UnicodeDecodeError, KeyError): + pass + except zipfile.BadZipFile: + pass + return files + + @staticmethod + def _validate_origin_parents(origin: str, parents: List[str]) -> None: + if origin in ("imported", "captured") and parents: + raise CloudError(f"origin='{origin}' must not have parent_skill_ids") + if origin == "derived" and not parents: + raise CloudError("origin='derived' requires at least 1 parent_skill_id") + if origin == "fixed" and len(parents) != 1: + raise CloudError("origin='fixed' requires exactly 1 parent_skill_id") + + def _compute_content_diff( + self, + skill_dir: Path, + api_visibility: str, + parents: List[str], + ) -> Optional[str]: + """Compute content_diff for the upload. + + - public + single parent → diff vs ancestor + - public + no parent → add-all diff + - else → None + """ + if api_visibility != "public": + return None + + cur_files = self._collect_text_files(skill_dir) + + if len(parents) == 1: + try: + anc_zip = self.download_artifact(parents[0]) + anc_files = self._extract_zip_text_files(anc_zip) + diff = self._unified_diff(anc_files, cur_files) + if diff: + logger.info(f"Computed diff vs ancestor {parents[0]}") + return diff + except Exception as e: + logger.warning(f"Diff computation failed: {e}") + return None + + if not parents: + return self._unified_diff({}, cur_files) + + return None # multiple parents + + @staticmethod + def _unified_diff(old_files: Dict[str, str], new_files: Dict[str, str]) -> Optional[str]: + """Compute combined unified diff between two file snapshots.""" + all_names = sorted(set(old_files) | set(new_files)) + parts: List[str] = [] + for fname in all_names: + old = old_files.get(fname, "") + new = new_files.get(fname, "") + d = "".join(difflib.unified_diff( + old.splitlines(keepends=True), + new.splitlines(keepends=True), + fromfile=f"a/{fname}", + tofile=f"b/{fname}", + n=3, + )) + if d: + parts.append(d) + return "\n".join(parts) if parts else None diff --git a/openspace/cloud/embedding.py b/openspace/cloud/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6606936712a3c2fe6ef146586e094214df42908c --- /dev/null +++ b/openspace/cloud/embedding.py @@ -0,0 +1,129 @@ +"""Embedding generation via OpenAI-compatible API.""" + +from __future__ import annotations + +import json +import logging +import math +import os +import urllib.request +from typing import List, Optional, Tuple + +logger = logging.getLogger("openspace.cloud") + +# Constants (duplicated here to avoid top-level import of skill_ranker) +SKILL_EMBEDDING_MODEL = "openai/text-embedding-3-small" +SKILL_EMBEDDING_MAX_CHARS = 12_000 +SKILL_EMBEDDING_DIMENSIONS = 1536 + +_OPENROUTER_BASE = "https://openrouter.ai/api/v1" +_OPENAI_BASE = "https://api.openai.com/v1" + + +def resolve_embedding_api() -> Tuple[Optional[str], str]: + """Resolve API key and base URL for embedding requests. + + Priority: + 1. ``OPENROUTER_API_KEY`` → OpenRouter base URL + 2. ``OPENAI_API_KEY`` + ``OPENAI_BASE_URL`` (default ``api.openai.com``) + 3. host-agent config (nanobot / openclaw) + + Returns: + ``(api_key, base_url)`` — *api_key* may be ``None`` when no key is found. + """ + or_key = os.environ.get("OPENROUTER_API_KEY") + if or_key: + return or_key, _OPENROUTER_BASE + + oa_key = os.environ.get("OPENAI_API_KEY") + if oa_key: + base = os.environ.get("OPENAI_BASE_URL", _OPENAI_BASE).rstrip("/") + return oa_key, base + + try: + from openspace.host_detection import get_openai_api_key + host_key = get_openai_api_key() + if host_key: + base = os.environ.get("OPENAI_BASE_URL", _OPENAI_BASE).rstrip("/") + return host_key, base + except Exception: + pass + + return None, _OPENAI_BASE + + +def cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors.""" + if len(a) != len(b) or not a: + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def build_skill_embedding_text( + name: str, + description: str, + readme_body: str, + max_chars: int = SKILL_EMBEDDING_MAX_CHARS, +) -> str: + """Build text for skill embedding: ``name + description + SKILL.md body``. + + Unified strategy matching MCP search_skills and clawhub platform. + """ + header = "\n".join(filter(None, [name, description])) + raw = "\n\n".join(filter(None, [header, readme_body])) + if len(raw) <= max_chars: + return raw + return raw[:max_chars] + + +def generate_embedding(text: str, api_key: Optional[str] = None) -> Optional[List[float]]: + """Generate embedding using OpenAI-compatible API. + + When *api_key* is ``None``, credentials are resolved automatically via + :func:`resolve_embedding_api` (``OPENROUTER_API_KEY`` → ``OPENAI_API_KEY`` + → host-agent config). + + This is a **synchronous** call (uses urllib). In async contexts, + wrap with ``asyncio.to_thread()``. + + Args: + text: The text to embed. + api_key: Explicit API key. When provided, base URL is still resolved + from environment (``OPENROUTER_API_KEY`` presence determines + the endpoint). + + Returns: + Embedding vector, or None on failure. + """ + resolved_key, base_url = resolve_embedding_api() + if api_key is None: + api_key = resolved_key + if not api_key: + return None + + body = json.dumps({ + "model": SKILL_EMBEDDING_MODEL, + "input": text, + }).encode("utf-8") + + req = urllib.request.Request( + f"{base_url}/embeddings", + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=15) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data.get("data", [{}])[0].get("embedding") + except Exception as e: + logger.warning("Embedding generation failed: %s", e) + return None diff --git a/openspace/cloud/search.py b/openspace/cloud/search.py new file mode 100644 index 0000000000000000000000000000000000000000..678f71bbeec02b6ee10a5c9b16cac0c569106504 --- /dev/null +++ b/openspace/cloud/search.py @@ -0,0 +1,393 @@ +"""Hybrid skill search engine (BM25 + embedding + lexical boost). + +Implements the search pipeline: + Phase 1: BM25 rough-rank over all candidates + Phase 2: Vector scoring (embedding cosine similarity) + Phase 3: Hybrid score = vector_score + lexical_boost + Phase 4: Deduplication + limit + +Used by MCP ``search_skills`` tool, ``retrieve_skill`` agent tool, +and potentially other search interfaces. +""" + +from __future__ import annotations + +import asyncio +import logging +import re +from typing import Any, Dict, List, Optional + +logger = logging.getLogger("openspace.cloud") + + +def _check_safety(text: str) -> list[str]: + """Lazy wrapper — avoids importing skill_engine at module load time.""" + from openspace.skill_engine.skill_utils import check_skill_safety + return check_skill_safety(text) + + +def _is_safe(flags: list[str]) -> bool: + from openspace.skill_engine.skill_utils import is_skill_safe + return is_skill_safe(flags) + +_WORD_RE = re.compile(r"[a-z0-9]+") + + +def _tokenize(value: str) -> list[str]: + return _WORD_RE.findall(value.lower()) if value else [] + + +def _lexical_boost(query_tokens: list[str], name: str, slug: str) -> float: + """Compute lexical boost score based on exact/prefix token matching.""" + slug_tokens = _tokenize(slug) + name_tokens = _tokenize(name) + boost = 0.0 + + # Slug exact / prefix + if slug_tokens and all( + any(ct == qt for ct in slug_tokens) for qt in query_tokens + ): + boost += 1.4 + elif slug_tokens and all( + any(ct.startswith(qt) for ct in slug_tokens) for qt in query_tokens + ): + boost += 0.8 + + # Name exact / prefix + if name_tokens and all( + any(ct == qt for ct in name_tokens) for qt in query_tokens + ): + boost += 1.1 + elif name_tokens and all( + any(ct.startswith(qt) for ct in name_tokens) for qt in query_tokens + ): + boost += 0.6 + + return boost + + +class SkillSearchEngine: + """Hybrid BM25 + embedding search engine for skills. + + Usage:: + + engine = SkillSearchEngine() + results = engine.search( + query="weather forecast", + candidates=candidates, + query_embedding=[...], # optional + limit=20, + ) + """ + + def search( + self, + query: str, + candidates: List[Dict[str, Any]], + *, + query_embedding: Optional[List[float]] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """Run the full search pipeline on candidates. + + Each candidate dict should have at minimum: + - ``skill_id``, ``name``, ``description`` + - ``_embedding`` (optional): pre-computed embedding vector + - ``source``: "openspace-local" | "cloud" + + Args: + query: Search query text. + candidates: Candidate dicts to rank. + query_embedding: Pre-computed query embedding (if available). + limit: Max results to return. + + Returns: + Sorted list of result dicts (highest score first). + """ + q = query.strip() + if not q or not candidates: + return [] + + query_tokens = _tokenize(q) + if not query_tokens: + return [] + + # Phase 1: BM25 rough-rank + filtered = self._bm25_phase(q, candidates, limit) + + # Phase 2+3: Vector + lexical scoring + scored = self._score_phase(filtered, query_tokens, query_embedding) + + # Phase 4: Deduplicate and limit + return self._dedup_and_limit(scored, limit) + + def _bm25_phase( + self, + query: str, + candidates: List[Dict[str, Any]], + limit: int, + ) -> List[Dict[str, Any]]: + """BM25 rough-rank to keep top candidates for embedding stage.""" + from openspace.skill_engine.skill_ranker import SkillRanker, SkillCandidate + + ranker = SkillRanker(enable_cache=True) + bm25_candidates = [ + SkillCandidate( + skill_id=c.get("skill_id", ""), + name=c.get("name", ""), + description=c.get("description", ""), + body="", + metadata=c, + ) + for c in candidates + ] + ranked = ranker.bm25_only(query, bm25_candidates, top_k=min(limit * 3, len(candidates))) + + ranked_ids = {sc.skill_id for sc in ranked} + filtered = [c for c in candidates if c.get("skill_id") in ranked_ids] + + # If BM25 found nothing, fall back to all candidates + return filtered if filtered else candidates + + def _score_phase( + self, + candidates: List[Dict[str, Any]], + query_tokens: list[str], + query_embedding: Optional[List[float]], + ) -> List[Dict[str, Any]]: + """Compute hybrid score = vector_score + lexical_boost.""" + from openspace.cloud.embedding import cosine_similarity + + scored = [] + for c in candidates: + name = c.get("name", "") + slug = c.get("skill_id", name).split("__")[0].replace(":", "-") + + # Vector score + vector_score = 0.0 + if query_embedding: + skill_emb = c.get("_embedding") + if skill_emb and isinstance(skill_emb, list): + vector_score = cosine_similarity(query_embedding, skill_emb) + + # Lexical boost + lexical = _lexical_boost(query_tokens, name, slug) + + final_score = vector_score + lexical + + entry: Dict[str, Any] = { + "skill_id": c.get("skill_id", ""), + "name": name, + "description": c.get("description", ""), + "source": c.get("source", ""), + "score": round(final_score, 4), + } + if vector_score > 0: + entry["vector_score"] = round(vector_score, 4) + # Include optional fields + for key in ("path", "visibility", "created_by", "origin", "tags", "quality", "safety_flags"): + if c.get(key): + entry[key] = c[key] + scored.append(entry) + + scored.sort(key=lambda x: -x["score"]) + return scored + + @staticmethod + def _dedup_and_limit( + scored: List[Dict[str, Any]], + limit: int, + ) -> List[Dict[str, Any]]: + """Deduplicate by name and apply limit.""" + seen: set[str] = set() + deduped = [] + for item in scored: + name = item["name"] + if name in seen: + continue + seen.add(name) + deduped.append(item) + return deduped[:limit] + + +def build_local_candidates( + skills: list, + store: Any = None, +) -> List[Dict[str, Any]]: + """Build search candidate dicts from SkillRegistry skills. + + Args: + skills: List of ``SkillMeta`` from ``registry.list_skills()``. + store: Optional ``SkillStore`` instance for quality data enrichment. + + Returns: + List of candidate dicts ready for ``SkillSearchEngine.search()``. + """ + from openspace.cloud.embedding import build_skill_embedding_text + + candidates: List[Dict[str, Any]] = [] + for s in skills: + # Read SKILL.md body + readme_body = "" + try: + raw = s.path.read_text(encoding="utf-8") + m = re.match(r"^---\n.*?\n---\n?", raw, re.DOTALL) + readme_body = raw[m.end():].strip() if m else raw + except Exception: + pass + + embedding_text = build_skill_embedding_text(s.name, s.description, readme_body) + + # Safety check + flags = _check_safety(embedding_text) + if not _is_safe(flags): + logger.info(f"BLOCKED local skill {s.skill_id} — {flags}") + continue + + candidates.append({ + "skill_id": s.skill_id, + "name": s.name, + "description": s.description, + "source": "openspace-local", + "path": str(s.path), + "is_local": True, + "safety_flags": flags if flags else None, + "_embedding_text": embedding_text, + }) + + # Enrich with quality data + if store and candidates: + try: + all_records = store.load_all(active_only=True) + for c in candidates: + rec = all_records.get(c["skill_id"]) + if rec: + c["quality"] = { + "total_selections": rec.total_selections, + "completion_rate": round(rec.completion_rate, 3), + "effective_rate": round(rec.effective_rate, 3), + } + c["tags"] = rec.tags + except Exception as e: + logger.warning(f"Quality lookup failed: {e}") + + return candidates + + +def build_cloud_candidates( + items: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Build search candidate dicts from cloud metadata items. + + Args: + items: Items from ``OpenSpaceClient.fetch_metadata()``. + + Returns: + List of candidate dicts (with safety filtering applied). + """ + candidates: List[Dict[str, Any]] = [] + for item in items: + name = item.get("name", "") + desc = item.get("description", "") + tags = item.get("tags", []) + safety_text = f"{name}\n{desc}\n{' '.join(tags)}" + flags = _check_safety(safety_text) + if not _is_safe(flags): + continue + + c_entry: Dict[str, Any] = { + "skill_id": item.get("record_id", ""), + "name": name, + "description": desc, + "source": "cloud", + "visibility": item.get("visibility", "public"), + "is_local": False, + "created_by": item.get("created_by", ""), + "origin": item.get("origin", ""), + "tags": tags, + "safety_flags": flags if flags else None, + } + # Carry pre-computed embedding + platform_emb = item.get("embedding") + if platform_emb and isinstance(platform_emb, list): + c_entry["_embedding"] = platform_emb + candidates.append(c_entry) + + return candidates + + +async def hybrid_search_skills( + query: str, + local_skills: list = None, + store: Any = None, + source: str = "all", + limit: int = 20, +) -> List[Dict[str, Any]]: + """Shared cloud+local skill search with graceful fallback. + + Builds candidates, generates embeddings, runs ``SkillSearchEngine``. + Cloud is attempted when *source* includes it; failures are silently + skipped so the caller always gets local results at minimum. + + Args: + query: Free-text search query. + local_skills: ``SkillMeta`` list (from ``registry.list_skills()``). + store: Optional ``SkillStore`` for quality enrichment. + source: ``"all"`` | ``"local"`` | ``"cloud"``. + limit: Maximum results. + + Returns: + Ranked result dicts (same format as ``SkillSearchEngine.search()``). + """ + from openspace.cloud.embedding import generate_embedding + + q = query.strip() + if not q: + return [] + + candidates: List[Dict[str, Any]] = [] + + if source in ("all", "local") and local_skills: + candidates.extend(build_local_candidates(local_skills, store)) + + if source in ("all", "cloud"): + try: + from openspace.cloud.auth import get_openspace_auth + from openspace.cloud.client import OpenSpaceClient + + auth_headers, api_base = get_openspace_auth() + if auth_headers: + client = OpenSpaceClient(auth_headers, api_base) + try: + from openspace.cloud.embedding import resolve_embedding_api + has_emb = bool(resolve_embedding_api()[0]) + except Exception: + has_emb = False + items = await asyncio.to_thread( + client.fetch_metadata, include_embedding=has_emb, limit=200, + ) + candidates.extend(build_cloud_candidates(items)) + except Exception as e: + logger.warning(f"hybrid_search_skills: cloud unavailable: {e}") + + if not candidates: + return [] + + # query embedding (optional — key/URL resolved inside generate_embedding) + query_embedding: Optional[List[float]] = None + try: + query_embedding = await asyncio.to_thread(generate_embedding, q) + if query_embedding: + for c in candidates: + if not c.get("_embedding") and c.get("_embedding_text"): + emb = await asyncio.to_thread( + generate_embedding, c["_embedding_text"], + ) + if emb: + c["_embedding"] = emb + except Exception: + pass + + engine = SkillSearchEngine() + return engine.search(q, candidates, query_embedding=query_embedding, limit=limit) + diff --git a/openspace/config/README.md b/openspace/config/README.md new file mode 100644 index 0000000000000000000000000000000000000000..74327ff430e0abe1e602d82c896f44fe7d024222 --- /dev/null +++ b/openspace/config/README.md @@ -0,0 +1,115 @@ +# 🔧 Configuration Guide + +All configuration applies to both Path A (host agent) and Path B (standalone). Configure once before the first run. + +## 1. API Keys (`.env`) + +> [!NOTE] +> Create a `.env` file and add your API keys (refer to [`.env.example`](../../.env.example)). When used via host agent (Path A), LLM keys are auto-detected from your agent's config — `.env` is mainly needed for standalone mode. + +## 2. Environment Variables + +Set via `.env`, MCP config `env` block, or system environment. OpenSpace reads these at startup. + +| Variable | Required | Description | +|----------|----------|-------------| +| `OPENSPACE_HOST_SKILL_DIRS` | Path A only | Your agent's skill directories (comma-separated). Auto-registered on startup. | +| `OPENSPACE_WORKSPACE` | Recommended | OpenSpace project root. Used for recording logs and workspace resolution. | +| `OPENSPACE_API_KEY` | No | Cloud API key (`sk-xxx`). Register at https://open-space.cloud. | +| `OPENSPACE_MODEL` | No | LLM model override (default: auto-detected or `openrouter/anthropic/claude-sonnet-4.5`). | +| `OPENSPACE_MAX_ITERATIONS` | No | Max agent iterations per task (default: `20`). | +| `OPENSPACE_BACKEND_SCOPE` | No | Enabled backends, comma-separated (default: all — `shell,gui,mcp,web,system`). | + +### Advanced env overrides (rarely needed) + +| Variable | Description | +|----------|-------------| +| `OPENSPACE_LLM_API_KEY` | LLM API key (auto-detected from host agent in Path A) | +| `OPENSPACE_LLM_API_BASE` | LLM API base URL | +| `OPENSPACE_LLM_EXTRA_HEADERS` | Extra HTTP headers for LLM requests (JSON string) | +| `OPENSPACE_LLM_CONFIG` | Arbitrary litellm kwargs (JSON string) | +| `OPENSPACE_API_BASE` | Cloud API base URL (default `https://open-space.cloud/api/v1`) | +| `OPENSPACE_CONFIG_PATH` | Custom grounding config JSON (deep-merged with defaults) | +| `OPENSPACE_SHELL_CONDA_ENV` | Conda environment for shell backend | +| `OPENSPACE_SHELL_WORKING_DIR` | Working directory for shell backend | +| `OPENSPACE_MCP_SERVERS_JSON` | MCP server definitions (JSON string, merged into `mcpServers`) | +| `OPENSPACE_ENABLE_RECORDING` | Record execution traces (default: `true`) | +| `OPENSPACE_LOG_LEVEL` | `DEBUG` / `INFO` / `WARNING` / `ERROR` | + +## 3. MCP Servers (`config_mcp.json`) + +Register external MCP servers that OpenSpace connects to as a **client** (e.g. GitHub, Slack, databases): + +```bash +cp openspace/config/config_mcp.json.example openspace/config/config_mcp.json +``` + +```json +{ + "mcpServers": { + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_TOKEN}" } + } + } +} +``` + +## 4. Execution Mode: Local vs Server + +Shell and GUI backends support two execution modes, set via `"mode"` in `config_grounding.json`: + +| | Local Mode (`"local"`, default) | Server Mode (`"server"`) | +|---|---|---| +| **Setup** | Zero config | Start `local_server` first | +| **Use case** | Same-machine development | Remote VMs, sandboxing, multi-machine | +| **How** | `asyncio.subprocess` in-process | HTTP → Flask → subprocess | + +> [!TIP] +> **Use local mode** for most use cases. For server mode setup (how to enable, platform-specific deps, remote VM control), see [`../local_server/README.md`](../local_server/README.md). + +## 5. Config Files (`openspace/config/`) + +Layered system — later files override earlier ones: + +| File | Purpose | +|------|---------| +| `config_grounding.json` | Backend settings, smart tool retrieval, tool quality, skill discovery | +| `config_agents.json` | Agent definitions, backend scope, max iterations | +| `config_mcp.json` | MCP servers OpenSpace connects to as a client | +| `config_security.json` | Security policies, blocked commands, sandboxing | +| `config_dev.json` | Dev overrides — copy from `config_dev.json.example` (highest priority) | + +### Agent config (`config_agents.json`) + +```json +{ "agents": [{ "name": "GroundingAgent", "backend_scope": ["shell", "mcp", "web"], "max_iterations": 30 }] } +``` + +| Field | Description | Default | +|-------|-------------|---------| +| `backend_scope` | Enabled backends | `["gui", "shell", "mcp", "system", "web"]` | +| `max_iterations` | Max execution cycles | `20` | +| `visual_analysis_timeout` | Timeout for visual analysis (seconds) | `30.0` | + +### Backend & tool config (`config_grounding.json`) + +| Section | Key Fields | Description | +|---------|-----------|-------------| +| `shell` | `mode`, `timeout`, `conda_env`, `working_dir` | `"local"` (default) or `"server"`, command timeout (default: `60`s) | +| `gui` | `mode`, `timeout`, `driver_type`, `screenshot_on_error` | Local/server mode, automation driver (default: `pyautogui`) | +| `mcp` | `timeout`, `sandbox`, `eager_sessions` | Request timeout (`30`s), E2B sandbox, lazy/eager server init | +| `tool_search` | `search_mode`, `max_tools`, `enable_llm_filter` | `"hybrid"` (semantic + LLM), max tools to return (`40`), embedding cache | +| `tool_quality` | `enabled`, `enable_persistence`, `evolve_interval` | Quality tracking, self-evolution every N calls (default: `5`) | +| `skills` | `enabled`, `skill_dirs`, `max_select` | Directories to scan, max skills injected per task (default: `2`) | + +### Security config (`config_security.json`) + +| Field | Description | Default | +|-------|-------------|---------| +| `allow_shell_commands` | Enable shell execution | `true` | +| `blocked_commands` | Platform-specific blacklists (common/linux/darwin/windows) | `rm -rf`, `shutdown`, `dd`, etc. | +| `sandbox_enabled` | Enable sandboxing for all operations | `false` | +| Per-backend overrides | Shell, MCP, GUI, Web each have independent security policies | Inherit global | + diff --git a/openspace/config/__init__.py b/openspace/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86aa2a70a55bc1436ea1a2122da7b4b2507ab6da --- /dev/null +++ b/openspace/config/__init__.py @@ -0,0 +1,32 @@ +from .grounding import * +from .loader import * +from .constants import * +from .utils import * +from . import constants + +__all__ = [ + # Grounding Config + "BackendConfig", + "ShellConfig", + "WebConfig", + "MCPConfig", + "GUIConfig", + "ToolSearchConfig", + "SessionConfig", + "SecurityPolicy", + "GroundingConfig", + + # Loader + "CONFIG_DIR", + "load_config", + "get_config", + "reset_config", + "save_config", + "load_agents_config", + "get_agent_config", + + # Utils + "get_config_value", + "load_json_file", + "save_json_file", +] + constants.__all__ \ No newline at end of file diff --git a/openspace/config/config_agents.json b/openspace/config/config_agents.json new file mode 100644 index 0000000000000000000000000000000000000000..909df1f9103cb591fccae23705b96c368e9da83c --- /dev/null +++ b/openspace/config/config_agents.json @@ -0,0 +1,11 @@ +{ + "agents": [ + { + "name": "GroundingAgent", + "class_name": "GroundingAgent", + "backend_scope": ["shell", "mcp", "system"], + "max_iterations": 30, + "visual_analysis_timeout": 60.0 + } + ] +} \ No newline at end of file diff --git a/openspace/config/config_dev.json.example b/openspace/config/config_dev.json.example new file mode 100644 index 0000000000000000000000000000000000000000..585d22ffd78c72321de688baa5e4778ec58c283a --- /dev/null +++ b/openspace/config/config_dev.json.example @@ -0,0 +1,12 @@ +{ + "comment": "[Optional] Loading grounding.json → security.json → dev.json (dev.json overrides the former ones)", + + "debug": true, + "log_level": "DEBUG", + + "security_policies": { + "global": { + "blocked_commands": [] + } + } +} \ No newline at end of file diff --git a/openspace/config/config_grounding.json b/openspace/config/config_grounding.json new file mode 100644 index 0000000000000000000000000000000000000000..79508a109441eef332cfed89e96932ea270e771a --- /dev/null +++ b/openspace/config/config_grounding.json @@ -0,0 +1,82 @@ +{ + "shell": { + "mode": "local", + "timeout": 60, + "max_retries": 3, + "retry_interval": 3.0, + "default_shell": "/bin/bash", + "working_dir": null, + "env": {}, + "conda_env": null, + "default_port": 5000 + }, + "mcp": { + "timeout": 30, + "max_retries": 3, + "retry_interval": 2.0, + "sandbox": false, + "auto_initialize": true, + "eager_sessions": false, + "sse_read_timeout": 300.0, + "check_dependencies": true, + "auto_install": true + }, + "gui": { + "mode": "local", + "timeout": 90, + "max_retries": 3, + "retry_interval": 5.0, + "driver_type": "pyautogui", + "failsafe": false, + "screenshot_on_error": true, + "pkgs_prefix": "import pyautogui; import time; pyautogui.FAILSAFE = {failsafe}; {command}" + }, + "tool_search": { + "embedding_model": "BAAI/bge-small-en-v1.5", + "max_tools": 40, + "search_mode": "hybrid", + "enable_llm_filter": true, + "llm_filter_threshold": 50, + "enable_cache_persistence": true, + "cache_dir": null + }, + "tool_quality": { + "enabled": true, + "enable_persistence": true, + "cache_dir": null, + "auto_evaluate_descriptions": true, + "enable_quality_ranking": true, + "evolve_interval": 5 + }, + "skills": { + "enabled": true, + "skill_dirs": [], + "max_select": 2 + }, + + "tool_cache_ttl": 600, + "tool_cache_maxsize": 500, + + "debug": false, + "log_level": "INFO", + "enabled_backends": [ + { + "name": "shell", + "provider_cls": "openspace.grounding.backends.shell.ShellProvider" + }, + { + "name": "web", + "provider_cls": "openspace.grounding.backends.web.WebProvider" + }, + { + "name": "mcp", + "provider_cls": "openspace.grounding.backends.mcp.MCPProvider" + }, + { + "name": "gui", + "provider_cls": "openspace.grounding.backends.gui.GUIProvider" + } + ], + + "_comment_system_backend": "Note: 'system' backend is automatically registered and always available. It provides meta-level tools for querying system state. Do not add it to enabled_backends as it requires special initialization." +} \ No newline at end of file diff --git a/openspace/config/config_mcp.json.example b/openspace/config/config_mcp.json.example new file mode 100644 index 0000000000000000000000000000000000000000..adca29e465fdbdf318a3a8dfda02714d64a4e107 --- /dev/null +++ b/openspace/config/config_mcp.json.example @@ -0,0 +1,11 @@ +{ + "mcpServers": { + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_TOKEN}" + } + } + } +} diff --git a/openspace/config/config_security.json b/openspace/config/config_security.json new file mode 100644 index 0000000000000000000000000000000000000000..c64f75e8dd9595fef40b44f84d10ff9e1f6f4ac9 --- /dev/null +++ b/openspace/config/config_security.json @@ -0,0 +1,68 @@ +{ + "security_policies": { + "global": { + "allow_shell_commands": true, + "allow_network_access": true, + "allow_file_access": true, + "blocked_commands": { + "common": ["rm", "-rf", "shutdown", "reboot", "poweroff", "halt"], + "linux": ["mkfs", "dd", "iptables", "systemctl", "init", "kill", "-9", "pkill"], + "darwin": ["diskutil", "dd", "pfctl", "launchctl", "killall"], + "windows": ["del", "format", "rd", "rmdir", "/s", "/q", "taskkill", "/f"] + }, + "sandbox_enabled": false + }, + "backend": { + "shell": { + "allow_shell_commands": true, + "allow_file_access": true, + "blocked_commands": { + "common": ["rm", "-rf", "shutdown", "reboot", "poweroff", "halt"], + "linux": [ + "mkfs", "mkfs.ext4", "mkfs.xfs", + "dd", + "iptables", "ip6tables", "nftables", + "systemctl", "service", + "fdisk", "parted", "gdisk", + "mount", "umount", + "chmod", "777", + "chown", "root", + "passwd", + "useradd", "userdel", "usermod", + "kill", "-9", "pkill", "killall" + ], + "darwin": [ + "diskutil", + "dd", + "pfctl", + "launchctl", + "dscl", + "chmod", "777", + "chown", "root", + "passwd", + "killall", + "pmset" + ], + "windows": [ + "del", "erase", + "format", + "rd", "rmdir", "/s", "/q", + "diskpart", + "reg", "delete", + "net", "user", + "taskkill", "/f", + "wmic" + ] + }, + "sandbox_enabled": false + }, + "mcp": { + "sandbox_enabled": false + }, + "web": { + "allow_network_access": true, + "allowed_domains": [] + } + } + } +} diff --git a/openspace/config/constants.py b/openspace/config/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e847eb67b6ad5fb1cb3b9e462096ba3188060240 --- /dev/null +++ b/openspace/config/constants.py @@ -0,0 +1,23 @@ +from pathlib import Path + +CONFIG_GROUNDING = "config_grounding.json" +CONFIG_SECURITY = "config_security.json" +CONFIG_MCP = "config_mcp.json" +CONFIG_DEV = "config_dev.json" +CONFIG_AGENTS = "config_agents.json" + +LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +# Project root directory (OpenSpace/) +PROJECT_ROOT = Path(__file__).parent.parent.parent + + +__all__ = [ + "CONFIG_GROUNDING", + "CONFIG_SECURITY", + "CONFIG_MCP", + "CONFIG_DEV", + "CONFIG_AGENTS", + "LOG_LEVELS", + "PROJECT_ROOT", +] \ No newline at end of file diff --git a/openspace/config/grounding.py b/openspace/config/grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..853bbfb3b846e73bf79ec9c989126ff45ed9a022 --- /dev/null +++ b/openspace/config/grounding.py @@ -0,0 +1,311 @@ +from typing import Dict, Optional, Any, List, Literal +try: + from pydantic import BaseModel, Field, field_validator + PYDANTIC_V2 = True +except ImportError: + from pydantic import BaseModel, Field, validator as field_validator + PYDANTIC_V2 = False + +from openspace.grounding.core.types import ( + SessionConfig, + SecurityPolicy, + BackendType +) +from .constants import LOG_LEVELS + + +class ConfigMixin: + """Mixin to add utility methods for config access""" + + def get_value(self, key: str, default=None): + """ + Safely get config value, works with both dict and Pydantic models. + + Args: + key: Configuration key + default: Default value if key not found + """ + if isinstance(self, dict): + return self.get(key, default) + else: + return getattr(self, key, default) + + +class BackendConfig(BaseModel, ConfigMixin): + """Base backend configuration""" + enabled: bool = Field(True, description="Whether the backend is enabled") + timeout: int = Field(30, ge=1, le=300, description="Timeout in seconds") + max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts") + + +class ShellConfig(BackendConfig): + """ + Shell backend configuration + + Attributes: + enabled: Whether shell backend is enabled + mode: Execution mode - "local" runs scripts in-process via subprocess, + "server" connects to a running local_server via HTTP + timeout: Default timeout for shell operations (seconds) + max_retries: Maximum number of retry attempts for failed operations + retry_interval: Wait time between retries (seconds) + default_shell: Path to default shell executable + working_dir: Default working directory for bash scripts + env: Default environment variables for shell operations + conda_env: Conda environment name to activate before execution (optional) + default_port: Default port for shell server connection (only used in server mode) + """ + mode: Literal["local", "server"] = Field("local", description="Execution mode: 'local' (in-process subprocess) or 'server' (HTTP local_server)") + retry_interval: float = Field(3.0, ge=0.1, le=60.0, description="Wait time between retries in seconds") + default_shell: str = Field("/bin/bash", description="Default shell path") + working_dir: Optional[str] = Field(None, description="Default working directory for bash scripts") + env: Dict[str, str] = Field(default_factory=dict, description="Default environment variables") + conda_env: Optional[str] = Field(None, description="Conda environment name to activate (e.g., 'myenv')") + default_port: int = Field(5000, ge=1, le=65535, description="Default port for shell server") + use_clawwork_productivity: bool = Field( + False, + description="If True and livebench is installed, add ClawWork productivity tools (search_web, read_webpage, create_file, read_file, execute_code_sandbox, create_video) for fair comparison with ClawWork." + ) + productivity_date: str = Field( + "default", + description="Date segment for productivity sandbox paths (e.g. 'default' or 'YYYY-MM-DD'). Used when use_clawwork_productivity is True." + ) + + @field_validator('default_shell') + @classmethod + def validate_shell(cls, v): + if not v or not isinstance(v, str): + raise ValueError("Shell path must be a non-empty string") + return v + + @field_validator('working_dir') + @classmethod + def validate_working_dir(cls, v): + if v is not None and not isinstance(v, str): + raise ValueError("Working directory must be a string") + return v + +class WebConfig(BackendConfig): + """ + Web backend configuration - AI Deep Research + + Attributes: + enabled: Whether web backend is enabled + timeout: Default timeout for web operations (seconds) + max_retries: Maximum number of retry attempts + + Note: + All web-specific parameters (API key, base URL) are loaded from + environment variables or use default values in WebSession: + - OPENROUTER_API_KEY: API key for deep research (required) + - Deep research base URL defaults to "https://openrouter.ai/api/v1" + """ + pass + + +class MCPConfig(BackendConfig): + """MCP backend configuration""" + sandbox: bool = Field(False, description="Whether to enable sandbox") + auto_initialize: bool = Field(True, description="Whether to auto initialize") + eager_sessions: bool = Field(False, description="Whether to eagerly create sessions for all servers on initialization") + retry_interval: float = Field(2.0, ge=0.1, le=60.0, description="Wait time between retries in seconds") + servers: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="MCP servers configuration, loaded from config_mcp.json") + sse_read_timeout: float = Field(300.0, ge=1.0, le=3600.0, description="SSE read timeout in seconds for HTTP/Sandbox connectors") + + +class GUIConfig(BackendConfig): + """ + GUI backend configuration + + Attributes: + mode: Execution mode - "local" runs GUI operations in-process, + "server" connects to a running local_server via HTTP + """ + mode: Literal["local", "server"] = Field("local", description="Execution mode: 'local' (in-process) or 'server' (HTTP local_server)") + retry_interval: float = Field(5.0, ge=0.1, le=60.0, description="Wait time between retries in seconds") + driver_type: str = Field("pyautogui", description="GUI driver type") + failsafe: bool = Field(False, description="Whether to enable pyautogui failsafe mode") + screenshot_on_error: bool = Field(True, description="Whether to capture screenshot on error") + pkgs_prefix: str = Field( + "import pyautogui; import time; pyautogui.FAILSAFE = {failsafe}; {command}", + description="Python command prefix for pyautogui setup" + ) + + +class ToolSearchConfig(BaseModel): + """Tool search and ranking configuration""" + embedding_model: str = Field( + "BAAI/bge-small-en-v1.5", + description="Embedding model name for semantic search" + ) + max_tools: int = Field( + 20, + ge=1, + le=1000, + description="Maximum number of tools to return from search" + ) + search_mode: str = Field( + "hybrid", + description="Default search mode: semantic, keyword, or hybrid" + ) + enable_llm_filter: bool = Field( + True, + description="Whether to use LLM for backend/server filtering" + ) + llm_filter_threshold: int = Field( + 50, + ge=1, + le=1000, + description="Only apply LLM filter when tool count exceeds this threshold" + ) + enable_cache_persistence: bool = Field( + False, + description="Whether to persist embeddings to disk" + ) + cache_dir: Optional[str] = Field( + None, + description="Directory for embedding cache. None means use default /.openspace/embedding_cache" + ) + + @field_validator('search_mode') + @classmethod + def validate_search_mode(cls, v): + valid_modes = ['semantic', 'keyword', 'hybrid'] + if v.lower() not in valid_modes: + raise ValueError(f"Search mode must be one of {valid_modes}, got: {v}") + return v.lower() + + +class ToolQualityConfig(BaseModel): + """Tool quality tracking configuration""" + enabled: bool = Field( + True, + description="Whether to enable tool quality tracking" + ) + enable_persistence: bool = Field( + True, + description="Whether to persist quality data to disk" + ) + cache_dir: Optional[str] = Field( + None, + description="Directory for quality cache. None means use default /.openspace/tool_quality" + ) + auto_evaluate_descriptions: bool = Field( + True, + description="Whether to automatically evaluate tool descriptions using LLM" + ) + enable_quality_ranking: bool = Field( + True, + description="Whether to incorporate quality scores in tool ranking" + ) + evolve_interval: int = Field( + 5, + ge=1, + le=100, + description="Trigger quality evolution every N tool executions" + ) + + +class SkillConfig(BaseModel): + """Skill engine configuration + + Controls how skills are discovered, selected and injected. + Built-in skills (``openspace/skills/``) are always auto-discovered. + """ + enabled: bool = Field(True, description="Enable skill matching and injection") + skill_dirs: List[str] = Field( + default_factory=list, + description="Extra skill directories. Built-in openspace/skills/ is always included." + ) + max_select: int = Field( + 2, ge=1, le=20, + description="Maximum number of skills to inject per task" + ) + + +class GroundingConfig(BaseModel): + """ + Main configuration for Grounding module. + + Contains configuration for all grounding backends and grounding-level settings. + Note: Local server connection uses defaults or environment variables (LOCAL_SERVER_URL). + """ + # Backend configurations + shell: ShellConfig = Field(default_factory=ShellConfig) + web: WebConfig = Field(default_factory=WebConfig) + mcp: MCPConfig = Field(default_factory=MCPConfig) + gui: GUIConfig = Field(default_factory=GUIConfig) + system: BackendConfig = Field(default_factory=BackendConfig) + + # Grounding-level settings + tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig) + tool_quality: ToolQualityConfig = Field(default_factory=ToolQualityConfig) + skills: SkillConfig = Field(default_factory=SkillConfig) + + enabled_backends: List[Dict[str, str]] = Field( + default_factory=list, + description="List of enabled backends, each item: {'name': str, 'provider_cls': str}" + ) + + session_defaults: SessionConfig = Field( + default_factory=lambda: SessionConfig( + session_name="", + backend_type=BackendType.SHELL, + timeout=30, + auto_reconnect=True, + health_check_interval=30 + ) + ) + + tool_cache_ttl: int = Field( + 300, + ge=1, + le=3600, + description="Tool cache time-to-live in seconds" + ) + tool_cache_maxsize: int = Field( + 300, + ge=1, + le=10000, + description="Maximum number of tool cache entries" + ) + + debug: bool = Field(False, description="Debug mode") + log_level: str = Field("INFO", description="Log level") + security_policies: Dict[str, Any] = Field(default_factory=dict) + + @field_validator('log_level') + @classmethod + def validate_log_level(cls, v): + if v.upper() not in LOG_LEVELS: + raise ValueError(f"Log level must be one of {LOG_LEVELS}, got: {v}") + return v.upper() + + def get_backend_config(self, backend_type: str) -> BackendConfig: + """Get configuration for specified backend""" + name = backend_type.lower() + if not hasattr(self, name): + from openspace.utils.logging import Logger + logger = Logger.get_logger(__name__) + logger.warning(f"Unknown backend type: {backend_type}") + return BackendConfig() + return getattr(self, name) + + def get_security_policy(self, backend_type: str) -> SecurityPolicy: + global_policy = self.security_policies.get("global", {}) + backend_policy = self.security_policies.get("backend", {}).get(backend_type.lower(), {}) + merged_policy = {**global_policy, **backend_policy} + return SecurityPolicy.from_dict(merged_policy) + + +__all__ = [ + "BackendConfig", + "ShellConfig", + "WebConfig", + "MCPConfig", + "GUIConfig", + "ToolSearchConfig", + "ToolQualityConfig", + "SkillConfig", + "GroundingConfig", +] \ No newline at end of file diff --git a/openspace/config/loader.py b/openspace/config/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..cd41416839bf7d63d22a316fd1408a742c0cd3b6 --- /dev/null +++ b/openspace/config/loader.py @@ -0,0 +1,177 @@ +import threading +from pathlib import Path +from typing import Union, Iterable, Dict, Any, Optional + +from .grounding import GroundingConfig +from .constants import ( + CONFIG_GROUNDING, + CONFIG_SECURITY, + CONFIG_DEV, + CONFIG_MCP, + CONFIG_AGENTS +) +from openspace.utils.logging import Logger +from .utils import load_json_file, save_json_file as save_json + +logger = Logger.get_logger(__name__) + + +CONFIG_DIR = Path(__file__).parent + +# Global configuration singleton +_config: GroundingConfig | None = None +_config_lock = threading.RLock() # Use RLock to support recursive locking + + +def _deep_merge_dict(base: dict, update: dict) -> dict: + """Deep merge two dictionaries, update's values will override base's values""" + result = base.copy() + for key, value in update.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge_dict(result[key], value) + else: + result[key] = value + return result + +def _load_json_file(path: Path) -> Dict[str, Any]: + """Load single JSON configuration file. + + This function wraps the generic load_json_file and adds global configuration specific error handling and logging. + """ + if not path.exists(): + logger.debug(f"Configuration file does not exist, skipping: {path}") + return {} + + try: + data = load_json_file(path) + logger.info(f"Loaded configuration file: {path}") + return data + except Exception as e: + logger.warning(f"Failed to load configuration file {path}: {e}") + return {} + +def _load_multiple_files(paths: Iterable[Path]) -> Dict[str, Any]: + """Load configuration from multiple files""" + merged = {} + for path in paths: + data = _load_json_file(path) + if data: + merged = _deep_merge_dict(merged, data) + return merged + +def load_config(*config_paths: Union[str, Path]) -> GroundingConfig: + """ + Load configuration files + """ + global _config + + with _config_lock: + if config_paths: + paths = [Path(p) for p in config_paths] + else: + paths = [ + CONFIG_DIR / CONFIG_GROUNDING, + CONFIG_DIR / CONFIG_SECURITY, + CONFIG_DIR / CONFIG_DEV, # Optional: development environment configuration + ] + + # Load and merge configuration + raw_data = _load_multiple_files(paths) + + # Load MCP configuration (separate processing) + # Check if mcpServers already provided in merged custom configs + has_custom_mcp_servers = "mcpServers" in raw_data + + if has_custom_mcp_servers: + # Use mcpServers from custom config + if "mcp" not in raw_data: + raw_data["mcp"] = {} + raw_data["mcp"]["servers"] = raw_data.pop("mcpServers") + logger.debug(f"Using custom MCP servers from provided config ({len(raw_data['mcp']['servers'])} servers)") + else: + # Load default MCP servers from config_mcp.json + mcp_data = _load_json_file(CONFIG_DIR / CONFIG_MCP) + if mcp_data and "mcpServers" in mcp_data: + if "mcp" not in raw_data: + raw_data["mcp"] = {} + raw_data["mcp"]["servers"] = mcp_data["mcpServers"] + logger.debug(f"Loaded MCP servers from default config_mcp.json ({len(raw_data['mcp']['servers'])} servers)") + + # Validate and create configuration object + try: + _config = GroundingConfig.model_validate(raw_data) + except Exception as e: + logger.error(f"Validation failed, using default configuration: {e}") + _config = GroundingConfig() + + # Adjust log level according to configuration + if _config.debug: + Logger.set_debug(2) + elif _config.log_level: + try: + Logger.configure(level=_config.log_level) + except Exception as e: + logger.warning(f"Failed to set log level {_config.log_level}: {e}") + + return _config + +def get_config() -> GroundingConfig: + """ + Get global configuration instance. + + Usage: + - Get configuration in Provider: get_config().get_backend_config('shell') + - Get security policy in Tool: get_config().get_security_policy('shell') + """ + global _config + + if _config is None: + with _config_lock: + if _config is None: + load_config() + + return _config + +def reset_config() -> None: + """Reset configuration (for testing)""" + global _config + with _config_lock: + _config = None + +def save_config(config: GroundingConfig, path: Union[str, Path]) -> None: + save_json(config.model_dump(), path) + logger.info(f"Configuration saved to: {path}") + + +def load_agents_config() -> Dict[str, Any]: + agents_config_path = CONFIG_DIR / CONFIG_AGENTS + return _load_json_file(agents_config_path) + + +def get_agent_config(agent_name: str) -> Optional[Dict[str, Any]]: + """ + Get the configuration of the specified agent + """ + agents_config = load_agents_config() + + if "agents" not in agents_config: + logger.warning(f"No 'agents' key found in {CONFIG_AGENTS}") + return None + + for agent_cfg in agents_config.get("agents", []): + if agent_cfg.get("name") == agent_name: + return agent_cfg + + logger.warning(f"Agent '{agent_name}' not found in {CONFIG_AGENTS}") + return None + + +__all__ = [ + "CONFIG_DIR", + "load_config", + "get_config", + "reset_config", + "save_config", + "load_agents_config", + "get_agent_config" +] \ No newline at end of file diff --git a/openspace/config/utils.py b/openspace/config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be91583e2c9f453a35b998a9e756ccab9af1eabc --- /dev/null +++ b/openspace/config/utils.py @@ -0,0 +1,30 @@ +import json +from pathlib import Path +from typing import Any + + +def get_config_value(config: Any, key: str, default=None): + if isinstance(config, dict): + return config.get(key, default) + else: + return getattr(config, key, default) + + +def load_json_file(filepath: str | Path) -> dict[str, Any]: + filepath = Path(filepath) if isinstance(filepath, str) else filepath + + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + + +def save_json_file(data: dict[str, Any], filepath: str | Path, indent: int = 2) -> None: + filepath = Path(filepath) if isinstance(filepath, str) else filepath + + # Ensure directory exists + filepath.parent.mkdir(parents=True, exist_ok=True) + + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=indent, ensure_ascii=False) + + +__all__ = ["get_config_value", "load_json_file", "save_json_file"] diff --git a/openspace/dashboard_server.py b/openspace/dashboard_server.py new file mode 100644 index 0000000000000000000000000000000000000000..14d5b66e81f911b599b7b2e31e49f520dfab6366 --- /dev/null +++ b/openspace/dashboard_server.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +import argparse +import os +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +PROJECT_ROOT = Path(__file__).resolve().parent.parent + +from flask import Flask, abort, jsonify, send_from_directory, url_for, request + +from openspace.recording.action_recorder import analyze_agent_actions, load_agent_actions +from openspace.recording.utils import load_recording_session +from openspace.skill_engine import SkillStore +from openspace.skill_engine.types import SkillRecord + +API_PREFIX = "/api/v1" +FRONTEND_DIST_DIR = PROJECT_ROOT / "frontend" / "dist" +WORKFLOW_ROOTS = [ + PROJECT_ROOT / "logs" / "recordings", + PROJECT_ROOT / "logs" / "trajectories", + PROJECT_ROOT / "gdpval_bench" / "results", +] + +PIPELINE_STAGES = [ + { + "id": "initialize", + "title": "Initialize", + "description": "Load LLM, grounding backends, recording, registry, analyzer, and evolver.", + }, + { + "id": "select-skills", + "title": "Skill Selection", + "description": "Select candidate skills and write selection metadata before execution.", + }, + { + "id": "phase-1-skill", + "title": "Skill Phase", + "description": "Run the task with injected skill context whenever matching skills exist.", + }, + { + "id": "phase-2-fallback", + "title": "Tool Fallback", + "description": "Fallback to tool-only execution when the skill-guided phase fails or no skills match.", + }, + { + "id": "analysis", + "title": "Execution Analysis", + "description": "Persist metadata, trajectory, and post-run execution judgments.", + }, + { + "id": "evolution", + "title": "Skill Evolution", + "description": "Trigger fix / derived / captured evolution and periodic quality checks.", + }, +] + +_STORE: SkillStore | None = None + + +def create_app() -> Flask: + app = Flask(__name__, static_folder=None) + + @app.before_request + def check_api_key(): + # Allow preflight requests (CORS) + if request.method == "OPTIONS": + return + + expected_key = os.environ.get("OPENSPACE_API_KEY") + if expected_key: + auth_header = request.headers.get("Authorization") + if not auth_header or auth_header != f"Bearer {expected_key}": + abort(401, description="Unauthorized: Invalid or missing API Key") + + @app.route(f"{API_PREFIX}/health", methods=["GET"]) + def health() -> Any: + workflows = _discover_workflow_dirs() + store = _get_store() + return jsonify( + { + "status": "ok", + "project_root": str(PROJECT_ROOT), + "db_path": str(store.db_path), + "db_exists": store.db_path.exists(), + "frontend_dist_exists": FRONTEND_DIST_DIR.exists(), + "workflow_roots": [str(path) for path in WORKFLOW_ROOTS], + "workflow_count": len(workflows), + } + ) + + @app.route(f"{API_PREFIX}/overview", methods=["GET"]) + def overview() -> Any: + store = _get_store() + skills = list(store.load_all(active_only=False).values()) + workflows = [_build_workflow_summary(path) for path in _discover_workflow_dirs()] + top_skills = _sort_skills(skills, sort_key="score")[:5] + recent_skills = _sort_skills(skills, sort_key="updated")[:5] + average_score = round( + sum(_skill_score(record) for record in skills) / len(skills), 1 + ) if skills else 0.0 + average_workflow_success = round( + (sum((item.get("success_rate") or 0.0) for item in workflows) / len(workflows)) * 100, + 1, + ) if workflows else 0.0 + + return jsonify( + { + "health": { + "status": "ok", + "db_path": str(store.db_path), + "workflow_count": len(workflows), + "frontend_dist_exists": FRONTEND_DIST_DIR.exists(), + }, + "pipeline": PIPELINE_STAGES, + "skills": { + "summary": _build_skill_stats(store, skills), + "average_score": average_score, + "top": [_serialize_skill(item) for item in top_skills], + "recent": [_serialize_skill(item) for item in recent_skills], + }, + "workflows": { + "total": len(workflows), + "average_success_rate": average_workflow_success, + "recent": workflows[:5], + }, + } + ) + + @app.route(f"{API_PREFIX}/skills", methods=["GET"]) + def list_skills() -> Any: + store = _get_store() + active_only = _bool_arg("active_only", True) + limit = _int_arg("limit", 100) + sort_key = (_str_arg("sort", "score") or "score").lower() + skills = list(store.load_all(active_only=active_only).values()) + query = (_str_arg("query", "") or "").strip().lower() + if query: + skills = [ + record + for record in skills + if query in record.name.lower() + or query in record.skill_id.lower() + or query in record.description.lower() + or any(query in tag.lower() for tag in record.tags) + ] + items = [_serialize_skill(item) for item in _sort_skills(skills, sort_key=sort_key)[:limit]] + return jsonify({"items": items, "count": len(items), "active_only": active_only}) + + @app.route(f"{API_PREFIX}/skills/stats", methods=["GET"]) + def skill_stats() -> Any: + store = _get_store() + skills = list(store.load_all(active_only=False).values()) + return jsonify(_build_skill_stats(store, skills)) + + @app.route(f"{API_PREFIX}/skills/", methods=["GET"]) + def skill_detail(skill_id: str) -> Any: + store = _get_store() + record = store.load_record(skill_id) + if not record: + abort(404, description=f"Unknown skill_id: {skill_id}") + + detail = _serialize_skill(record, include_recent_analyses=True) + detail["lineage_graph"] = _build_lineage_payload(skill_id, store) + detail["recent_analyses"] = [analysis.to_dict() for analysis in store.load_analyses(skill_id=skill_id, limit=10)] + detail["source"] = _load_skill_source(record) + return jsonify(detail) + + @app.route(f"{API_PREFIX}/skills//lineage", methods=["GET"]) + def skill_lineage(skill_id: str) -> Any: + store = _get_store() + if not store.load_record(skill_id): + abort(404, description=f"Unknown skill_id: {skill_id}") + return jsonify(_build_lineage_payload(skill_id, store)) + + @app.route(f"{API_PREFIX}/skills//source", methods=["GET"]) + def skill_source(skill_id: str) -> Any: + store = _get_store() + record = store.load_record(skill_id) + if not record: + abort(404, description=f"Unknown skill_id: {skill_id}") + return jsonify(_load_skill_source(record)) + + @app.route(f"{API_PREFIX}/workflows", methods=["GET"]) + def list_workflows() -> Any: + items = [_build_workflow_summary(path) for path in _discover_workflow_dirs()] + return jsonify({"items": items, "count": len(items)}) + + @app.route(f"{API_PREFIX}/workflows/", methods=["GET"]) + def workflow_detail(workflow_id: str) -> Any: + workflow_dir = _get_workflow_dir(workflow_id) + if not workflow_dir: + abort(404, description=f"Unknown workflow: {workflow_id}") + + session = load_recording_session(str(workflow_dir)) + actions = load_agent_actions(str(workflow_dir)) + metadata = session.get("metadata") or {} + trajectory = session.get("trajectory") or [] + plans = session.get("plans") or [] + decisions = session.get("decisions") or [] + action_stats = analyze_agent_actions(actions) + + enriched_trajectory = [] + for step in trajectory: + step_copy = dict(step) + screenshot_rel = step_copy.get("screenshot") + if screenshot_rel: + step_copy["screenshot_url"] = url_for( + "workflow_artifact", + workflow_id=workflow_id, + artifact_path=screenshot_rel, + ) + enriched_trajectory.append(step_copy) + + timeline = _build_timeline(actions, enriched_trajectory) + artifacts = _build_workflow_artifacts(workflow_dir, workflow_id, metadata) + + return jsonify( + { + **_build_workflow_summary(workflow_dir), + "metadata": metadata, + "statistics": session.get("statistics") or {}, + "trajectory": enriched_trajectory, + "plans": plans, + "decisions": decisions, + "agent_actions": actions, + "agent_statistics": action_stats, + "timeline": timeline, + "artifacts": artifacts, + } + ) + + @app.route(f"{API_PREFIX}/workflows//artifacts/", methods=["GET"]) + def workflow_artifact(workflow_id: str, artifact_path: str) -> Any: + workflow_dir = _get_workflow_dir(workflow_id) + if not workflow_dir: + abort(404, description=f"Unknown workflow: {workflow_id}") + + target = (workflow_dir / artifact_path).resolve() + root = workflow_dir.resolve() + if root not in target.parents and target != root: + abort(404) + if not target.exists() or not target.is_file(): + abort(404) + return send_from_directory(str(target.parent), target.name) + + @app.route("/", defaults={"path": ""}) + @app.route("/") + def serve_frontend(path: str) -> Any: + if path.startswith("api/"): + abort(404) + + if FRONTEND_DIST_DIR.exists(): + requested = FRONTEND_DIST_DIR / path if path else FRONTEND_DIST_DIR / "index.html" + if path and requested.exists() and requested.is_file(): + return send_from_directory(str(FRONTEND_DIST_DIR), path) + return send_from_directory(str(FRONTEND_DIST_DIR), "index.html") + + return jsonify( + { + "message": "OpenSpace dashboard API is running.", + "frontend": "Build frontend/ first or run the Vite dev server.", + } + ) + + return app + + +def _get_store() -> SkillStore: + global _STORE + if _STORE is None: + _STORE = SkillStore() + return _STORE + + +def _bool_arg(name: str, default: bool) -> bool: + from flask import request + + raw = request.args.get(name) + if raw is None: + return default + return raw.lower() not in {"0", "false", "no", "off"} + + +def _int_arg(name: str, default: int) -> int: + from flask import request + + raw = request.args.get(name) + if raw is None: + return default + try: + return int(raw) + except ValueError: + return default + + +def _str_arg(name: str, default: str) -> str: + from flask import request + + return request.args.get(name, default) + + +def _skill_score(record: SkillRecord) -> float: + return round(record.effective_rate * 100, 1) + + +def _serialize_skill(record: SkillRecord, *, include_recent_analyses: bool = False) -> Dict[str, Any]: + payload = record.to_dict() + if not include_recent_analyses: + payload.pop("recent_analyses", None) + + path = payload.get("path", "") + lineage = payload.get("lineage") or {} + payload.update( + { + "skill_dir": str(Path(path).parent) if path else "", + "origin": lineage.get("origin", ""), + "generation": lineage.get("generation", 0), + "parent_skill_ids": lineage.get("parent_skill_ids", []), + "applied_rate": round(record.applied_rate, 4), + "completion_rate": round(record.completion_rate, 4), + "effective_rate": round(record.effective_rate, 4), + "fallback_rate": round(record.fallback_rate, 4), + "score": _skill_score(record), + } + ) + return payload + + +def _naive_dt(dt: datetime) -> datetime: + """Strip tzinfo so naive/aware datetimes can be compared safely.""" + return dt.replace(tzinfo=None) if dt.tzinfo else dt + + +def _sort_skills(records: Iterable[SkillRecord], *, sort_key: str) -> List[SkillRecord]: + if sort_key == "updated": + return sorted(records, key=lambda item: _naive_dt(item.last_updated), reverse=True) + if sort_key == "name": + return sorted(records, key=lambda item: item.name.lower()) + return sorted( + records, + key=lambda item: (_skill_score(item), item.total_selections, _naive_dt(item.last_updated).timestamp()), + reverse=True, + ) + + +def _build_skill_stats(store: SkillStore, skills: List[SkillRecord]) -> Dict[str, Any]: + stats = store.get_stats(active_only=False) + avg_score = round(sum(_skill_score(item) for item in skills) / len(skills), 1) if skills else 0.0 + skills_with_recent_analysis = sum(1 for item in skills if item.recent_analyses) + return { + **stats, + "average_score": avg_score, + "skills_with_activity": sum(1 for item in skills if item.total_selections > 0), + "skills_with_recent_analysis": skills_with_recent_analysis, + "top_by_effective_rate": [_serialize_skill(item) for item in _sort_skills(skills, sort_key="score")[:5]], + } + + +def _load_skill_source(record: SkillRecord) -> Dict[str, Any]: + skill_path = Path(record.path) + if not skill_path.exists() or not skill_path.is_file(): + return {"exists": False, "path": record.path, "content": None} + try: + return { + "exists": True, + "path": str(skill_path), + "content": skill_path.read_text(encoding="utf-8"), + } + except OSError: + return {"exists": False, "path": str(skill_path), "content": None} + + +def _build_lineage_payload(skill_id: str, store: SkillStore) -> Dict[str, Any]: + records = store.load_all(active_only=False) + if skill_id not in records: + return {"skill_id": skill_id, "nodes": [], "edges": [], "total_nodes": 0} + + children_by_parent: Dict[str, set[str]] = {} + for item in records.values(): + for parent_id in item.lineage.parent_skill_ids: + children_by_parent.setdefault(parent_id, set()).add(item.skill_id) + + related_ids = {skill_id} + frontier = [skill_id] + while frontier: + current = frontier.pop() + record = records.get(current) + if not record: + continue + for parent_id in record.lineage.parent_skill_ids: + if parent_id not in related_ids: + related_ids.add(parent_id) + frontier.append(parent_id) + for child_id in children_by_parent.get(current, set()): + if child_id not in related_ids: + related_ids.add(child_id) + frontier.append(child_id) + + nodes = [] + edges = [] + for related_id in sorted(related_ids): + record = records.get(related_id) + if not record: + continue + nodes.append( + { + "skill_id": record.skill_id, + "name": record.name, + "description": record.description, + "origin": record.lineage.origin.value, + "generation": record.lineage.generation, + "created_at": record.lineage.created_at.isoformat(), + "visibility": record.visibility.value, + "is_active": record.is_active, + "tags": list(record.tags), + "score": _skill_score(record), + "effective_rate": round(record.effective_rate, 4), + "total_selections": record.total_selections, + } + ) + for parent_id in record.lineage.parent_skill_ids: + if parent_id in related_ids: + edges.append({"source": parent_id, "target": record.skill_id}) + + return { + "skill_id": skill_id, + "nodes": nodes, + "edges": edges, + "total_nodes": len(nodes), + } + + +def _discover_workflow_dirs() -> List[Path]: + discovered: Dict[str, Path] = {} + for root in WORKFLOW_ROOTS: + if not root.exists(): + continue + _scan_workflow_tree(root, discovered) + return sorted(discovered.values(), key=lambda item: item.stat().st_mtime, reverse=True) + + +def _scan_workflow_tree(directory: Path, discovered: Dict[str, Path], *, _depth: int = 0, _max_depth: int = 6) -> None: + if _depth > _max_depth: + return + try: + children = list(directory.iterdir()) + except OSError: + return + for child in children: + if not child.is_dir(): + continue + if (child / "metadata.json").exists() or (child / "traj.jsonl").exists(): + discovered.setdefault(child.name, child) + else: + _scan_workflow_tree(child, discovered, _depth=_depth + 1, _max_depth=_max_depth) + + +def _get_workflow_dir(workflow_id: str) -> Optional[Path]: + for path in _discover_workflow_dirs(): + if path.name == workflow_id: + return path + return None + + +def _build_workflow_summary(workflow_dir: Path) -> Dict[str, Any]: + session = load_recording_session(str(workflow_dir)) + metadata = session.get("metadata") or {} + statistics = session.get("statistics") or {} + actions = load_agent_actions(str(workflow_dir)) + screenshots_dir = workflow_dir / "screenshots" + screenshot_count = len(list(screenshots_dir.glob("*.png"))) if screenshots_dir.exists() else 0 + + video_candidates = [workflow_dir / "screen_recording.mp4", workflow_dir / "recording.mp4"] + video_url = None + for candidate in video_candidates: + if candidate.exists(): + rel = candidate.relative_to(workflow_dir).as_posix() + video_url = url_for("workflow_artifact", workflow_id=workflow_dir.name, artifact_path=rel) + break + + outcome = metadata.get("execution_outcome") or {} + # Instruction fallback chain: top-level → retrieved_tools.instruction → skill_selection.task + instruction = ( + metadata.get("instruction") + or (metadata.get("retrieved_tools") or {}).get("instruction") + or (metadata.get("skill_selection") or {}).get("task") + or "" + ) + + # Resolve start/end times with trajectory fallback + start_time = metadata.get("start_time") + end_time = metadata.get("end_time") + trajectory = session.get("trajectory") or [] + + # If end_time is missing, infer from last trajectory step + if not end_time and trajectory: + last_ts = trajectory[-1].get("timestamp") + if last_ts: + end_time = last_ts + + # Compute execution_time: prefer outcome, fallback to timestamp diff + execution_time = outcome.get("execution_time", 0) + if not execution_time and start_time and end_time: + try: + t0 = datetime.fromisoformat(start_time) + t1 = datetime.fromisoformat(end_time) + execution_time = round((t1 - t0).total_seconds(), 2) + except (ValueError, TypeError): + pass + + # Resolve status: prefer outcome, fallback heuristic + status = outcome.get("status", "") + if not status: + total_steps = statistics.get("total_steps", 0) + if total_steps > 0: + status = "success" + elif trajectory: + status = "completed" + else: + status = "unknown" + + # Resolve iterations: prefer outcome, fallback to conversation count + iterations = outcome.get("iterations", 0) + if not iterations and trajectory: + iterations = len(trajectory) + + return { + "id": workflow_dir.name, + "path": str(workflow_dir), + "task_id": metadata.get("task_id") or metadata.get("task_name") or workflow_dir.name, + "task_name": metadata.get("task_name") or metadata.get("task_id") or workflow_dir.name, + "instruction": instruction, + "status": status, + "iterations": iterations, + "execution_time": execution_time, + "start_time": start_time, + "end_time": end_time, + "total_steps": statistics.get("total_steps", 0), + "success_count": statistics.get("success_count", 0), + "success_rate": statistics.get("success_rate", 0.0), + "backend_counts": statistics.get("backends", {}), + "tool_counts": statistics.get("tools", {}), + "agent_action_count": len(actions), + "has_video": bool(video_url), + "video_url": video_url, + "screenshot_count": screenshot_count, + "selected_skills": (metadata.get("skill_selection") or {}).get("selected", []), + } + + +def _build_timeline(actions: List[Dict[str, Any]], trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + events: List[Dict[str, Any]] = [] + for action in actions: + events.append( + { + "timestamp": action.get("timestamp", ""), + "type": "agent_action", + "step": action.get("step"), + "label": action.get("action_type", "agent_action"), + "agent_name": action.get("agent_name", ""), + "agent_type": action.get("agent_type", ""), + "details": action, + } + ) + for step in trajectory: + events.append( + { + "timestamp": step.get("timestamp", ""), + "type": "tool_execution", + "step": step.get("step"), + "label": step.get("tool", "tool_execution"), + "backend": step.get("backend", ""), + "status": (step.get("result") or {}).get("status", "unknown"), + "details": step, + } + ) + events.sort(key=lambda item: (item.get("timestamp", ""), item.get("step") or 0)) + return events + + +def _build_workflow_artifacts(workflow_dir: Path, workflow_id: str, metadata: Dict[str, Any]) -> Dict[str, Any]: + screenshots: List[Dict[str, Any]] = [] + screenshots_dir = workflow_dir / "screenshots" + if screenshots_dir.exists(): + for image in sorted(screenshots_dir.glob("*.png")): + rel = image.relative_to(workflow_dir).as_posix() + screenshots.append( + { + "name": image.name, + "path": rel, + "url": url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=rel), + } + ) + + init_screenshot = metadata.get("init_screenshot") + init_screenshot_url = ( + url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=init_screenshot) + if isinstance(init_screenshot, str) + else None + ) + + video_url = None + for rel in ("screen_recording.mp4", "recording.mp4"): + candidate = workflow_dir / rel + if candidate.exists(): + video_url = url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=rel) + break + + return { + "init_screenshot_url": init_screenshot_url, + "screenshots": screenshots, + "video_url": video_url, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="OpenSpace dashboard API server") + parser.add_argument("--host", default="127.0.0.1", help="Dashboard API host") + parser.add_argument("--port", type=int, default=7788, help="Dashboard API port") + parser.add_argument("--debug", action="store_true", help="Enable Flask debug mode") + args = parser.parse_args() + + app = create_app() + + from werkzeug.serving import run_simple + run_simple( + args.host, + args.port, + app, + threaded=True, + use_debugger=args.debug, + use_reloader=args.debug, + ) + + +if __name__ == "__main__": + main() diff --git a/openspace/grounding/backends/__init__.py b/openspace/grounding/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b31ae2bb01b070b65b0ed951dd795ebd958b2a --- /dev/null +++ b/openspace/grounding/backends/__init__.py @@ -0,0 +1,34 @@ +# Use lazy imports to avoid loading all backends unconditionally + +def _lazy_import_provider(provider_name: str): + """Lazy import provider class""" + if provider_name == 'mcp': + from .mcp.provider import MCPProvider + return MCPProvider + elif provider_name == 'shell': + from .shell.provider import ShellProvider + return ShellProvider + elif provider_name == 'web': + from .web.provider import WebProvider + return WebProvider + elif provider_name == 'gui': + from .gui.provider import GUIProvider + return GUIProvider + else: + raise ImportError(f"Unknown provider: {provider_name}") + + +class _ProviderRegistry: + """Lazy provider registry""" + def __getitem__(self, key): + return _lazy_import_provider(key) + + def __contains__(self, key): + return key in ['mcp', 'shell', 'web', 'gui'] + +BACKEND_PROVIDERS = _ProviderRegistry() + +__all__ = [ + 'BACKEND_PROVIDERS', + '_lazy_import_provider' +] \ No newline at end of file diff --git a/openspace/grounding/backends/gui/__init__.py b/openspace/grounding/backends/gui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c874851b2388d8e0fa059cd91e2bff011720ec20 --- /dev/null +++ b/openspace/grounding/backends/gui/__init__.py @@ -0,0 +1,25 @@ +from .provider import GUIProvider +from .session import GUISession +from .transport.connector import GUIConnector +from .transport.local_connector import LocalGUIConnector + +try: + from .anthropic_client import AnthropicGUIClient + from . import anthropic_utils + _anthropic_available = True +except ImportError: + _anthropic_available = False + +__all__ = [ + # Core Provider and Session + "GUIProvider", + "GUISession", + + # Transport layer + "GUIConnector", + "LocalGUIConnector", +] + +# Add Anthropic modules to exports if available +if _anthropic_available: + __all__.extend(["AnthropicGUIClient", "anthropic_utils"]) \ No newline at end of file diff --git a/openspace/grounding/backends/gui/anthropic_client.py b/openspace/grounding/backends/gui/anthropic_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1dba6a7385d36d72fff8a77d1163eec36ab03f --- /dev/null +++ b/openspace/grounding/backends/gui/anthropic_client.py @@ -0,0 +1,575 @@ +import base64 +import os +import time +from typing import Any, Dict, Optional, Tuple, List +from openspace.utils.logging import Logger +from PIL import Image +import io + +logger = Logger.get_logger(__name__) + +try: + from anthropic import ( + Anthropic, + AnthropicBedrock, + AnthropicVertex, + APIError, + APIResponseValidationError, + APIStatusError, + ) + from anthropic.types.beta import ( + BetaMessageParam, + BetaTextBlockParam, + ) + ANTHROPIC_AVAILABLE = True +except ImportError: + logger.warning("Anthropic SDK not available. Install with: pip install anthropic") + ANTHROPIC_AVAILABLE = False + +# Import utility functions +from .anthropic_utils import ( + APIProvider, + PROVIDER_TO_DEFAULT_MODEL_NAME, + COMPUTER_USE_BETA_FLAG, + PROMPT_CACHING_BETA_FLAG, + get_system_prompt, + inject_prompt_caching, + maybe_filter_to_n_most_recent_images, + response_to_params, +) + +# API retry configuration +API_RETRY_TIMES = 10 +API_RETRY_INTERVAL = 5 # seconds + + +class AnthropicGUIClient: + """ + Anthropic LLM Client for GUI operations. + Uses Claude Sonnet 4.5 with computer-use-2025-01-24 API. + + Features: + - Vision-based screen understanding + - Automatic screenshot resizing (configurable display size) + - Coordinate scaling between display and actual screen + """ + + def __init__( + self, + model: str = "claude-sonnet-4-5", + platform: str = "Ubuntu", + api_key: Optional[str] = None, + provider: str = "anthropic", + max_tokens: int = 4096, + screen_size: Tuple[int, int] = (1920, 1080), + display_size: Tuple[int, int] = (1024, 768), # Computer use display size + pyautogui_size: Optional[Tuple[int, int]] = None, # PyAutoGUI working size + only_n_most_recent_images: int = 3, + enable_prompt_caching: bool = True, + backup_api_key: Optional[str] = None, + ): + """ + Initialize Anthropic GUI Client for Claude Sonnet 4.5. + + Args: + model: Model name (only "claude-sonnet-4-5" supported) + platform: Platform type (Ubuntu, Windows, or macOS) + api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) + provider: API provider (only "anthropic" supported) + max_tokens: Maximum tokens for response + screen_size: Actual screenshot resolution (width, height) - physical pixels + display_size: Display size for computer use tool (width, height) + Screenshots will be resized to this size before sending to API + pyautogui_size: PyAutoGUI working size (logical pixels). If None, assumed same as screen_size. + On Retina/HiDPI displays, this may be screen_size / 2 + only_n_most_recent_images: Number of recent screenshots to keep in history + enable_prompt_caching: Whether to enable prompt caching for cost optimization + backup_api_key: Backup API key (defaults to ANTHROPIC_API_KEY_BACKUP env var) + """ + if not ANTHROPIC_AVAILABLE: + raise RuntimeError("Anthropic SDK not installed. Install with: pip install anthropic") + + # Only support claude-sonnet-4-5 + if model != "claude-sonnet-4-5": + logger.warning(f"Model '{model}' not supported. Using 'claude-sonnet-4-5'") + model = "claude-sonnet-4-5" + + self.model = model + self.platform = platform + self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + if not self.api_key: + raise ValueError("Anthropic API key not provided. Set ANTHROPIC_API_KEY env var or pass api_key parameter") + + # Backup API key for failover + self.backup_api_key = backup_api_key or os.environ.get("ANTHROPIC_API_KEY_BACKUP") + + # Only support anthropic provider + if provider != "anthropic": + logger.warning(f"Provider '{provider}' not supported. Using 'anthropic'") + provider = "anthropic" + + self.provider = APIProvider(provider) + self.max_tokens = max_tokens + self.screen_size = screen_size + self.display_size = display_size + self.pyautogui_size = pyautogui_size or screen_size # Default to screen_size if not specified + self.only_n_most_recent_images = only_n_most_recent_images + self.enable_prompt_caching = enable_prompt_caching + + # Message history + self.messages: List[BetaMessageParam] = [] + + # Calculate resize factor for coordinate scaling + # Step 1: LLM coordinates (display_size) -> Physical pixels (screen_size) + # Step 2: Physical pixels -> PyAutoGUI logical pixels (pyautogui_size) + self.resize_factor = ( + self.pyautogui_size[0] / display_size[0], # x scale factor + self.pyautogui_size[1] / display_size[1] # y scale factor + ) + + logger.info( + f"Initialized AnthropicGUIClient:\n" + f" Model: {model}\n" + f" Platform: {platform}\n" + f" Screen Size (physical): {screen_size}\n" + f" PyAutoGUI Size (logical): {self.pyautogui_size}\n" + f" Display Size (LLM): {display_size}\n" + f" Resize Factor (LLM->PyAutoGUI): {self.resize_factor}\n" + f" Prompt Caching: {enable_prompt_caching}" + ) + + def _create_client(self, api_key: Optional[str] = None): + """Create Anthropic client (only supports anthropic provider).""" + key = api_key or self.api_key + return Anthropic(api_key=key, max_retries=4) + + def _resize_screenshot(self, screenshot_bytes: bytes) -> bytes: + """ + Resize screenshot to display size for Computer Use API. + + For computer-use-2025-01-24, the screenshot must be resized to the + display_width_px x display_height_px specified in the tool definition. + """ + screenshot_image = Image.open(io.BytesIO(screenshot_bytes)) + resized_image = screenshot_image.resize(self.display_size, Image.Resampling.LANCZOS) + + output_buffer = io.BytesIO() + resized_image.save(output_buffer, format='PNG') + return output_buffer.getvalue() + + def _scale_coordinates(self, x: int, y: int) -> Tuple[int, int]: + """ + Scale coordinates from display size to actual screen size. + + The API returns coordinates in display_size (e.g., 1024x768). + We need to scale them to actual screen_size (e.g., 1920x1080) for execution. + + Args: + x, y: Coordinates in display size space + + Returns: + Scaled coordinates in actual screen size space + """ + scaled_x = int(x * self.resize_factor[0]) + scaled_y = int(y * self.resize_factor[1]) + return scaled_x, scaled_y + + async def plan_action( + self, + task_description: str, + screenshot: bytes, + action_history: List[Dict[str, Any]] = None, + ) -> Tuple[Optional[str], List[str]]: + """ + Plan next action based on task and current screenshot. + Includes prompt caching, error handling, and backup API key support. + + Args: + task_description: Task to accomplish + screenshot: Current screenshot (PNG bytes) + action_history: Previous actions (for context) + + Returns: + Tuple of (reasoning, list of pyautogui commands) + """ + # Resize screenshot + resized_screenshot = self._resize_screenshot(screenshot) + screenshot_b64 = base64.b64encode(resized_screenshot).decode('utf-8') + + # Initialize messages with first task + screenshot + if not self.messages: + # IMPORTANT: Image should come BEFORE text for better model understanding + # This matches OSWorld's implementation which has proven effectiveness + self.messages.append({ + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": screenshot_b64, + }, + }, + {"type": "text", "text": task_description}, + ] + }) + + # Filter images BEFORE adding new screenshot to control message size + # This is critical to avoid exceeding the 25MB API limit + image_truncation_threshold = 10 + if self.only_n_most_recent_images and len(self.messages) > 1: + # Reserve 1 slot for the screenshot we're about to add + maybe_filter_to_n_most_recent_images( + self.messages, + max(1, self.only_n_most_recent_images - 1), + min_removal_threshold=1, # More aggressive filtering + ) + + # Add tool result from previous action if exists + if self.messages and self.messages[-1]["role"] == "assistant": + last_content = self.messages[-1]["content"] + if isinstance(last_content, list) and any( + block.get("type") == "tool_use" for block in last_content + ): + tool_use_id = next( + block["id"] for block in last_content + if block.get("type") == "tool_use" + ) + self._add_tool_result(tool_use_id, "Success", resized_screenshot) + + # Define tools and betas for claude-sonnet-4-5 with computer-use-2025-01-24 + tools = [{ + 'name': 'computer', + 'type': 'computer_20250124', + 'display_width_px': self.display_size[0], + 'display_height_px': self.display_size[1], + 'display_number': 1 + }] + betas = [COMPUTER_USE_BETA_FLAG] + + # Prepare system prompt with optional caching + system = BetaTextBlockParam( + type="text", + text=get_system_prompt(self.platform) + ) + + # Enable prompt caching if supported and enabled + if self.enable_prompt_caching: + betas.append(PROMPT_CACHING_BETA_FLAG) + inject_prompt_caching(self.messages) + system["cache_control"] = {"type": "ephemeral"} # type: ignore + + # Model name - use claude-sonnet-4-5 directly + model_name = "claude-sonnet-4-5" + + # Enable thinking for complex computer use tasks + extra_body = {"thinking": {"type": "enabled", "budget_tokens": 2048}} + + # Log request details for debugging + # Count current images in messages + total_images = sum( + 1 + for message in self.messages + for item in (message.get("content", []) if isinstance(message.get("content"), list) else []) + if isinstance(item, dict) and item.get("type") == "image" + ) + tool_result_images = sum( + 1 + for message in self.messages + for item in (message.get("content", []) if isinstance(message.get("content"), list) else []) + if isinstance(item, dict) and item.get("type") == "tool_result" + for content in item.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + logger.info( + f"Anthropic API request:\n" + f" Model: {model_name}\n" + f" Display Size: {self.display_size}\n" + f" Betas: {betas}\n" + f" Images: {total_images} ({tool_result_images} in tool_results)\n" + f" Messages: {len(self.messages)}" + ) + + # Try API call with retry and backup + client = self._create_client() + response = None + + try: + # Retry loop with automatic image count reduction on 25MB error + for attempt in range(API_RETRY_TIMES): + try: + response = client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=model_name, + system=[system], + tools=tools, + betas=betas, + extra_body=extra_body + ) + logger.info(f"API call succeeded on attempt {attempt + 1}") + break + + except (APIError, APIStatusError, APIResponseValidationError) as e: + error_msg = str(e) + logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}") + + # Handle 25MB payload limit error (including HTTP 413) + if ("25000000" in error_msg or + "Member must have length less than or equal to" in error_msg or + "request_too_large" in error_msg or + "413" in str(e)): + logger.warning("Detected 25MB limit error, reducing image count") + current_count = self.only_n_most_recent_images + new_count = max(1, current_count // 2) + self.only_n_most_recent_images = new_count + + maybe_filter_to_n_most_recent_images( + self.messages, + new_count, + min_removal_threshold=1, # Aggressive filtering when hitting limit + ) + logger.info(f"Image count reduced from {current_count} to {new_count}") + + if attempt < API_RETRY_TIMES - 1: + time.sleep(API_RETRY_INTERVAL) + else: + raise + + except (APIError, APIStatusError, APIResponseValidationError) as e: + logger.error(f"Primary API key failed: {e}") + + # Try backup API key if available + if self.backup_api_key: + logger.warning("Retrying with backup API key...") + try: + backup_client = self._create_client(self.backup_api_key) + response = backup_client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=model_name, + system=[system], + tools=tools, + betas=betas, + extra_body=extra_body + ) + logger.info("Successfully used backup API key") + except Exception as backup_e: + logger.error(f"Backup API key also failed: {backup_e}") + return None, ["FAIL"] + else: + return None, ["FAIL"] + + except Exception as e: + logger.error(f"Unexpected error: {e}") + return None, ["FAIL"] + + if not response: + return None, ["FAIL"] + + # Parse response using utility function + response_params = response_to_params(response) + + # Extract reasoning and commands + reasoning = "" + commands = [] + + for block in response_params: + block_type = block.get("type") + + if block_type == "text": + reasoning = block.get("text", "") + elif block_type == "thinking": + reasoning = block.get("thinking", "") + elif block_type == "tool_use": + tool_input = block.get("input", {}) + command = self._parse_computer_tool_use(tool_input) + if command: + commands.append(command) + else: + logger.warning(f"Failed to parse tool_use: {tool_input}") + + # Store assistant response + self.messages.append({ + "role": "assistant", + "content": response_params + }) + + logger.info(f"Parsed {len(commands)} commands from response") + + return reasoning, commands + + def _add_tool_result( + self, + tool_use_id: str, + result: str, + screenshot_bytes: Optional[bytes] = None + ): + """ + Add tool result to message history. + IMPORTANT: Put screenshot BEFORE text for consistency with initial message. + """ + # Build content list with image first (if provided), then text + content_list = [] + + # Add screenshot first if provided (consistent with initial message ordering) + if screenshot_bytes is not None: + screenshot_b64 = base64.b64encode(screenshot_bytes).decode('utf-8') + content_list.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": screenshot_b64 + } + }) + + # Then add text result + content_list.append({"type": "text", "text": result}) + + tool_result_content = [{ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content_list + }] + + self.messages.append({ + "role": "user", + "content": tool_result_content + }) + + def _parse_computer_tool_use(self, tool_input: Dict[str, Any]) -> Optional[str]: + """ + Parse Anthropic computer tool use to pyautogui command. + + Args: + tool_input: Tool input from Anthropic (action, coordinate, text, etc.) + + Returns: + PyAutoGUI command string or control command (DONE, FAIL) + """ + action = tool_input.get("action") + if not action: + return None + + # Action conversion + action_conversion = { + "left click": "click", + "right click": "right_click" + } + action = action_conversion.get(action, action) + + text = tool_input.get("text") + coordinate = tool_input.get("coordinate") + scroll_direction = tool_input.get("scroll_direction") + scroll_amount = tool_input.get("scroll_amount", 5) + + # Scale coordinates to actual screen size + if coordinate: + coordinate = self._scale_coordinates(coordinate[0], coordinate[1]) + + # Build commands + command = "" + + if action == "mouse_move": + if coordinate: + x, y = coordinate + command = f"pyautogui.moveTo({x}, {y}, duration=0.5)" + + elif action in ("left_click", "click"): + if coordinate: + x, y = coordinate + command = f"pyautogui.click({x}, {y})" + else: + command = "pyautogui.click()" + + elif action == "right_click": + if coordinate: + x, y = coordinate + command = f"pyautogui.rightClick({x}, {y})" + else: + command = "pyautogui.rightClick()" + + elif action == "double_click": + if coordinate: + x, y = coordinate + command = f"pyautogui.doubleClick({x}, {y})" + else: + command = "pyautogui.doubleClick()" + + elif action == "middle_click": + if coordinate: + x, y = coordinate + command = f"pyautogui.middleClick({x}, {y})" + else: + command = "pyautogui.middleClick()" + + elif action == "left_click_drag": + if coordinate: + x, y = coordinate + command = f"pyautogui.dragTo({x}, {y}, duration=0.5)" + + elif action == "key": + if text: + keys = text.split('+') + # Key conversion + key_conversion = { + "page_down": "pagedown", + "page_up": "pageup", + "super_l": "win", + "super": "command", + "escape": "esc" + } + converted_keys = [key_conversion.get(k.strip().lower(), k.strip().lower()) for k in keys] + + # Press and release keys + for key in converted_keys: + command += f"pyautogui.keyDown('{key}'); " + for key in reversed(converted_keys): + command += f"pyautogui.keyUp('{key}'); " + # Remove trailing semicolon and space + command = command.rstrip('; ') + + elif action == "type": + if text: + command = f"pyautogui.typewrite({repr(text)}, interval=0.01)" + + elif action == "scroll": + if scroll_direction in ("up", "down"): + scroll_value = scroll_amount if scroll_direction == "up" else -scroll_amount + if coordinate: + x, y = coordinate + command = f"pyautogui.scroll({scroll_value}, {x}, {y})" + else: + command = f"pyautogui.scroll({scroll_value})" + elif scroll_direction in ("left", "right"): + scroll_value = scroll_amount if scroll_direction == "right" else -scroll_amount + if coordinate: + x, y = coordinate + command = f"pyautogui.hscroll({scroll_value}, {x}, {y})" + else: + command = f"pyautogui.hscroll({scroll_value})" + + elif action == "screenshot": + # Screenshot is automatically handled by the system + # Return special marker to indicate no action needed + return "SCREENSHOT" + + elif action == "wait": + # Wait for specified duration + duration = tool_input.get("duration", 1) + command = f"pyautogui.sleep({duration})" + + elif action == "done": + return "DONE" + + elif action == "fail": + return "FAIL" + + return command if command else None + + def reset(self): + """Reset message history.""" + self.messages = [] + logger.info("Reset AnthropicGUIClient message history") \ No newline at end of file diff --git a/openspace/grounding/backends/gui/anthropic_utils.py b/openspace/grounding/backends/gui/anthropic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..824580181263d2d13d867333a76bdb713af81f73 --- /dev/null +++ b/openspace/grounding/backends/gui/anthropic_utils.py @@ -0,0 +1,241 @@ +from typing import List, cast +from enum import Enum +from datetime import datetime +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +try: + from anthropic.types.beta import ( + BetaCacheControlEphemeralParam, + BetaContentBlockParam, + BetaImageBlockParam, + BetaMessage, + BetaMessageParam, + BetaTextBlock, + BetaTextBlockParam, + BetaToolResultBlockParam, + BetaToolUseBlockParam, + ) + ANTHROPIC_AVAILABLE = True +except ImportError: + ANTHROPIC_AVAILABLE = False + + +# Beta flags +# For claude-sonnet-4-5 with computer-use-2025-01-24 +COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24" +PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" + + +class APIProvider(Enum): + """API Provider enumeration""" + ANTHROPIC = "anthropic" + # BEDROCK = "bedrock" + # VERTEX = "vertex" + + +# Provider to model name mapping (simplified for claude-sonnet-4-5 only) +PROVIDER_TO_DEFAULT_MODEL_NAME: dict = { + (APIProvider.ANTHROPIC, "claude-sonnet-4-5"): "claude-sonnet-4-5", + # (APIProvider.BEDROCK, "claude-sonnet-4-5"): "us.anthropic.claude-sonnet-4-5-v1:0", + # (APIProvider.VERTEX, "claude-sonnet-4-5"): "claude-sonnet-4-5-v1", +} + + +def get_system_prompt(platform: str = "Ubuntu") -> str: + """ + Get system prompt based on platform. + + Args: + platform: Platform type (Ubuntu, Windows, macOS, or Darwin) + + Returns: + System prompt string + """ + # Normalize platform name + platform_lower = platform.lower() + + if platform_lower in ["windows", "win32"]: + return f""" +* You are utilising a Windows virtual machine using x86_64 architecture with internet access. +* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications. +* To accomplish tasks, you MUST use the computer tool to see the screen and take actions. +* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* The current date is {datetime.today().strftime('%A, %B %d, %Y')}. +* Home directory of this Windows system is 'C:\\Users\\user'. +* When you want to open some applications on Windows, please use Double Click on it instead of clicking once. +* After each action, the system will provide you with a new screenshot showing the result. +* Continue taking actions until the task is complete. +""" + elif platform_lower in ["macos", "darwin", "mac"]: + return f""" +* You are utilising a macOS system with internet access. +* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications. +* To accomplish tasks, you MUST use the computer tool to see the screen and take actions. +* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* The current date is {datetime.today().strftime('%A, %B %d, %Y')}. +* Home directory of this macOS system is typically '/Users/[username]' or can be accessed via '~'. +* On macOS, use Command (⌘) key combinations instead of Ctrl (e.g., Command+C for copy). +* After each action, the system will provide you with a new screenshot showing the result. +* Continue taking actions until the task is complete. +* When the task is completed, simply describe what you've done in your response WITHOUT using the tool again. +""" + else: # Ubuntu/Linux + return f""" +* You are utilising an Ubuntu virtual machine using x86_64 architecture with internet access. +* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications. +* To accomplish tasks, you MUST use the computer tool to see the screen and take actions. +* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* The current date is {datetime.today().strftime('%A, %B %d, %Y')}. +* Home directory of this Ubuntu system is '/home/user'. +* After each action, the system will provide you with a new screenshot showing the result. +* Continue taking actions until the task is complete. +""" + + +def inject_prompt_caching(messages: List[BetaMessageParam]) -> None: + """ + Set cache breakpoints for the 3 most recent turns. + One cache breakpoint is left for tools/system prompt, to be shared across sessions. + + Args: + messages: Message history (modified in place) + """ + if not ANTHROPIC_AVAILABLE: + return + + breakpoints_remaining = 3 + for message in reversed(messages): + if message["role"] == "user" and isinstance( + content := message["content"], list + ): + if breakpoints_remaining: + breakpoints_remaining -= 1 + # Use type ignore to bypass TypedDict check until SDK types are updated + content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore + {"type": "ephemeral"} + ) + else: + content[-1].pop("cache_control", None) + # we'll only ever have one extra turn per loop + break + + +def maybe_filter_to_n_most_recent_images( + messages: List[BetaMessageParam], + images_to_keep: int, + min_removal_threshold: int, +) -> None: + """ + With the assumption that images are screenshots that are of diminishing value as + the conversation progresses, remove all but the final `images_to_keep` tool_result + images in place, with a chunk of min_removal_threshold to reduce the amount we + break the implicit prompt cache. + + Args: + messages: Message history (modified in place) + images_to_keep: Number of recent images to keep + min_removal_threshold: Minimum number of images to remove at once (for cache efficiency) + """ + if not ANTHROPIC_AVAILABLE or images_to_keep is None: + return + + tool_result_blocks = cast( + list[BetaToolResultBlockParam], + [ + item + for message in messages + for item in ( + message["content"] if isinstance(message["content"], list) else [] + ) + if isinstance(item, dict) and item.get("type") == "tool_result" + ], + ) + + total_images = sum( + 1 + for tool_result in tool_result_blocks + for content in tool_result.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + + images_to_remove = total_images - images_to_keep + # for better cache behavior, we want to remove in chunks + images_to_remove -= images_to_remove % min_removal_threshold + + for tool_result in tool_result_blocks: + if isinstance(tool_result.get("content"), list): + new_content = [] + for content in tool_result.get("content", []): + if isinstance(content, dict) and content.get("type") == "image": + if images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result["content"] = new_content + + +def response_to_params(response: BetaMessage) -> List[BetaContentBlockParam]: + """ + Convert Anthropic response to parameter list. + Handles both text blocks, tool use blocks, and thinking blocks. + + Args: + response: Anthropic API response + + Returns: + List of content blocks + """ + if not ANTHROPIC_AVAILABLE: + return [] + + res: List[BetaContentBlockParam] = [] + if response.content: + for block in response.content: + # Check block type using type attribute + # Note: type may be a string or enum, so convert to string for comparison + block_type = str(getattr(block, "type", "")) + + if block_type == "text": + # Regular text block + if isinstance(block, BetaTextBlock) and block.text: + res.append(BetaTextBlockParam(type="text", text=block.text)) + elif block_type == "thinking": + # Thinking block (for Claude 4 and Sonnet 3.7) + thinking_block = { + "type": "thinking", + "thinking": getattr(block, "thinking", ""), + } + if hasattr(block, "signature"): + thinking_block["signature"] = getattr(block, "signature", None) + res.append(cast(BetaContentBlockParam, thinking_block)) + elif block_type == "tool_use": + # Tool use block - only include required fields to avoid API errors + # (e.g., 'caller' field is not permitted by Anthropic API) + tool_use_dict = { + "type": "tool_use", + "id": block.id, + "name": block.name, + "input": block.input, + } + res.append(cast(BetaToolUseBlockParam, tool_use_dict)) + else: + # Unknown block type - try to handle generically + try: + res.append(cast(BetaContentBlockParam, block.model_dump())) + except Exception as e: + logger.warning(f"Failed to parse block type {block_type}: {e}") + return res + else: + return [] + diff --git a/openspace/grounding/backends/gui/config.py b/openspace/grounding/backends/gui/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ef08739774400d62c5bfab19bffb653d23d5f306 --- /dev/null +++ b/openspace/grounding/backends/gui/config.py @@ -0,0 +1,76 @@ +from typing import Dict, Any, Optional +import os +import platform as platform_module +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +def build_llm_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Build complete LLM configuration with auto-detection and environment variables. + + Auto-detects: + - API key from environment variables (ANTHROPIC_API_KEY) + - Platform from system (macOS/Windows/Ubuntu) + - Provider defaults to 'anthropic' + + User-provided config values will override auto-detected values. + + Args: + user_config: User-provided configuration (optional) + + Returns: + Complete LLM configuration dict + + Example: + >>> # Auto-detect everything + >>> config = build_llm_config() + + >>> # Override specific values + >>> config = build_llm_config({ + ... "model": "claude-3-5-sonnet-20241022", + ... "max_tokens": 8192 + ... }) + """ + if user_config is None: + user_config = {} + + # Auto-detect platform + system = platform_module.system() + if system == "Darwin": + detected_platform = "macOS" + elif system == "Windows": + detected_platform = "Windows" + else: # Linux + detected_platform = "Ubuntu" + + # Auto-detect API key from environment + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + logger.warning( + "ANTHROPIC_API_KEY not found in environment. " + "Please set it: export ANTHROPIC_API_KEY='your-key'" + ) + + # Build configuration with precedence: user_config > auto-detected > defaults + config = { + "type": user_config.get("type", "anthropic"), + "model": user_config.get("model", "claude-sonnet-4-5"), + "platform": user_config.get("platform", detected_platform), + "api_key": user_config.get("api_key", api_key), + "provider": user_config.get("provider", "anthropic"), + "max_tokens": user_config.get("max_tokens", 4096), + "only_n_most_recent_images": user_config.get("only_n_most_recent_images", 3), + "enable_prompt_caching": user_config.get("enable_prompt_caching", True), + } + + # Optional: screen_size (will be auto-detected from screenshot later) + if "screen_size" in user_config: + config["screen_size"] = user_config["screen_size"] + + logger.info(f"Built LLM config - Platform: {config['platform']}, Model: {config['model']}") + if config["api_key"]: + logger.info(f"API key loaded: {config['api_key'][:10]}...") + + return config \ No newline at end of file diff --git a/openspace/grounding/backends/gui/provider.py b/openspace/grounding/backends/gui/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9389e28a620f60ea0b9900b8ec7b380161eb74 --- /dev/null +++ b/openspace/grounding/backends/gui/provider.py @@ -0,0 +1,143 @@ +from typing import Dict, Any, Union +from openspace.grounding.core.types import BackendType, SessionConfig +from openspace.grounding.core.provider import Provider +from openspace.grounding.core.session import BaseSession +from openspace.config import get_config +from openspace.config.utils import get_config_value +from openspace.platforms import get_local_server_config +from openspace.utils.logging import Logger +from .transport.connector import GUIConnector +from .transport.local_connector import LocalGUIConnector +from .session import GUISession + +logger = Logger.get_logger(__name__) + + +class GUIProvider(Provider): + """ + Provider for GUI desktop environment. + Manages communication with desktop_env through HTTP API or local in-process execution. + + Supports two modes: + - "local": Execute GUI operations directly in-process (no server needed) + - "server": Connect to a running local_server via HTTP API + + Supports automatic default session creation: + - If no session exists, a default session will be created on first use + - Default session uses configuration from config file or environment + """ + + DEFAULT_SID = BackendType.GUI.value + + def __init__(self, config: Dict[str, Any] = None): + """ + Initialize GUI provider. + + Args: + config: Provider configuration + """ + super().__init__(BackendType.GUI, config) + self.connectors: Dict[str, Union[GUIConnector, LocalGUIConnector]] = {} + + async def initialize(self) -> None: + """ + Initialize the provider and create default session. + """ + if not self.is_initialized: + logger.info("Initializing GUI provider") + # Auto-create default session + await self.create_session(SessionConfig( + session_name=self.DEFAULT_SID, + backend_type=BackendType.GUI, + connection_params={} + )) + self.is_initialized = True + + async def create_session(self, session_config: SessionConfig) -> BaseSession: + """ + Create GUI session. + + Args: + session_config: Session configuration + + Returns: + GUISession instance + """ + # Load GUI backend configuration + gui_config = get_config().get_backend_config("gui") + + # Determine execution mode: "local" or "server" + mode = getattr(gui_config, "mode", "local") + + # Extract connection parameters + conn_params = session_config.connection_params + timeout = get_config_value(conn_params, 'timeout', gui_config.timeout) + retry_times = get_config_value(conn_params, 'retry_times', gui_config.max_retries) + retry_interval = get_config_value(conn_params, 'retry_interval', gui_config.retry_interval) + + # Build pkgs_prefix with failsafe setting + failsafe_str = "True" if gui_config.failsafe else "False" + pkgs_prefix = get_config_value( + conn_params, + 'pkgs_prefix', + gui_config.pkgs_prefix.format(failsafe=failsafe_str, command="{command}") + ) + + if mode == "local": + # ---------- LOCAL MODE ---------- + logger.info("GUI backend using LOCAL mode (no server required)") + connector = LocalGUIConnector( + timeout=timeout, + retry_times=retry_times, + retry_interval=retry_interval, + pkgs_prefix=pkgs_prefix, + ) + else: + # ---------- SERVER MODE ---------- + logger.info("GUI backend using SERVER mode (connecting to local_server)") + local_server_config = get_local_server_config() + vm_ip = get_config_value(conn_params, 'vm_ip', local_server_config['host']) + server_port = get_config_value(conn_params, 'server_port', local_server_config['port']) + + connector = GUIConnector( + vm_ip=vm_ip, + server_port=server_port, + timeout=timeout, + retry_times=retry_times, + retry_interval=retry_interval, + pkgs_prefix=pkgs_prefix, + ) + + # Create session + session = GUISession( + connector=connector, + session_id=session_config.session_name, + backend_type=BackendType.GUI, + config=session_config, + ) + + # Store connector and session + self.connectors[session_config.session_name] = connector + self._sessions[session_config.session_name] = session + + logger.info(f"Created GUI session: {session_config.session_name} (mode={mode})") + return session + + async def close_session(self, session_name: str) -> None: + """ + Close GUI session. + + Args: + session_name: Name of the session to close + """ + if session_name in self._sessions: + session = self._sessions[session_name] + await session.disconnect() + del self._sessions[session_name] + + if session_name in self.connectors: + connector = self.connectors[session_name] + await connector.disconnect() + del self.connectors[session_name] + + logger.info(f"Closed GUI session: {session_name}") \ No newline at end of file diff --git a/openspace/grounding/backends/gui/session.py b/openspace/grounding/backends/gui/session.py new file mode 100644 index 0000000000000000000000000000000000000000..5815a2d8690646b2d5d2c8b15c3595c273bf4bd7 --- /dev/null +++ b/openspace/grounding/backends/gui/session.py @@ -0,0 +1,188 @@ +from typing import Dict, Any, Union +import os +from openspace.grounding.core.session import BaseSession +from openspace.grounding.core.types import BackendType, SessionStatus, SessionConfig +from openspace.utils.logging import Logger +from .transport.connector import GUIConnector +from .transport.local_connector import LocalGUIConnector +from .tool import GUIAgentTool +from .config import build_llm_config + +logger = Logger.get_logger(__name__) + + +class GUISession(BaseSession): + """ + Session for GUI desktop environment. + Manages connection and tools for GUI automation. + """ + + def __init__( + self, + connector: Union[GUIConnector, LocalGUIConnector], + session_id: str, + backend_type: BackendType.GUI, + config: SessionConfig, + auto_connect: bool = True, + auto_initialize: bool = True, + ): + """ + Initialize GUI session. + + Args: + connector: GUI HTTP connector + session_id: Unique session identifier + backend_type: Backend type (GUI) + config: Session configuration + auto_connect: Auto-connect on context enter + auto_initialize: Auto-initialize on context enter + """ + super().__init__( + connector=connector, + session_id=session_id, + backend_type=backend_type, + auto_connect=auto_connect, + auto_initialize=auto_initialize, + ) + self.config = config + self.gui_connector = connector + + async def initialize(self) -> Dict[str, Any]: + """ + Initialize session: connect and discover tools. + + Returns: + Session information dict + """ + logger.info(f"Initializing GUI session: {self.session_id}") + + # Ensure connected + if not self.connector.is_connected: + await self.connect() + + # Create LLM client if configured + llm_client = None + user_llm_config = self.config.connection_params.get("llm_config") + + # Build complete LLM config with auto-detection + # If user provides llm_config, merge with auto-detected values + # If user doesn't provide llm_config, try to auto-build one if ANTHROPIC_API_KEY exists + if user_llm_config or os.environ.get("ANTHROPIC_API_KEY"): + llm_config = build_llm_config(user_llm_config) + + if llm_config.get("type") == "anthropic": + # Check if API key is available + if not llm_config.get("api_key"): + logger.warning( + "Anthropic API key not found. Skipping LLM client initialization. " + "Set ANTHROPIC_API_KEY environment variable or provide api_key in llm_config." + ) + else: + try: + from .anthropic_client import AnthropicGUIClient + + # Detect actual screen size from screenshot (most accurate) + # PyAutoGUI may report logical resolution, but we need the actual screenshot size + try: + screenshot_bytes = await self.gui_connector.get_screenshot() + if screenshot_bytes: + from PIL import Image + import io + img = Image.open(io.BytesIO(screenshot_bytes)) + actual_screen_size = img.size + logger.info(f"Auto-detected screen size from screenshot: {actual_screen_size}") + screen_size = actual_screen_size + else: + raise RuntimeError("Could not get screenshot") + except Exception as e: + # Fallback to pyautogui detection + actual_screen_size = await self.gui_connector.get_screen_size() + if actual_screen_size: + logger.info(f"Auto-detected screen size from pyautogui: {actual_screen_size}") + screen_size = actual_screen_size + else: + # Final fallback to configured value + screen_size = llm_config.get("screen_size", (1920, 1080)) + logger.warning(f"Could not auto-detect screen size, using configured: {screen_size}") + + # Detect PyAutoGUI working size (logical pixels) + pyautogui_size = await self.gui_connector.get_screen_size() + if pyautogui_size: + logger.info(f"PyAutoGUI working size (logical): {pyautogui_size}") + else: + # If we can't detect PyAutoGUI size, assume it's the same as screen size + pyautogui_size = screen_size + logger.warning(f"Could not detect PyAutoGUI size, assuming same as screen: {pyautogui_size}") + + llm_client = AnthropicGUIClient( + model=llm_config["model"], + platform=llm_config["platform"], + api_key=llm_config["api_key"], + provider=llm_config["provider"], + screen_size=screen_size, + pyautogui_size=pyautogui_size, + max_tokens=llm_config["max_tokens"], + only_n_most_recent_images=llm_config["only_n_most_recent_images"], + ) + logger.info( + f"Initialized Anthropic LLM client - " + f"Model: {llm_config['model']}, Platform: {llm_config['platform']}" + ) + except Exception as e: + logger.warning(f"Failed to initialize Anthropic client: {e}") + + # Get recording_manager from connection_params if available + recording_manager = self.config.connection_params.get("recording_manager") + + # Create GUI Agent Tool + self.tools = [ + GUIAgentTool( + connector=self.gui_connector, + llm_client=llm_client, + recording_manager=recording_manager + ) + ] + + logger.info(f"Initialized GUI session with {len(self.tools)} tool(s)") + + # Return session info + session_info = { + "session_id": self.session_id, + "backend_type": self.backend_type.value, + "vm_ip": self.gui_connector.vm_ip, + "server_port": self.gui_connector.server_port, + "num_tools": len(self.tools), + "tools": [tool.name for tool in self.tools], + "llm_client": "anthropic" if llm_client else "none", + } + + return session_info + + async def connect(self) -> None: + """Connect to GUI desktop environment""" + if self.connector.is_connected: + return + + self.status = SessionStatus.CONNECTING + logger.info(f"Connecting to desktop_env at {self.gui_connector.base_url}") + + await self.connector.connect() + + self.status = SessionStatus.CONNECTED + logger.info("Connected to desktop environment") + + async def disconnect(self) -> None: + """Disconnect from GUI desktop environment""" + if not self.connector.is_connected: + return + + logger.info("Disconnecting from desktop environment") + await self.connector.disconnect() + + self.status = SessionStatus.DISCONNECTED + logger.info("Disconnected from desktop environment") + + @property + def is_connected(self) -> bool: + """Check if session is connected""" + return self.connector.is_connected \ No newline at end of file diff --git a/openspace/grounding/backends/gui/tool.py b/openspace/grounding/backends/gui/tool.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb90efa6d318c6a9c8a9a063320d8ee147a44f5 --- /dev/null +++ b/openspace/grounding/backends/gui/tool.py @@ -0,0 +1,712 @@ +import base64 +from typing import Any, Dict +from openspace.grounding.core.tool.base import BaseTool +from openspace.grounding.core.types import BackendType, ToolResult, ToolStatus +from .transport.connector import GUIConnector +from .transport.actions import ACTION_SPACE, KEYBOARD_KEYS +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class GUIAgentTool(BaseTool): + """ + LLM-powered GUI Agent Tool. + + This tool acts as an intelligent agent that: + - Takes a task description as input + - Observes the desktop via screenshot + - Uses LLM/VLM to understand and plan actions + - Outputs action space commands + - Executes actions through the connector + """ + + _name = "gui_agent" + _description = """Vision-based GUI automation agent for tasks requiring graphical interface interaction. + + Use this tool when the task involves: + - Operating desktop applications with graphical interfaces (browsers, editors, design tools, etc.) + - Tasks that require visual understanding of UI elements, layouts, or content + - Multi-step workflows that need click, drag, type, or other GUI interactions + - Scenarios where programmatic APIs or command-line tools are unavailable or insufficient + + The agent observes screen state through screenshots, uses vision-language models to understand + the interface, plans appropriate actions, and executes GUI operations autonomously. + + IMPORTANT - max_steps Parameter Guidelines: + - Simple tasks (1-2 actions): 15-20 steps + - Medium tasks (3-5 actions): 25-35 steps + - Complex tasks (6+ actions, like web navigation): 35-50 steps + - When uncertain, prefer larger values (35+) to avoid premature termination + - Default is 25, but increase for multi-step workflows + + Input: + - task_description: Natural language task description + - max_steps: Maximum actions (default 25, increase for complex tasks) + + Output: Task execution results with action history and completion status + """ + + backend_type = BackendType.GUI + + def __init__(self, connector: GUIConnector, llm_client=None, recording_manager=None, **kwargs): + """ + Initialize GUI Agent Tool. + + Args: + connector: GUI connector for communication with desktop_env + llm_client: LLM/VLM client for vision-based planning (optional) + recording_manager: RecordingManager for recording intermediate steps (optional) + **kwargs: Additional arguments for BaseTool + """ + super().__init__(**kwargs) + self.connector = connector + self.llm_client = llm_client # Will be injected later + self.recording_manager = recording_manager # For recording intermediate steps + self.action_history = [] # Track executed actions + + async def _arun( + self, + task_description: str, + max_steps: int = 50, + ) -> ToolResult: + """ + Execute a GUI automation task using LLM planning. + + This is the main entry point that: + 1. Gets current screenshot + 2. Uses LLM to plan next action based on task and screenshot + 3. Executes the planned action + 4. Repeats until task is complete or max_steps reached + + Args: + task_description: Natural language description of the task + max_steps: Maximum number of actions to execute (default 25) + Recommended values based on task complexity: + - Simple (1-2 actions): 15-20 + - Medium (3-5 actions): 25-35 + - Complex (6+ actions, web navigation, multi-app): 35-50 + When in doubt, use higher values to avoid premature termination + + Returns: + ToolResult with task execution status + """ + if not task_description: + return ToolResult( + status=ToolStatus.ERROR, + error="task_description is required" + ) + + logger.info(f"Starting GUI task: {task_description}") + self.action_history = [] + + # Execute task with LLM planning loop + try: + result = await self._execute_task_with_planning( + task_description=task_description, + max_steps=max_steps, + ) + return result + + except Exception as e: + logger.error(f"Task execution failed: {e}") + return ToolResult( + status=ToolStatus.ERROR, + error=str(e), + metadata={ + "task_description": task_description, + "actions_executed": len(self.action_history), + "action_history": self.action_history, + } + ) + + async def _execute_task_with_planning( + self, + task_description: str, + max_steps: int, + ) -> ToolResult: + """ + Execute task with LLM-based planning loop. + + Planning loop: + 1. Observe: Get screenshot + 2. Plan: LLM decides next action + 3. Execute: Perform the action + 4. Verify: Check if task is complete + 5. Repeat until done or max_steps + + Args: + task_description: Task to complete + max_steps: Maximum planning iterations + + Returns: + ToolResult with execution details + """ + # Collect all screenshots for visual analysis + all_screenshots = [] + # Collect intermediate steps + intermediate_steps = [] + + for step in range(max_steps): + logger.info(f"Planning step {step + 1}/{max_steps}") + + # Step 1: Observe current state + screenshot = await self.connector.get_screenshot() + if not screenshot: + return ToolResult( + status=ToolStatus.ERROR, + error="Failed to get screenshot for planning", + metadata={"step": step, "action_history": self.action_history} + ) + + # Collect screenshot for visual analysis + all_screenshots.append(screenshot) + + # Step 2: Plan next action using LLM + planned_action = await self._plan_next_action( + task_description=task_description, + screenshot=screenshot, + action_history=self.action_history, + ) + + # Check if task is complete + if planned_action["action_type"] == "DONE": + logger.info("Task marked as complete by LLM") + reasoning = planned_action.get("reasoning", "Task completed successfully") + + intermediate_steps.append({ + "step_number": step + 1, + "action": "DONE", + "reasoning": reasoning, + "status": "done", + }) + + return ToolResult( + status=ToolStatus.SUCCESS, + content=f"Task completed: {task_description}\n\nFinal state: {reasoning}", + metadata={ + "steps_taken": step + 1, + "action_history": self.action_history, + "screenshots": all_screenshots, + "intermediate_steps": intermediate_steps, + "final_reasoning": reasoning, + } + ) + + # Check if task failed + if planned_action["action_type"] == "FAIL": + logger.warning("Task marked as failed by LLM") + reason = planned_action.get("reason", "Task cannot be completed") + + intermediate_steps.append({ + "step_number": step + 1, + "action": "FAIL", + "reasoning": planned_action.get("reasoning", ""), + "status": "failed", + }) + + return ToolResult( + status=ToolStatus.ERROR, + error=reason, + metadata={ + "steps_taken": step + 1, + "action_history": self.action_history, + "screenshots": all_screenshots, + "intermediate_steps": intermediate_steps, + } + ) + + # Check if action is WAIT (screenshot observation, continue to next step) + if planned_action["action_type"] == "WAIT": + logger.info("Screenshot observation step, continuing planning loop") + intermediate_steps.append({ + "step_number": step + 1, + "action": "WAIT", + "reasoning": planned_action.get("reasoning", ""), + "status": "observation", + }) + continue + + # Step 3: Execute the planned action + execution_result = await self._execute_planned_action(planned_action) + + # Record action in history + self.action_history.append({ + "step": step + 1, + "planned_action": planned_action, + "execution_result": execution_result, + }) + + intermediate_steps.append({ + "step_number": step + 1, + "action": planned_action.get("action_type", "unknown"), + "reasoning": planned_action.get("reasoning", ""), + "status": execution_result.get("status", "unknown"), + }) + + # Check execution result + if execution_result.get("status") != "success": + logger.warning(f"Action execution failed: {execution_result.get('error')}") + # Continue to next iteration for retry planning + + # Max steps reached + return ToolResult( + status=ToolStatus.ERROR, + error=f"Task incomplete after {max_steps} steps", + metadata={ + "task_description": task_description, + "steps_taken": max_steps, + "action_history": self.action_history, + "screenshots": all_screenshots, + "intermediate_steps": intermediate_steps, + } + ) + + async def _plan_next_action( + self, + task_description: str, + screenshot: bytes, + action_history: list, + ) -> Dict[str, Any]: + """ + Use LLM/VLM to plan the next action. + + This method sends: + - Task description + - Current screenshot (vision input) + - Action history (context) + - Available ACTION_SPACE + + And gets back a structured action plan. + + Args: + task_description: The task to accomplish + screenshot: Current desktop screenshot (PNG/JPEG bytes) + action_history: Previously executed actions + + Returns: + Dict with action_type and parameters + """ + if self.llm_client is None: + # Fallback: Simple heuristic or manual mode + logger.warning("No LLM client configured, using fallback mode") + return { + "action_type": "FAIL", + "reason": "LLM client not configured" + } + + # Check if using Anthropic client + try: + from .anthropic_client import AnthropicGUIClient + is_anthropic = isinstance(self.llm_client, AnthropicGUIClient) + except ImportError: + is_anthropic = False + + if is_anthropic: + # Use Anthropic client + try: + reasoning, commands = await self.llm_client.plan_action( + task_description=task_description, + screenshot=screenshot, + action_history=action_history, + ) + + if commands == ["FAIL"]: + return { + "action_type": "FAIL", + "reason": "Anthropic planning failed" + } + + if commands == ["DONE"]: + return { + "action_type": "DONE", + "reasoning": reasoning + } + + if commands == ["SCREENSHOT"]: + # Screenshot is automatically handled by system + # Continue to next planning step + logger.info("LLM requested screenshot (observation step)") + return { + "action_type": "WAIT", + "reasoning": reasoning or "Observing screen state" + } + + # If no commands but has reasoning, task is complete + # (Anthropic returns text-only when task is done) + if not commands and reasoning: + logger.info("LLM returned text-only response, interpreting as task completion") + return { + "action_type": "DONE", + "reasoning": reasoning + } + + # No commands and no reasoning = error + if not commands: + return { + "action_type": "FAIL", + "reason": "No commands generated and no completion message" + } + + # Return first command (Anthropic returns pyautogui commands directly) + return { + "action_type": "PYAUTOGUI_COMMAND", + "command": commands[0], + "reasoning": reasoning + } + + except Exception as e: + logger.error(f"Anthropic planning failed: {e}") + return { + "action_type": "FAIL", + "reason": f"Planning error: {str(e)}" + } + + # Generic LLM client (for future integration with other LLMs) + # Encode screenshot to base64 for LLM + screenshot_b64 = base64.b64encode(screenshot).decode('utf-8') + + # Prepare prompt for LLM + prompt = self._build_planning_prompt( + task_description=task_description, + action_history=action_history, + ) + + # Call LLM with vision input + try: + response = await self.llm_client.plan_action( + prompt=prompt, + image_base64=screenshot_b64, + action_space=ACTION_SPACE, + keyboard_keys=KEYBOARD_KEYS, + ) + + # Parse LLM response to action dict + action = self._parse_llm_response(response) + + logger.info(f"LLM planned action: {action['action_type']}") + return action + + except Exception as e: + logger.error(f"LLM planning failed: {e}") + return { + "action_type": "FAIL", + "reason": f"Planning error: {str(e)}" + } + + def _build_planning_prompt( + self, + task_description: str, + action_history: list, + ) -> str: + """ + Build prompt for LLM planning. + + Args: + task_description: The task to accomplish + action_history: Previously executed actions + + Returns: + Formatted prompt string + """ + prompt = f"""You are a GUI automation agent. Your task is to complete the following: + +Task: {task_description} + +You can observe the current desktop state through the provided screenshot. +You must plan the next action to take from the available ACTION_SPACE. + +Available actions: +- Mouse: MOVE_TO, CLICK, RIGHT_CLICK, DOUBLE_CLICK, DRAG_TO, SCROLL +- Keyboard: TYPING, PRESS, KEY_DOWN, KEY_UP, HOTKEY +- Control: WAIT, DONE, FAIL + +""" + + if action_history: + prompt += f"\nPrevious actions taken ({len(action_history)}):\n" + for i, action in enumerate(action_history[-5:], 1): # Last 5 actions + prompt += f"{i}. {action['planned_action']['action_type']}" + if 'parameters' in action['planned_action']: + prompt += f" - {action['planned_action']['parameters']}" + prompt += "\n" + + prompt += """ +Based on the screenshot and task, output the next action in JSON format: +{ + "action_type": "ACTION_TYPE", + "parameters": {...}, + "reasoning": "Why this action is needed" +} + +If the task is complete, output: {"action_type": "DONE"} +If the task cannot be completed, output: {"action_type": "FAIL", "reason": "explanation"} +""" + + return prompt + + def _parse_llm_response(self, response: str) -> Dict[str, Any]: + """ + Parse LLM response to extract action. + + Args: + response: LLM response (should be JSON) + + Returns: + Action dict with action_type and parameters + """ + import json + + try: + # Try to parse as JSON + action = json.loads(response) + + # Validate action + if "action_type" not in action: + raise ValueError("Missing action_type in LLM response") + + return action + + except json.JSONDecodeError: + logger.error(f"Failed to parse LLM response as JSON: {response[:200]}") + return { + "action_type": "FAIL", + "reason": "Invalid LLM response format" + } + + async def _execute_planned_action( + self, + action: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Execute a planned action through the connector. + + Args: + action: Action dict with action_type and parameters + + Returns: + Execution result dict + """ + action_type = action["action_type"] + + # Handle Anthropic's direct pyautogui commands + if action_type == "PYAUTOGUI_COMMAND": + command = action.get("command", "") + logger.info(f"Executing pyautogui command: {command}") + + try: + result = await self.connector.execute_python_command(command) + return { + "status": "success" if result else "error", + "action_type": action_type, + "command": command, + "result": result + } + except Exception as e: + logger.error(f"Command execution error: {e}") + return { + "status": "error", + "action_type": action_type, + "error": str(e) + } + + # Handle standard action space commands + parameters = action.get("parameters", {}) + logger.info(f"Executing action: {action_type}") + + try: + result = await self.connector.execute_action(action_type, parameters) + return result + + except Exception as e: + logger.error(f"Action execution error: {e}") + return { + "status": "error", + "action_type": action_type, + "error": str(e) + } + + # Helper methods for direct action execution + + async def execute_action( + self, + action_type: str, + parameters: Dict[str, Any] + ) -> ToolResult: + """ + Direct action execution (bypass LLM planning). + + Args: + action_type: Action type from ACTION_SPACE + parameters: Action parameters + + Returns: + ToolResult with execution status + """ + result = await self.connector.execute_action(action_type, parameters) + + if result.get("status") == "success": + return ToolResult( + status=ToolStatus.SUCCESS, + content=f"Executed {action_type}", + metadata=result + ) + else: + return ToolResult( + status=ToolStatus.ERROR, + error=result.get("error", "Unknown error"), + metadata=result + ) + + async def get_screenshot(self) -> ToolResult: + """Get current desktop screenshot.""" + screenshot = await self.connector.get_screenshot() + if screenshot: + return ToolResult( + status=ToolStatus.SUCCESS, + content=screenshot, + metadata={"type": "screenshot", "size": len(screenshot)} + ) + else: + return ToolResult( + status=ToolStatus.ERROR, + error="Failed to capture screenshot" + ) + + async def _record_intermediate_step( + self, + step_number: int, + planned_action: Dict[str, Any], + execution_result: Dict[str, Any], + screenshot: bytes, + task_description: str, + ): + """ + Record an intermediate step of GUI agent execution. + + This method records each planning-action cycle to the recording system, + providing detailed traces of GUI agent's decision-making process. + + Args: + step_number: Step number in the execution sequence + planned_action: Action planned by LLM + execution_result: Result of executing the action + screenshot: Screenshot before executing the action + task_description: Overall task description + """ + # Try to get recording_manager dynamically if not set at initialization + recording_manager = self.recording_manager + if not recording_manager and hasattr(self, '_runtime_info') and self._runtime_info: + # Try to get from grounding_client + grounding_client = self._runtime_info.grounding_client + if grounding_client and hasattr(grounding_client, 'recording_manager'): + recording_manager = grounding_client.recording_manager + logger.debug(f"Step {step_number}: Dynamically retrieved recording_manager from grounding_client") + + if not recording_manager: + logger.debug(f"Step {step_number}: No recording_manager available, skipping intermediate step recording") + return + + # Check if recording is active + try: + from openspace.recording.manager import RecordingManager + if not RecordingManager.is_recording(): + logger.debug(f"Step {step_number}: RecordingManager not started") + return + except Exception as e: + logger.debug(f"Step {step_number}: Failed to check recording status: {e}") + return + + # Check if recorder is initialized + if not hasattr(recording_manager, '_recorder') or not recording_manager._recorder: + logger.warning(f"Step {step_number}: recording_manager._recorder not initialized") + return + + # Build command string for display + action_type = planned_action.get("action_type", "unknown") + command = self._format_action_command(planned_action) + + # Build result summary + status = execution_result.get("status", "unknown") + is_success = status in ("success", "done", "observation") + + # Build result content + if status == "done": + result_content = f"Task completed at step {step_number}" + elif status == "failed": + result_content = execution_result.get("message", "Task failed") + elif status == "observation": + result_content = execution_result.get("message", "Screenshot observation") + else: + result_content = execution_result.get("result", execution_result.get("message", str(execution_result))) + + # Build parameters for recording + parameters = { + "task_description": task_description, + "step_number": step_number, + "action_type": action_type, + "planned_action": planned_action, + } + + # Record to trajectory recorder (handles screenshot saving) + try: + await recording_manager._recorder.record_step( + backend="gui", + tool="gui_agent_step", + command=command, + result={ + "status": "success" if is_success else "error", + "output": str(result_content)[:200], + }, + parameters=parameters, + screenshot=screenshot, + extra={ + "gui_step_number": step_number, + "reasoning": planned_action.get("reasoning", ""), + } + ) + + logger.info(f"✓ Recorded GUI intermediate step {step_number}: {command}") + + except Exception as e: + logger.error(f"✗ Failed to record intermediate step {step_number}: {e}", exc_info=True) + + def _format_action_command(self, planned_action: Dict[str, Any]) -> str: + """ + Format planned action into a human-readable command string. + + Args: + planned_action: Action dictionary from LLM planning + + Returns: + Formatted command string + """ + action_type = planned_action.get("action_type", "unknown") + + # Handle special action types + if action_type == "DONE": + return "DONE (task completed)" + elif action_type == "FAIL": + reason = planned_action.get("reason", "unknown") + return f"FAIL ({reason})" + elif action_type == "WAIT": + return "WAIT (screenshot observation)" + + # Handle PyAutoGUI commands + elif action_type == "PYAUTOGUI_COMMAND": + command = planned_action.get("command", "") + # Truncate long commands + if len(command) > 100: + return command[:100] + "..." + return command + + # Handle standard action space commands + else: + parameters = planned_action.get("parameters", {}) + if parameters: + # Format first 2 parameters + param_items = list(parameters.items())[:2] + param_str = ", ".join([f"{k}={v}" for k, v in param_items]) + return f"{action_type}({param_str})" + else: + return action_type \ No newline at end of file diff --git a/openspace/grounding/backends/gui/transport/actions.py b/openspace/grounding/backends/gui/transport/actions.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9951b51bbe0644946313ff5bfc6f1391e0f365 --- /dev/null +++ b/openspace/grounding/backends/gui/transport/actions.py @@ -0,0 +1,232 @@ +""" +GUI Action Space Definitions. +""" +from typing import Dict, Any + +# Screen resolution constants +X_MAX = 1920 +Y_MAX = 1080 + +# Keyboard keys constants +KEYBOARD_KEYS = [ + '\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', + '[', '\\', ']', '^', '_', '`', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '{', '|', '}', '~', + 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', + 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', + 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', + 'down', 'end', 'enter', 'esc', 'escape', 'execute', + 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', + 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', + 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', + 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', + 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', + 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', + 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', + 'shift', 'shiftleft', 'shiftright', 'sleep', 'stop', 'subtract', 'tab', 'up', + 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', + 'command', 'option', 'optionleft', 'optionright' +] + +# Action Space Definition +ACTION_SPACE = [ + { + "action_type": "MOVE_TO", + "note": "move the cursor to the specified position", + "parameters": { + "x": {"type": float, "range": [0, X_MAX], "optional": False}, + "y": {"type": float, "range": [0, Y_MAX], "optional": False}, + } + }, + { + "action_type": "CLICK", + "note": "click the left button if button not specified, otherwise click the specified button", + "parameters": { + "button": {"type": str, "range": ["left", "right", "middle"], "optional": True}, + "x": {"type": float, "range": [0, X_MAX], "optional": True}, + "y": {"type": float, "range": [0, Y_MAX], "optional": True}, + "num_clicks": {"type": int, "range": [1, 2, 3], "optional": True}, + } + }, + { + "action_type": "MOUSE_DOWN", + "note": "press the mouse button", + "parameters": { + "button": {"type": str, "range": ["left", "right", "middle"], "optional": True} + } + }, + { + "action_type": "MOUSE_UP", + "note": "release the mouse button", + "parameters": { + "button": {"type": str, "range": ["left", "right", "middle"], "optional": True} + } + }, + { + "action_type": "RIGHT_CLICK", + "note": "right click at position", + "parameters": { + "x": {"type": float, "range": [0, X_MAX], "optional": True}, + "y": {"type": float, "range": [0, Y_MAX], "optional": True} + } + }, + { + "action_type": "DOUBLE_CLICK", + "note": "double click at position", + "parameters": { + "x": {"type": float, "range": [0, X_MAX], "optional": True}, + "y": {"type": float, "range": [0, Y_MAX], "optional": True} + } + }, + { + "action_type": "DRAG_TO", + "note": "drag the cursor to position", + "parameters": { + "x": {"type": float, "range": [0, X_MAX], "optional": False}, + "y": {"type": float, "range": [0, Y_MAX], "optional": False} + } + }, + { + "action_type": "SCROLL", + "note": "scroll the mouse wheel", + "parameters": { + "dx": {"type": int, "range": None, "optional": False}, + "dy": {"type": int, "range": None, "optional": False} + } + }, + { + "action_type": "TYPING", + "note": "type the specified text", + "parameters": { + "text": {"type": str, "range": None, "optional": False} + } + }, + { + "action_type": "PRESS", + "note": "press the specified key", + "parameters": { + "key": {"type": str, "range": KEYBOARD_KEYS, "optional": False} + } + }, + { + "action_type": "KEY_DOWN", + "note": "press down the specified key", + "parameters": { + "key": {"type": str, "range": KEYBOARD_KEYS, "optional": False} + } + }, + { + "action_type": "KEY_UP", + "note": "release the specified key", + "parameters": { + "key": {"type": str, "range": KEYBOARD_KEYS, "optional": False} + } + }, + { + "action_type": "HOTKEY", + "note": "press key combination", + "parameters": { + "keys": {"type": list, "range": [KEYBOARD_KEYS], "optional": False} + } + }, + { + "action_type": "WAIT", + "note": "wait until next action", + }, + { + "action_type": "FAIL", + "note": "mark task as failed", + }, + { + "action_type": "DONE", + "note": "mark task as done", + } +] + + +def build_pyautogui_command(action_type: str, parameters: Dict[str, Any]) -> str: + """ + Build pyautogui command from action type and parameters. + + Args: + action_type: Type of action (e.g., 'CLICK', 'TYPING') + parameters: Action parameters + + Returns: + Python command string + """ + if action_type == "MOVE_TO": + if "x" in parameters and "y" in parameters: + x, y = parameters["x"], parameters["y"] + return f"pyautogui.moveTo({x}, {y}, 0.5, pyautogui.easeInQuad)" + else: + return "pyautogui.moveTo()" + + elif action_type == "CLICK": + button = parameters.get("button", "left") + num_clicks = parameters.get("num_clicks", 1) + + if "x" in parameters and "y" in parameters: + x, y = parameters["x"], parameters["y"] + return f"pyautogui.click(button='{button}', x={x}, y={y}, clicks={num_clicks})" + else: + return f"pyautogui.click(button='{button}', clicks={num_clicks})" + + elif action_type == "MOUSE_DOWN": + button = parameters.get("button", "left") + return f"pyautogui.mouseDown(button='{button}')" + + elif action_type == "MOUSE_UP": + button = parameters.get("button", "left") + return f"pyautogui.mouseUp(button='{button}')" + + elif action_type == "RIGHT_CLICK": + if "x" in parameters and "y" in parameters: + x, y = parameters["x"], parameters["y"] + return f"pyautogui.rightClick(x={x}, y={y})" + else: + return "pyautogui.rightClick()" + + elif action_type == "DOUBLE_CLICK": + if "x" in parameters and "y" in parameters: + x, y = parameters["x"], parameters["y"] + return f"pyautogui.doubleClick(x={x}, y={y})" + else: + return "pyautogui.doubleClick()" + + elif action_type == "DRAG_TO": + if "x" in parameters and "y" in parameters: + x, y = parameters["x"], parameters["y"] + return f"pyautogui.dragTo({x}, {y}, 1.0, button='left')" + + elif action_type == "SCROLL": + dx = parameters.get("dx", 0) + dy = parameters.get("dy", 0) + return f"pyautogui.scroll({dy})" + + elif action_type == "TYPING": + text = parameters.get("text", "") + # Use repr() for proper string escaping + return f"pyautogui.typewrite({repr(text)})" + + elif action_type == "PRESS": + key = parameters.get("key", "") + return f"pyautogui.press('{key}')" + + elif action_type == "KEY_DOWN": + key = parameters.get("key", "") + return f"pyautogui.keyDown('{key}')" + + elif action_type == "KEY_UP": + key = parameters.get("key", "") + return f"pyautogui.keyUp('{key}')" + + elif action_type == "HOTKEY": + keys = parameters.get("keys", []) + if keys: + keys_str = ", ".join([f"'{k}'" for k in keys]) + return f"pyautogui.hotkey({keys_str})" + + return None \ No newline at end of file diff --git a/openspace/grounding/backends/gui/transport/connector.py b/openspace/grounding/backends/gui/transport/connector.py new file mode 100644 index 0000000000000000000000000000000000000000..8a85fb51afaa7361f423b32376fb0ce40ab025bb --- /dev/null +++ b/openspace/grounding/backends/gui/transport/connector.py @@ -0,0 +1,389 @@ +import asyncio +import re +from typing import Any, Dict, Optional +from openspace.grounding.core.transport.connectors import AioHttpConnector +from .actions import build_pyautogui_command, KEYBOARD_KEYS +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class GUIConnector(AioHttpConnector): + """ + Connector for desktop_env HTTP API. + Provides action execution and observation methods. + """ + + def __init__( + self, + vm_ip: str, + server_port: int = 5000, + timeout: int = 90, + retry_times: int = 3, + retry_interval: float = 5.0, + pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}", + ): + """ + Initialize GUI connector. + + Args: + vm_ip: IP address of the VM running desktop_env + server_port: Port of the desktop_env HTTP server + timeout: Request timeout in seconds + retry_times: Number of retries for failed requests + retry_interval: Interval between retries in seconds + pkgs_prefix: Python command prefix for pyautogui setup + """ + base_url = f"http://{vm_ip}:{server_port}" + super().__init__(base_url, timeout=timeout) + + self.vm_ip = vm_ip + self.server_port = server_port + self.retry_times = retry_times + self.retry_interval = retry_interval + self.pkgs_prefix = pkgs_prefix + self.timeout = timeout + + async def _retry_invoke( + self, + operation_name: str, + operation_func, + *args, + **kwargs + ): + """ + Execute operation with retry logic. + + Args: + operation_name: Name of operation for logging + operation_func: Async function to execute + *args: Positional arguments for operation_func + **kwargs: Keyword arguments for operation_func + + Returns: + Operation result + + Raises: + Exception: Last exception after all retries fail + """ + last_exc: Exception | None = None + + for attempt in range(1, self.retry_times + 1): + try: + result = await operation_func(*args, **kwargs) + logger.debug("%s executed successfully (attempt %d/%d)", operation_name, attempt, self.retry_times) + return result + except asyncio.TimeoutError as exc: + logger.error("%s timed out", operation_name) + raise RuntimeError(f"{operation_name} timed out after {self.timeout} seconds") from exc + except Exception as exc: + last_exc = exc + if attempt == self.retry_times: + break + logger.warning( + "%s failed (attempt %d/%d): %s, retrying in %.1f seconds...", + operation_name, attempt, self.retry_times, exc, self.retry_interval + ) + await asyncio.sleep(self.retry_interval) + + error_msg = f"{operation_name} failed after {self.retry_times} retries" + logger.error(error_msg) + raise last_exc or RuntimeError(error_msg) + + @staticmethod + def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool: + """Validate image response using magic bytes.""" + if not isinstance(data, (bytes, bytearray)) or not data: + return False + # PNG magic + if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n": + return True + # JPEG magic + if len(data) >= 3 and data[:3] == b"\xff\xd8\xff": + return True + # Fallback to content-type + if content_type and ("image/png" in content_type or "image/jpeg" in content_type): + return True + return False + + @staticmethod + def _fix_pyautogui_less_than_bug(command: str) -> str: + """ + Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls. + + This fixes the known PyAutoGUI issue where typing '<' produces '>' instead. + References: + - https://github.com/asweigart/pyautogui/issues/198 + - https://github.com/xlang-ai/OSWorld/issues/257 + + Args: + command (str): The original pyautogui command + + Returns: + str: The fixed command with '<' characters handled properly + """ + # Pattern to match press('<') or press('\u003c') calls + press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)' + + # Handle press('<') calls + def replace_press_less_than(match): + return 'pyautogui.hotkey("shift", ",")' + + # First handle press('<') calls + command = re.sub(press_pattern, replace_press_less_than, command) + + # Pattern to match typewrite calls with quoted strings + typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)' + + # Then handle typewrite calls + def process_typewrite_match(match): + quote_char = match.group(1) + content = match.group(2) + + # Preprocess: Try to decode Unicode escapes like \u003c to actual '<' + # This handles cases where '<' is represented as escaped Unicode + try: + # Attempt to decode unicode escapes + decoded_content = content.encode('utf-8').decode('unicode_escape') + content = decoded_content + except UnicodeDecodeError: + # If decoding fails, proceed with original content to avoid breaking existing logic + pass # Graceful degradation - fall back to original content if decoding fails + + # Check if content contains '<' + if '<' not in content: + return match.group(0) + + # Split by '<' and rebuild + parts = content.split('<') + result_parts = [] + + for i, part in enumerate(parts): + if i == 0: + # First part + if part: + result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})") + else: + # Add hotkey for '<' and then typewrite for the rest + result_parts.append('pyautogui.hotkey("shift", ",")') + if part: + result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})") + + return '; '.join(result_parts) + + command = re.sub(typewrite_pattern, process_typewrite_match, command) + + return command + + async def get_screen_size(self) -> Optional[tuple[int, int]]: + """ + Get actual screen size from desktop environment using pyautogui. + + Returns: + (width, height) tuple, or None on failure + """ + try: + command = "print(pyautogui.size())" + result = await self.execute_python_command(command) + if result and result.get("status") == "success": + output = result.get("output", "") + # Parse output like "Size(width=2880, height=1800)" + import re + match = re.search(r'width=(\d+).*height=(\d+)', output) + if match: + width = int(match.group(1)) + height = int(match.group(2)) + logger.info(f"Detected screen size: {width}x{height}") + return (width, height) + logger.warning(f"Failed to detect screen size, output: {result}") + return None + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return None + + async def get_screenshot(self) -> Optional[bytes]: + """ + Get screenshot from desktop environment. + + Returns: + Screenshot image bytes (PNG/JPEG), or None on failure + """ + try: + async def _get(): + response = await self._request("GET", "/screenshot", timeout=10) + if response.status == 200: + content_type = response.headers.get("Content-Type", "") + content = await response.read() + if self._is_valid_image_response(content_type, content): + return content + else: + raise ValueError("Invalid screenshot format") + else: + raise RuntimeError(f"HTTP {response.status}") + + return await self._retry_invoke("get_screenshot", _get) + except Exception as e: + logger.error(f"Failed to get screenshot: {e}") + return None + + async def execute_python_command(self, command: str) -> Optional[Dict[str, Any]]: + """ + Execute a Python command on desktop environment. + Used for pyautogui commands. + + Args: + command: Python command to execute + + Returns: + Response dict with execution result, or None on failure + """ + try: + # Apply '<' character fix for PyAutoGUI bug + fixed_command = self._fix_pyautogui_less_than_bug(command) + + command_list = ["python", "-c", self.pkgs_prefix.format(command=fixed_command)] + payload = {"command": command_list, "shell": False} + + async def _execute(): + return await self.post_json("/execute", payload) + + return await self._retry_invoke("execute_python_command", _execute) + except Exception as e: + logger.error(f"Failed to execute command: {e}") + return None + + async def execute_action(self, action_type: str, parameters: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Execute a desktop action. + This is the main method for action space execution. + + Args: + action_type: Action type (e.g., 'CLICK', 'TYPING') + parameters: Action parameters + + Returns: + Result dict with execution status + """ + parameters = parameters or {} + + # Handle control actions + if action_type in ['WAIT', 'FAIL', 'DONE']: + return { + "status": "success", + "action_type": action_type, + "message": f"Control action {action_type} acknowledged" + } + + # Validate keyboard keys + if action_type in ['PRESS', 'KEY_DOWN', 'KEY_UP']: + key = parameters.get('key') + if key and key not in KEYBOARD_KEYS: + return { + "status": "error", + "action_type": action_type, + "error": f"Invalid key: {key}. Must be in supported keyboard keys." + } + + if action_type == 'HOTKEY': + keys = parameters.get('keys', []) + invalid_keys = [k for k in keys if k not in KEYBOARD_KEYS] + if invalid_keys: + return { + "status": "error", + "action_type": action_type, + "error": f"Invalid keys: {invalid_keys}" + } + + # Build pyautogui command + command = build_pyautogui_command(action_type, parameters) + + if command is None: + return { + "status": "error", + "action_type": action_type, + "error": f"Unsupported action type: {action_type}" + } + + # Execute command + result = await self.execute_python_command(command) + + if result: + return { + "status": "success", + "action_type": action_type, + "parameters": parameters, + "result": result + } + else: + return { + "status": "error", + "action_type": action_type, + "parameters": parameters, + "error": "Command execution failed" + } + + async def get_accessibility_tree(self, max_depth: int = 5) -> Optional[Dict[str, Any]]: + """ + Get accessibility tree from desktop environment. + + Args: + max_depth: Maximum depth of accessibility tree traversal + + Returns: + Accessibility tree as dict, or None on failure + """ + try: + async def _get(): + response = await self._request("GET", "/accessibility", timeout=10) + if response.status == 200: + data = await response.json() + return data.get("AT") + else: + raise RuntimeError(f"HTTP {response.status}") + + return await self._retry_invoke("get_accessibility_tree", _get) + except Exception as e: + logger.error(f"Failed to get accessibility tree: {e}") + return None + + async def get_cursor_position(self) -> Optional[tuple[int, int]]: + """ + Get current mouse cursor position. + Useful for GUI debugging and relative positioning. + + Returns: + (x, y) tuple, or None on failure + """ + try: + async def _get(): + result = await self.get_json("/cursor_position") + return (result.get("x"), result.get("y")) + + return await self._retry_invoke("get_cursor_position", _get) + except Exception as e: + logger.error(f"Failed to get cursor position: {e}") + return None + + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + """ + Unified RPC entry for operations. + Required by BaseConnector. + + Args: + name: Operation name (action_type or observation method) + params: Operation parameters + + Returns: + Operation result + """ + # Handle observation methods + if name == "screenshot": + return await self.get_screenshot() + elif name == "accessibility_tree": + max_depth = params.get("max_depth", 5) if params else 5 + return await self.get_accessibility_tree(max_depth) + elif name == "cursor_position": + return await self.get_cursor_position() + else: + # Treat as action + return await self.execute_action(name.upper(), params or {}) \ No newline at end of file diff --git a/openspace/grounding/backends/gui/transport/local_connector.py b/openspace/grounding/backends/gui/transport/local_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..27d2ae38c7e1cb6b091ba6b1a92e553f5adb2aec --- /dev/null +++ b/openspace/grounding/backends/gui/transport/local_connector.py @@ -0,0 +1,364 @@ +""" +Local GUI Connector — execute GUI operations directly in-process. + +This connector has the **same public API** as GUIConnector (HTTP version) +but uses local pyautogui / ScreenshotHelper / AccessibilityHelper, +removing the need for a local_server. + +Return format is kept identical so that GUISession / GUIAgentTool +work without any changes. +""" + +import asyncio +import os +import platform +import re +import tempfile +import uuid +from typing import Any, Dict, Optional + +from openspace.grounding.core.transport.connectors.base import BaseConnector +from openspace.grounding.core.transport.task_managers.noop import NoOpConnectionManager +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +platform_name = platform.system() + + +class LocalGUIConnector(BaseConnector[Any]): + """ + GUI connector that runs desktop automation **locally** using pyautogui / + ScreenshotHelper / AccessibilityHelper, bypassing the Flask local_server. + + Public API is compatible with ``GUIConnector`` so that ``GUISession`` + works without modification. + """ + + def __init__( + self, + timeout: int = 90, + retry_times: int = 3, + retry_interval: float = 5.0, + pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}", + ): + super().__init__(NoOpConnectionManager()) + self.timeout = timeout + self.retry_times = retry_times + self.retry_interval = retry_interval + self.pkgs_prefix = pkgs_prefix + + # Compatibility attributes expected by GUISession + self.vm_ip = "localhost" + self.server_port = 0 + self.base_url = "local://localhost" + + # Lazy-initialized helpers (avoid import side effects at class load) + self._screenshot_helper = None + self._accessibility_helper = None + + def _get_screenshot_helper(self): + if self._screenshot_helper is None: + from openspace.local_server.utils import ScreenshotHelper + self._screenshot_helper = ScreenshotHelper() + return self._screenshot_helper + + def _get_accessibility_helper(self): + if self._accessibility_helper is None: + from openspace.local_server.utils import AccessibilityHelper + self._accessibility_helper = AccessibilityHelper() + return self._accessibility_helper + + # ------------------------------------------------------------------ + # connect / disconnect + # ------------------------------------------------------------------ + + async def connect(self) -> None: + """No real connection for local mode.""" + if self._connected: + return + await super().connect() + logger.info("LocalGUIConnector: ready (local mode, no server required)") + + # ------------------------------------------------------------------ + # Retry wrapper (same interface as GUIConnector._retry_invoke) + # ------------------------------------------------------------------ + + async def _retry_invoke( + self, + operation_name: str, + operation_func, + *args, + **kwargs, + ): + last_exc: Exception | None = None + for attempt in range(1, self.retry_times + 1): + try: + result = await operation_func(*args, **kwargs) + logger.debug( + "%s executed successfully (attempt %d/%d)", + operation_name, attempt, self.retry_times, + ) + return result + except asyncio.TimeoutError as exc: + logger.error("%s timed out", operation_name) + raise RuntimeError( + f"{operation_name} timed out after {self.timeout} seconds" + ) from exc + except Exception as exc: + last_exc = exc + if attempt == self.retry_times: + break + logger.warning( + "%s failed (attempt %d/%d): %s, retrying in %.1f seconds...", + operation_name, attempt, self.retry_times, exc, self.retry_interval, + ) + await asyncio.sleep(self.retry_interval) + + error_msg = f"{operation_name} failed after {self.retry_times} retries" + logger.error(error_msg) + raise last_exc or RuntimeError(error_msg) + + # ------------------------------------------------------------------ + # PyAutoGUI '<' bug fix (same as GUIConnector) + # ------------------------------------------------------------------ + + @staticmethod + def _fix_pyautogui_less_than_bug(command: str) -> str: + """Fix PyAutoGUI '<' character bug.""" + press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)' + + def replace_press_less_than(match): + return 'pyautogui.hotkey("shift", ",")' + + command = re.sub(press_pattern, replace_press_less_than, command) + + typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)' + + def process_typewrite_match(match): + quote_char = match.group(1) + content = match.group(2) + try: + decoded_content = content.encode("utf-8").decode("unicode_escape") + content = decoded_content + except UnicodeDecodeError: + pass + if "<" not in content: + return match.group(0) + parts = content.split("<") + result_parts = [] + for i, part in enumerate(parts): + if i == 0: + if part: + result_parts.append( + f"pyautogui.typewrite({quote_char}{part}{quote_char})" + ) + else: + result_parts.append('pyautogui.hotkey("shift", ",")') + if part: + result_parts.append( + f"pyautogui.typewrite({quote_char}{part}{quote_char})" + ) + return "; ".join(result_parts) + + command = re.sub(typewrite_pattern, process_typewrite_match, command) + return command + + # ------------------------------------------------------------------ + # Image response validation (same as GUIConnector) + # ------------------------------------------------------------------ + + @staticmethod + def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool: + if not isinstance(data, (bytes, bytearray)) or not data: + return False + if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n": + return True + if len(data) >= 3 and data[:3] == b"\xff\xd8\xff": + return True + if content_type and ("image/png" in content_type or "image/jpeg" in content_type): + return True + return False + + # ------------------------------------------------------------------ + # Public API (same signatures as GUIConnector) + # ------------------------------------------------------------------ + + async def get_screen_size(self) -> Optional[tuple[int, int]]: + """Get screen size using pyautogui.""" + try: + command = "print(pyautogui.size())" + result = await self.execute_python_command(command) + if result and result.get("status") == "success": + output = result.get("output", "") + match = re.search(r"width=(\d+).*height=(\d+)", output) + if match: + width = int(match.group(1)) + height = int(match.group(2)) + logger.info("Detected screen size: %dx%d", width, height) + return (width, height) + logger.warning("Failed to detect screen size, output: %s", result) + return None + except Exception as e: + logger.error("Failed to get screen size: %s", e) + return None + + async def get_screenshot(self) -> Optional[bytes]: + """Capture screenshot locally using ScreenshotHelper.""" + try: + async def _get(): + helper = self._get_screenshot_helper() + tmp_path = os.path.join( + tempfile.gettempdir(), f"screenshot_{uuid.uuid4().hex}.png" + ) + if helper.capture(tmp_path, with_cursor=True): + with open(tmp_path, "rb") as f: + data = f.read() + os.remove(tmp_path) + return data + else: + raise RuntimeError("Screenshot capture failed") + + return await self._retry_invoke("get_screenshot", _get) + except Exception as e: + logger.error("Failed to get screenshot: %s", e) + return None + + async def execute_python_command(self, command: str) -> Optional[Dict[str, Any]]: + """Execute a pyautogui Python command locally via subprocess.""" + try: + fixed_command = self._fix_pyautogui_less_than_bug(command) + full_command = self.pkgs_prefix.format(command=fixed_command) + + async def _execute(): + python_cmd = "python" if platform_name == "Windows" else "python3" + proc = await asyncio.create_subprocess_exec( + python_cmd, "-c", full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout_b, stderr_b = await asyncio.wait_for( + proc.communicate(), timeout=self.timeout + ) + stdout = stdout_b.decode("utf-8", errors="replace") if stdout_b else "" + stderr = stderr_b.decode("utf-8", errors="replace") if stderr_b else "" + returncode = proc.returncode or 0 + return { + "status": "success" if returncode == 0 else "error", + "output": stdout + stderr, + "error": stderr if returncode != 0 else "", + "returncode": returncode, + } + + return await self._retry_invoke("execute_python_command", _execute) + except Exception as e: + logger.error("Failed to execute command: %s", e) + return None + + async def execute_action( + self, action_type: str, parameters: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + """Execute a desktop action (same logic as GUIConnector).""" + parameters = parameters or {} + + if action_type in ["WAIT", "FAIL", "DONE"]: + return { + "status": "success", + "action_type": action_type, + "message": f"Control action {action_type} acknowledged", + } + + # Import action builder (same module used by GUIConnector) + from openspace.grounding.backends.gui.transport.actions import ( + build_pyautogui_command, + KEYBOARD_KEYS, + ) + + if action_type in ["PRESS", "KEY_DOWN", "KEY_UP"]: + key = parameters.get("key") + if key and key not in KEYBOARD_KEYS: + return { + "status": "error", + "action_type": action_type, + "error": f"Invalid key: {key}. Must be in supported keyboard keys.", + } + if action_type == "HOTKEY": + keys = parameters.get("keys", []) + invalid_keys = [k for k in keys if k not in KEYBOARD_KEYS] + if invalid_keys: + return { + "status": "error", + "action_type": action_type, + "error": f"Invalid keys: {invalid_keys}", + } + + command = build_pyautogui_command(action_type, parameters) + if command is None: + return { + "status": "error", + "action_type": action_type, + "error": f"Unsupported action type: {action_type}", + } + + result = await self.execute_python_command(command) + if result: + return { + "status": "success", + "action_type": action_type, + "parameters": parameters, + "result": result, + } + else: + return { + "status": "error", + "action_type": action_type, + "parameters": parameters, + "error": "Command execution failed", + } + + async def get_accessibility_tree( + self, max_depth: int = 5 + ) -> Optional[Dict[str, Any]]: + """Get accessibility tree locally.""" + try: + async def _get(): + helper = self._get_accessibility_helper() + return helper.get_tree(max_depth=max_depth) + + return await self._retry_invoke("get_accessibility_tree", _get) + except Exception as e: + logger.error("Failed to get accessibility tree: %s", e) + return None + + async def get_cursor_position(self) -> Optional[tuple[int, int]]: + """Get cursor position locally.""" + try: + async def _get(): + helper = self._get_screenshot_helper() + return helper.get_cursor_position() + + return await self._retry_invoke("get_cursor_position", _get) + except Exception as e: + logger.error("Failed to get cursor position: %s", e) + return None + + # ------------------------------------------------------------------ + # BaseConnector abstract methods + # ------------------------------------------------------------------ + + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + if name == "screenshot": + return await self.get_screenshot() + elif name == "accessibility_tree": + max_depth = params.get("max_depth", 5) if params else 5 + return await self.get_accessibility_tree(max_depth) + elif name == "cursor_position": + return await self.get_cursor_position() + else: + return await self.execute_action(name.upper(), params or {}) + + async def request(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + "LocalGUIConnector does not support raw HTTP requests" + ) + diff --git a/openspace/grounding/backends/mcp/__init__.py b/openspace/grounding/backends/mcp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c71a6c6ff75d931fdc9a382da22ccd9cbec11c --- /dev/null +++ b/openspace/grounding/backends/mcp/__init__.py @@ -0,0 +1,41 @@ +""" +MCP Backend for OpenSpace Grounding. + +This module provides the MCP (Model Context Protocol) backend implementation +for the grounding framework. It includes: + +- MCPProvider: Manages multiple MCP server sessions +- MCPSession: Handles individual MCP server connections +- MCPClient: High-level client for MCP server configuration +- MCPInstallerManager: Manages automatic installation of MCP dependencies +- MCPToolCache: Caches tool metadata to avoid starting servers on list_tools +""" + +from .provider import MCPProvider +from .session import MCPSession +from .client import MCPClient +from .installer import ( + MCPInstallerManager, + get_global_installer, + set_global_installer, + MCPDependencyError, + MCPCommandNotFoundError, + MCPInstallationCancelledError, + MCPInstallationFailedError, +) +from .tool_cache import MCPToolCache, get_tool_cache + +__all__ = [ + "MCPProvider", + "MCPSession", + "MCPClient", + "MCPInstallerManager", + "get_global_installer", + "set_global_installer", + "MCPDependencyError", + "MCPCommandNotFoundError", + "MCPInstallationCancelledError", + "MCPInstallationFailedError", + "MCPToolCache", + "get_tool_cache", +] \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/client.py b/openspace/grounding/backends/mcp/client.py new file mode 100644 index 0000000000000000000000000000000000000000..a6833e20d5c9c4083c1282dcee5fb9b9c018f0e2 --- /dev/null +++ b/openspace/grounding/backends/mcp/client.py @@ -0,0 +1,409 @@ +""" +Client for managing MCP servers and sessions. + +This module provides a high-level client that manages MCP servers, connectors, +and sessions from configuration. +""" +import asyncio +import warnings +from typing import Any, Optional + +from openspace.grounding.core.types import SandboxOptions +from openspace.config.utils import get_config_value, save_json_file, load_json_file +from .config import create_connector_from_config +from .session import MCPSession +from .installer import MCPInstallerManager, MCPDependencyError + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class MCPClient: + """Client for managing MCP servers and sessions. + + This class provides a unified interface for working with MCP servers, + handling configuration, connector creation, and session management. + """ + + def __init__( + self, + config: str | dict[str, Any] | None = None, + sandbox: bool = False, + sandbox_options: SandboxOptions | None = None, + timeout: float = 30.0, + sse_read_timeout: float = 300.0, + max_retries: int = 3, + retry_interval: float = 2.0, + installer: Optional[MCPInstallerManager] = None, + check_dependencies: bool = True, + tool_call_max_retries: int = 3, + tool_call_retry_delay: float = 1.0, + ) -> None: + """Initialize a new MCP client. + + Args: + config: Either a dict containing configuration or a path to a JSON config file. + If None, an empty configuration is used. + sandbox: Whether to use sandboxed execution mode for running MCP servers. + sandbox_options: Optional sandbox configuration options. + timeout: Timeout for operations in seconds (default: 30.0) + sse_read_timeout: SSE read timeout in seconds (default: 300.0) + max_retries: Maximum number of retry attempts for failed operations (default: 3) + retry_interval: Wait time between retries in seconds (default: 2.0) + installer: Optional installer manager for dependency installation + check_dependencies: Whether to check and install dependencies (default: True) + tool_call_max_retries: Maximum number of retries for tool calls (default: 3) + tool_call_retry_delay: Initial delay between tool call retries in seconds (default: 1.0) + """ + self.config: dict[str, Any] = {} + self.sandbox = sandbox + self.sandbox_options = sandbox_options + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.max_retries = max_retries + self.retry_interval = retry_interval + self.installer = installer + self.check_dependencies = check_dependencies + self.tool_call_max_retries = tool_call_max_retries + self.tool_call_retry_delay = tool_call_retry_delay + self.sessions: dict[str, MCPSession] = {} + self.active_sessions: list[str] = [] + + # Load configuration if provided + if config is not None: + if isinstance(config, str): + self.config = load_json_file(config) + else: + self.config = config + + def _get_mcp_servers(self) -> dict[str, Any]: + """Internal helper to get mcpServers configuration. + + Tries both 'mcpServers' and 'servers' keys for compatibility. + + Returns: + Dictionary of MCP server configurations, empty dict if none found. + """ + servers = get_config_value(self.config, "mcpServers", None) + if servers is None: + servers = get_config_value(self.config, "servers", {}) + return servers or {} + + @classmethod + def from_dict( + cls, + config: dict[str, Any], + sandbox: bool = False, + sandbox_options: SandboxOptions | None = None, + timeout: float = 30.0, + sse_read_timeout: float = 300.0, + max_retries: int = 3, + retry_interval: float = 2.0, + ) -> "MCPClient": + """Create a MCPClient from a dictionary. + + Args: + config: The configuration dictionary. + sandbox: Whether to use sandboxed execution mode for running MCP servers. + sandbox_options: Optional sandbox configuration options. + timeout: Timeout for operations in seconds (default: 30.0) + sse_read_timeout: SSE read timeout in seconds (default: 300.0) + max_retries: Maximum number of retry attempts (default: 3) + retry_interval: Wait time between retries in seconds (default: 2.0) + """ + return cls(config=config, sandbox=sandbox, sandbox_options=sandbox_options, + timeout=timeout, sse_read_timeout=sse_read_timeout, + max_retries=max_retries, retry_interval=retry_interval) + + @classmethod + def from_config_file( + cls, filepath: str, sandbox: bool = False, sandbox_options: SandboxOptions | None = None, + timeout: float = 30.0, sse_read_timeout: float = 300.0, + max_retries: int = 3, retry_interval: float = 2.0, + ) -> "MCPClient": + """Create a MCPClient from a configuration file. + + Args: + filepath: The path to the configuration file. + sandbox: Whether to use sandboxed execution mode for running MCP servers. + sandbox_options: Optional sandbox configuration options. + timeout: Timeout for operations in seconds (default: 30.0) + sse_read_timeout: SSE read timeout in seconds (default: 300.0) + max_retries: Maximum number of retry attempts (default: 3) + retry_interval: Wait time between retries in seconds (default: 2.0) + """ + return cls(config=load_json_file(filepath), sandbox=sandbox, sandbox_options=sandbox_options, + timeout=timeout, sse_read_timeout=sse_read_timeout, + max_retries=max_retries, retry_interval=retry_interval) + + def add_server( + self, + name: str, + server_config: dict[str, Any], + ) -> None: + """Add a server configuration. + + Args: + name: The name to identify this server. + server_config: The server configuration. + """ + mcp_servers = self._get_mcp_servers() + if "mcpServers" not in self.config: + self.config["mcpServers"] = {} + + self.config["mcpServers"][name] = server_config + logger.debug(f"Added MCP server configuration: {name}") + + def remove_server(self, name: str) -> None: + """Remove a server configuration. + + Args: + name: The name of the server to remove. + """ + mcp_servers = self._get_mcp_servers() + if name in mcp_servers: + # Remove from config + if "mcpServers" in self.config: + self.config["mcpServers"].pop(name, None) + elif "servers" in self.config: + self.config["servers"].pop(name, None) + + # If we removed an active session, remove it from active_sessions + if name in self.active_sessions: + self.active_sessions.remove(name) + + logger.debug(f"Removed MCP server configuration: {name}") + else: + logger.warning(f"Server '{name}' not found in configuration") + + def get_server_names(self) -> list[str]: + """Get the list of configured server names. + + Returns: + List of server names. + """ + return list(self._get_mcp_servers().keys()) + + def save_config(self, filepath: str) -> None: + """Save the current configuration to a file. + + Args: + filepath: The path to save the configuration to. + """ + save_json_file(self.config, filepath) + + async def create_session(self, server_name: str, auto_initialize: bool = True) -> MCPSession: + """Create a session for the specified server with retry logic. + + Args: + server_name: The name of the server to create a session for. + auto_initialize: Whether to automatically initialize the session. + + Returns: + The created MCPSession. + + Raises: + ValueError: If the specified server doesn't exist. + Exception: If session creation fails after all retries. + """ + # Check if session already exists + if server_name in self.sessions: + logger.debug(f"Session for server '{server_name}' already exists, returning existing session") + return self.sessions[server_name] + + # Get server config + servers = self._get_mcp_servers() + + if not servers: + warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2) + return None + + if server_name not in servers: + raise ValueError(f"Server '{server_name}' not found in config. Available: {list(servers.keys())}") + + server_config = servers[server_name] + + # Retry logic for session creation + last_exc: Exception | None = None + + for attempt in range(1, self.max_retries + 1): + try: + # Create connector with options (now async) + connector = await create_connector_from_config( + server_config, + server_name=server_name, + sandbox=self.sandbox, + sandbox_options=self.sandbox_options, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + installer=self.installer, + check_dependencies=self.check_dependencies, + tool_call_max_retries=self.tool_call_max_retries, + tool_call_retry_delay=self.tool_call_retry_delay, + ) + + # Create the session with proper initialization parameters + session = MCPSession( + connector=connector, + session_id=f"mcp-{server_name}", + auto_connect=True, + auto_initialize=False, # We'll handle initialization explicitly below + ) + + # Initialize if requested + if auto_initialize: + await session.initialize() + logger.debug(f"Initialized session for server '{server_name}'") + + # Store session + self.sessions[server_name] = session + + # Add to active sessions + if server_name not in self.active_sessions: + self.active_sessions.append(server_name) + + logger.info(f"Created session for MCP server '{server_name}' (attempt {attempt}/{self.max_retries})") + return session + + except MCPDependencyError as e: + # Don't retry dependency errors - they won't succeed on retry + # Error already shown to user by installer, just re-raise + logger.debug(f"Dependency error for server '{server_name}': {type(e).__name__}") + raise + except Exception as e: + last_exc = e + if attempt == self.max_retries: + break + + # Use info level for first attempt (common after fresh install), warning for subsequent + log_level = logger.info if attempt == 1 else logger.warning + log_level( + f"Failed to create session for server '{server_name}' (attempt {attempt}/{self.max_retries}): {e}, " + f"retrying in {self.retry_interval} seconds..." + ) + await asyncio.sleep(self.retry_interval) + + # All retries failed + error_msg = f"Failed to create session for server '{server_name}' after {self.max_retries} retries" + logger.error(error_msg) + raise last_exc or RuntimeError(error_msg) + + async def create_all_sessions( + self, + auto_initialize: bool = True, + ) -> dict[str, MCPSession]: + """Create sessions for all configured servers. + + Args: + auto_initialize: Whether to automatically initialize the sessions. + + Returns: + Dictionary mapping server names to their MCPSession instances. + + Warns: + UserWarning: If no servers are configured. + """ + servers = self._get_mcp_servers() + + if not servers: + warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2) + return {} + + # Create sessions for all servers (create_session already handles initialization) + logger.debug(f"Creating sessions for {len(servers)} servers") + for name in servers: + try: + await self.create_session(name, auto_initialize) + except Exception as e: + logger.error(f"Failed to create session for server '{name}': {e}") + + logger.info(f"Created {len(self.sessions)} MCP sessions") + return self.sessions + + def get_session(self, server_name: str) -> MCPSession: + """Get an existing session. + + Args: + server_name: The name of the server to get the session for. + If None, uses the first active session. + + Returns: + The MCPSession for the specified server. + + Raises: + ValueError: If no active sessions exist or the specified session doesn't exist. + """ + if server_name not in self.sessions: + raise ValueError(f"No session exists for server '{server_name}'") + + return self.sessions[server_name] + + def get_all_active_sessions(self) -> dict[str, MCPSession]: + """Get all active sessions. + + Returns: + Dictionary mapping server names to their MCPSession instances. + """ + return {name: self.sessions[name] for name in self.active_sessions if name in self.sessions} + + async def close_session(self, server_name: str) -> None: + """Close a session. + + Args: + server_name: The name of the server to close the session for. + + Raises: + ValueError: If no active sessions exist or the specified session doesn't exist. + """ + # Check if the session exists + if server_name not in self.sessions: + logger.warning(f"No session exists for server '{server_name}', nothing to close") + return + + # Get the session + session = self.sessions[server_name] + error_occurred = False + + try: + # Disconnect from the session + logger.debug(f"Closing session for server '{server_name}'") + await session.disconnect() + logger.info(f"Successfully closed session for server '{server_name}'") + except Exception as e: + error_occurred = True + logger.error(f"Error closing session for server '{server_name}': {e}") + finally: + # Remove the session regardless of whether disconnect succeeded + self.sessions.pop(server_name, None) + + # Remove from active_sessions + if server_name in self.active_sessions: + self.active_sessions.remove(server_name) + + if error_occurred: + logger.warning(f"Session for '{server_name}' removed from tracking despite disconnect error") + + async def close_all_sessions(self) -> None: + """Close all active sessions. + + This method ensures all sessions are closed even if some fail. + """ + # Get a list of all session names first to avoid modification during iteration + server_names = list(self.sessions.keys()) + errors = [] + + for server_name in server_names: + try: + logger.debug(f"Closing session for server '{server_name}'") + await self.close_session(server_name) + except Exception as e: + error_msg = f"Failed to close session for server '{server_name}': {e}" + logger.error(error_msg) + errors.append(error_msg) + + # Log summary if there were errors + if errors: + logger.error(f"Encountered {len(errors)} errors while closing sessions") + else: + logger.debug("All sessions closed successfully") diff --git a/openspace/grounding/backends/mcp/config.py b/openspace/grounding/backends/mcp/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8af3645c5d1eea99b4775c317f3ee03690045833 --- /dev/null +++ b/openspace/grounding/backends/mcp/config.py @@ -0,0 +1,132 @@ +""" +Configuration loader for MCP session. + +This module provides functionality to load MCP configuration from JSON files. +""" + +from typing import Any, Optional + +from openspace.grounding.core.types import SandboxOptions +from openspace.config.utils import get_config_value +from .transport.connectors import ( + MCPBaseConnector, + HttpConnector, + SandboxConnector, + StdioConnector, + WebSocketConnector, +) +from .transport.connectors.utils import is_stdio_server +from .installer import MCPInstallerManager + +# Import E2BSandbox +try: + from openspace.grounding.core.security import E2BSandbox + E2B_AVAILABLE = True +except ImportError: + E2BSandbox = None + E2B_AVAILABLE = False + +async def create_connector_from_config( + server_config: dict[str, Any], + server_name: str = "unknown", + sandbox: bool = False, + sandbox_options: SandboxOptions | None = None, + timeout: float = 30.0, + sse_read_timeout: float = 300.0, + installer: Optional[MCPInstallerManager] = None, + check_dependencies: bool = True, + tool_call_max_retries: int = 3, + tool_call_retry_delay: float = 1.0, +) -> MCPBaseConnector: + """Create a connector based on server configuration. + + Args: + server_config: The server configuration section + server_name: Name of the MCP server (for display purposes) + sandbox: Whether to use sandboxed execution mode for running MCP servers. + sandbox_options: Optional sandbox configuration options. + timeout: Timeout for operations in seconds (default: 30.0) + sse_read_timeout: SSE read timeout in seconds (default: 300.0) + installer: Optional installer manager for dependency installation + check_dependencies: Whether to check and install dependencies (default: True) + tool_call_max_retries: Maximum number of retries for tool calls (default: 3) + tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0) + + Returns: + A configured connector instance + + Raises: + RuntimeError: If dependencies are not installed and user declines installation + """ + + # Get original command and args from config + original_command = get_config_value(server_config, "command") + original_args = get_config_value(server_config, "args", []) + + # Check and install dependencies if needed (only for stdio servers) + if is_stdio_server(server_config) and check_dependencies: + # Use provided installer or get global instance + if installer is None: + from .installer import get_global_installer + installer = get_global_installer() + + # Ensure dependencies are installed (using original command/args) + await installer.ensure_dependencies(server_name, original_command, original_args) + + # Stdio connector (command-based) + if is_stdio_server(server_config) and not sandbox: + return StdioConnector( + command=get_config_value(server_config, "command"), + args=get_config_value(server_config, "args"), + env=get_config_value(server_config, "env", None), + ) + + # Sandboxed connector + elif is_stdio_server(server_config) and sandbox: + if not E2B_AVAILABLE: + raise ImportError( + "E2B sandbox support not available. Please install e2b-code-interpreter: " + "'pip install e2b-code-interpreter'" + ) + + # Create E2B sandbox instance + _sandbox_options = sandbox_options or {} + e2b_sandbox = E2BSandbox(_sandbox_options) + + # Extract timeout values from sandbox_options or use defaults + connector_timeout = _sandbox_options.get("timeout", timeout) + connector_sse_timeout = _sandbox_options.get("sse_read_timeout", sse_read_timeout) + + # Create and return sandbox connector + return SandboxConnector( + sandbox=e2b_sandbox, + command=get_config_value(server_config, "command"), + args=get_config_value(server_config, "args"), + env=get_config_value(server_config, "env", None), + supergateway_command=_sandbox_options.get("supergateway_command", "npx -y supergateway"), + port=_sandbox_options.get("port", 3000), + timeout=connector_timeout, + sse_read_timeout=connector_sse_timeout, + ) + + # HTTP connector + elif "url" in server_config: + return HttpConnector( + base_url=get_config_value(server_config, "url"), + headers=get_config_value(server_config, "headers", None), + auth_token=get_config_value(server_config, "auth_token", None), + timeout=timeout, + sse_read_timeout=sse_read_timeout, + tool_call_max_retries=tool_call_max_retries, + tool_call_retry_delay=tool_call_retry_delay, + ) + + # WebSocket connector + elif "ws_url" in server_config: + return WebSocketConnector( + url=get_config_value(server_config, "ws_url"), + headers=get_config_value(server_config, "headers", None), + auth_token=get_config_value(server_config, "auth_token", None), + ) + + raise ValueError("Cannot determine connector type from config") \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/installer.py b/openspace/grounding/backends/mcp/installer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e2c825079780f33553fad064f3eede1170c4f79 --- /dev/null +++ b/openspace/grounding/backends/mcp/installer.py @@ -0,0 +1,697 @@ +import asyncio +import sys +import shutil +from typing import Callable, Awaitable, Optional, Dict, List +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +PromptFunc = Callable[[str], Awaitable[bool]] + +# Global lock to prevent concurrent user prompts +_prompt_lock = asyncio.Lock() + + +class MCPDependencyError(RuntimeError): + """Base exception for MCP dependency errors.""" + pass + + +class MCPCommandNotFoundError(MCPDependencyError): + """Raised when a required command is not available.""" + pass + + +class MCPInstallationCancelledError(MCPDependencyError): + """Raised when user cancels installation.""" + pass + + +class MCPInstallationFailedError(MCPDependencyError): + """Raised when installation fails.""" + pass + + +class Colors: + RESET = "\033[0m" + BOLD = "\033[1m" + RED = "\033[91m" + YELLOW = "\033[93m" + GREEN = "\033[92m" + CYAN = "\033[96m" + GRAY = "\033[90m" + WHITE = "\033[97m" + BLUE = "\033[94m" + + +class MCPInstallerManager: + """ + MCP dependencies package installer manager. + + Responsible for detecting if the MCP server dependencies are installed, and if not, asking the user whether to install them. + """ + + def __init__(self, prompt: PromptFunc | None = None, auto_install: bool = False, verbose: bool = False): + """Initialize the installer manager. + + Args: + prompt: Custom user prompt function, if None, the default CLI prompt is used + auto_install: If True, automatically install dependencies without asking the user + verbose: If True, show detailed installation logs; if False, only show progress indicator + """ + self._prompt: PromptFunc | None = prompt or self._default_cli_prompt + self._auto_install = auto_install + self._verbose = verbose + self._installed_cache: Dict[str, bool] = {} # Cache for checked packages + self._failed_installations: Dict[str, str] = {} # Track failed installations to avoid retry + + async def _default_cli_prompt(self, message: str) -> bool: + """Default CLI prompt function (called within lock by ensure_dependencies).""" + from openspace.utils.display import print_separator, colorize + + print() + print_separator(70, 'c', 2) + print(f" {colorize('MCP dependencies installation prompt', color=Colors.BLUE, bold=True)}") + print_separator(70, 'c', 2) + print(f" {message}") + print_separator(70, 'gr', 2) + print(f" {colorize('[y/yes]', color=Colors.GREEN)} Install | {colorize('[n/no]', color=Colors.RED)} Cancel") + print_separator(70, 'gr', 2) + print(f" {colorize('Your choice:', bold=True)} ", end="", flush=True) + + answer = await asyncio.get_running_loop().run_in_executor(None, sys.stdin.readline) + response = answer.strip().lower() in {"y", "yes"} + + if response: + print(f"{Colors.GREEN}✓ Installation confirmed{Colors.RESET}\n") + else: + print(f"{Colors.RED}✗ Installation cancelled{Colors.RESET}\n") + + return response + + async def _ask_user(self, message: str) -> bool: + """Ask the user whether to install.""" + if self._auto_install: + logger.info("Automatic installation mode enabled, will automatically install dependencies") + return True + + if self._prompt: + try: + return await self._prompt(message) + except Exception as e: + logger.error(f"Error asking user: {e}") + return False + return False + + def _check_command_available(self, command: str) -> bool: + """Check if the command is available. + + Args: + command: The command to check (e.g. "npx", "uvx") + + Returns: + bool: Whether the command is available + """ + return shutil.which(command) is not None + + async def _check_package_installed(self, command: str, args: List[str]) -> bool: + """Check if the package is installed. + + Args: + command: The command to check (e.g. "npx", "uvx") + args: The arguments list + + Returns: + bool: Whether the package is installed + """ + # Build cache key + cache_key = f"{command}:{':'.join(args)}" + + # Check cache + if cache_key in self._installed_cache: + return self._installed_cache[cache_key] + + # For different types of commands, use different check methods + try: + if command == "npx": + # For npx, check if the npm package exists + package_name = self._extract_npm_package(args) + if package_name: + result = await self._check_npm_package(package_name) + self._installed_cache[cache_key] = result + return result + elif command == "uvx": + # For uvx, check if the Python package exists + package_name = self._extract_python_package(args) + if package_name: + result = await self._check_python_package(package_name) + self._installed_cache[cache_key] = result + return result + elif command == "uv": + # For "uv run --with package ...", check if the Python package exists + package_name = self._extract_uv_package(args) + if package_name: + result = await self._check_uv_pip_package(package_name) + self._installed_cache[cache_key] = result + return result + except Exception as e: + logger.debug(f"Error checking package installation status: {e}") + + # Default to assuming not installed + return False + + def _extract_npm_package(self, args: List[str]) -> Optional[str]: + """Extract package name from npx arguments. + + Args: + args: npx arguments list, e.g. ["-y", "mcp-excalidraw-server"] or ["bazi-mcp"] + + Returns: + Package name (without version tag) or None + """ + for i, arg in enumerate(args): + # Skip option parameters + if arg.startswith("-"): + continue + + # Found package name, now strip version tag + package_name = arg + + # Handle scoped packages: @scope/package@version -> @scope/package + if package_name.startswith("@"): + # Scoped package like @rtuin/mcp-mermaid-validator@latest + parts = package_name.split("/", 1) + if len(parts) == 2: + scope = parts[0] + name_with_version = parts[1] + # Remove version tag from name part (e.g., "pkg@latest" -> "pkg") + name = name_with_version.split("@")[0] if "@" in name_with_version else name_with_version + return f"{scope}/{name}" + return package_name + else: + # Regular package like mcp-deepwiki@latest -> mcp-deepwiki + return package_name.split("@")[0] if "@" in package_name else package_name + + return None + + def _extract_python_package(self, args: List[str]) -> Optional[str]: + """Extract package name from uvx arguments. + + Args: + args: uvx arguments list, e.g. ["--from", "office-powerpoint-mcp-server", "ppt_mcp_server"] + or ["--with", "mcp==1.9.0", "sitemap-mcp-server"] + or ["arxiv-mcp-server", "--storage-path", "./path"] + + Returns: + Package name or None + """ + # Find --from parameter (this is the package to install) + for i, arg in enumerate(args): + if arg == "--from" and i + 1 < len(args): + return args[i + 1] + + # Skip option flags and their values, find the main package (FIRST positional arg) + # Options that take a value: --with, --python, --from, --storage-path, etc. + options_with_value = {"--with", "--from", "--python", "-p", "--storage-path"} + skip_next = False + + for arg in args: + if skip_next: + skip_next = False + continue + if arg in options_with_value: + skip_next = True + continue + if arg.startswith("-"): + # Other flags without values (or unknown options with values) + # Also skip the next arg if it looks like an option value (doesn't start with -) + continue + # First non-option argument is the package name + return arg + + return None + + def _extract_uv_package(self, args: List[str]) -> Optional[str]: + """Extract package name from uv run arguments. + + Args: + args: uv arguments list, e.g. ["run", "--with", "biomcp-python", "biomcp", "run"] + + Returns: + Package name or None + """ + # Find --with parameter (this specifies the package to install) + for i, arg in enumerate(args): + if arg == "--with" and i + 1 < len(args): + package_name = args[i + 1] + # Remove version specifier if present (e.g., "mcp==1.9.0" -> "mcp") + if "==" in package_name: + return package_name.split("==")[0] + if ">=" in package_name: + return package_name.split(">=")[0] + return package_name + + return None + + async def _check_npm_package(self, package_name: str) -> bool: + """Check if the npm package is globally installed. + + Args: + package_name: npm package name + + Returns: + bool: Whether the npm package is installed + """ + try: + process = await asyncio.create_subprocess_exec( + "npm", "list", "-g", package_name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + # npm list returns 0 if the package is installed + return process.returncode == 0 + except Exception as e: + logger.debug(f"Error checking npm package {package_name}: {e}") + return False + + async def _check_python_package(self, package_name: str) -> bool: + """Check if the Python package is installed as a uvx tool. + + uvx tools are installed in ~/.local/share/uv/tools/ directory, + not in the current pip environment. + + Args: + package_name: Python package/tool name + + Returns: + bool: Whether the uvx tool is installed + """ + import os + from pathlib import Path + + # Strip version specifier if present (e.g., "mcp==1.9.0" -> "mcp") + clean_name = package_name.split("==")[0].split(">=")[0].split("<=")[0].split(">")[0].split("<")[0] + + # Check if uvx tool exists in the standard uv tools directory + uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools" + tool_dir = uv_tools_dir / clean_name + + if tool_dir.exists(): + logger.debug(f"uvx tool '{clean_name}' found at {tool_dir}") + return True + + # Fallback: try running uvx with --help to check if it's available + try: + process = await asyncio.create_subprocess_exec( + "uvx", clean_name, "--help", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + # Just wait briefly, don't need the full output + try: + await asyncio.wait_for(process.communicate(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + # If it didn't error immediately, the tool likely exists + return process.returncode == 0 + except Exception as e: + logger.debug(f"Error checking uvx tool {clean_name}: {e}") + + return False + + async def _check_uv_pip_package(self, package_name: str) -> bool: + """Check if a Python package is installed via uv pip. + + Args: + package_name: Python package name + + Returns: + bool: Whether the package is installed + """ + # Strip version specifier if present + clean_name = package_name.split("==")[0].split(">=")[0].split("<=")[0].split(">")[0].split("<")[0] + + try: + # Try using uv pip show to check if package is installed + process = await asyncio.create_subprocess_exec( + "uv", "pip", "show", clean_name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + logger.debug(f"uv pip package '{clean_name}' found") + return True + except Exception as e: + logger.debug(f"Error checking uv pip package {clean_name}: {e}") + + # Fallback: check with regular pip + try: + process = await asyncio.create_subprocess_exec( + "pip", "show", clean_name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + return process.returncode == 0 + except Exception as e: + logger.debug(f"Error checking pip package {clean_name}: {e}") + + return False + + async def _install_package(self, command: str, args: List[str], use_sudo: bool = False) -> bool: + """Execute the install command. + + Args: + command: The command to execute (e.g. "npx", "uvx") + args: The arguments list + use_sudo: Whether to use sudo for installation + + Returns: + bool: Whether the installation is successful + """ + install_command = self._get_install_command(command, args) + + if not install_command: + logger.error("Cannot determine install command") + return False + + # Add sudo if requested + if use_sudo: + install_command = ["sudo"] + install_command + + logger.info(f"Executing install command: {' '.join(install_command)}") + + try: + # For sudo commands, always show verbose output so password prompt is visible + if self._verbose or use_sudo: + # Verbose mode: show all installation logs + from openspace.utils.display import print_separator, colorize + + print_separator(70, 'c', 2) + if use_sudo: + print(f" {colorize('Installing with administrator privileges...', color=Colors.BLUE)}") + print(f" {colorize('>> You will be prompted for your password below <<', color=Colors.YELLOW)}") + else: + print(f" {colorize('Installing dependencies...', color=Colors.BLUE)}") + print(f" {colorize('Command: ' + ' '.join(install_command), color=Colors.GRAY)}") + print_separator(70, 'c', 2) + print() + + # For sudo, don't redirect stdin so password prompt works + if use_sudo: + process = await asyncio.create_subprocess_exec( + *install_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + stdin=None # Let sudo use terminal for password + ) + else: + process = await asyncio.create_subprocess_exec( + *install_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT + ) + + # Real-time output of installation logs + output_lines = [] + while True: + line = await process.stdout.readline() + if not line: + break + line_str = line.decode().rstrip() + output_lines.append(line_str) + print(f"{Colors.GRAY}{line_str}{Colors.RESET}") + + await process.wait() + full_output = '\n'.join(output_lines) + else: + # Quiet mode: only show progress indicator + print(f"\n{Colors.BLUE}Installing dependencies...{Colors.RESET} ", end="", flush=True) + + process = await asyncio.create_subprocess_exec( + *install_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + # Show spinner animation while installing + spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + spinner_idx = 0 + + while True: + try: + await asyncio.wait_for(process.wait(), timeout=0.1) + break + except asyncio.TimeoutError: + print(f"\r{Colors.BLUE}Installing dependencies...{Colors.RESET} {Colors.CYAN}{spinner[spinner_idx]}{Colors.RESET}", end="", flush=True) + spinner_idx = (spinner_idx + 1) % len(spinner) + + # Clear the spinner line + print(f"\r{' ' * 100}\r", end="", flush=True) + + # Collect output + stdout, stderr = await process.communicate() + full_output = (stdout or stderr).decode() if (stdout or stderr) else "" + + if process.returncode == 0: + print(f"{Colors.GREEN}✓ Dependencies installed successfully{Colors.RESET}") + if not use_sudo: + print(f"{Colors.GRAY}(Note: First connection may take a moment to initialize){Colors.RESET}") + # Update cache + cache_key = f"{command}:{':'.join(args)}" + self._installed_cache[cache_key] = True + return True + else: + # Check if it's a permission error + is_permission_error = "EACCES" in full_output or "permission denied" in full_output.lower() + + if is_permission_error and not use_sudo: + print(f"\n{Colors.YELLOW}Permission denied{Colors.RESET}") + print(f"{Colors.GRAY}The installation requires administrator privileges.{Colors.RESET}\n") + + # Ask user if they want to use sudo + message = ( + f"\n{Colors.WHITE}Administrator privileges required{Colors.RESET}\n\n" + f"Command: {Colors.GRAY}{' '.join(install_command)}{Colors.RESET}\n\n" + f"{Colors.YELLOW}Do you want to retry with sudo (requires password)?{Colors.RESET}" + ) + + if await self._ask_user(message): + # No extra print needed, the verbose mode will show clear instructions + return await self._install_package(command, args, use_sudo=True) + else: + print(f"\n{Colors.RED}✗ Installation cancelled{Colors.RESET}") + return False + else: + print(f"{Colors.RED}✗ Dependencies installation failed (return code: {process.returncode}){Colors.RESET}") + # Show error output if not already shown + if not self._verbose and full_output: + # Limit error output to last 20 lines + error_lines = full_output.split('\n') + if len(error_lines) > 20: + error_lines = ['...(truncated)...'] + error_lines[-20:] + print(f"{Colors.GRAY}Error output:\n{chr(10).join(error_lines)}{Colors.RESET}") + + # Add general guidance for manual installation + print(f"\n{Colors.YELLOW}Tip:{Colors.RESET} {Colors.GRAY}If automatic installation fails, please refer to the") + print(f"official documentation of the MCP server for manual installation instructions.{Colors.RESET}\n") + + return False + + except Exception as e: + logger.error(f"Error installing dependencies: {e}") + print(f"{Colors.RED}✗ Error occurred during installation: {e}{Colors.RESET}") + return False + + def _get_install_command(self, command: str, args: List[str]) -> Optional[List[str]]: + """Generate install command based on command type. + + Args: + command: The command to execute (e.g. "npx", "uvx", "uv") + args: The original arguments list + + Returns: + Install command list or None + """ + if command == "npx": + package_name = self._extract_npm_package(args) + if package_name: + return ["npm", "install", "-g", package_name] + elif command == "uvx": + package_name = self._extract_python_package(args) + if package_name: + return ["pip", "install", package_name] + elif command == "uv": + # Handle "uv run --with package_name ..." format + package_name = self._extract_uv_package(args) + if package_name: + return ["uv", "pip", "install", package_name] + + return None + + async def ensure_dependencies( + self, + server_name: str, + command: str, + args: List[str] + ) -> bool: + """Ensure the dependencies of the MCP server are installed. + + This method checks if the dependencies are installed, and if not, asks the user whether to install them. + + Args: + server_name: MCP server name (for display purposes) + command: The command to execute (e.g. "npx", "uvx") + args: The arguments list + + Returns: + bool: Whether the dependencies are installed (installed or successfully installed) + + Raises: + RuntimeError: When the command is not available or the user refuses to install + """ + # Use lock to ensure entire installation process is atomic + async with _prompt_lock: + return await self._ensure_dependencies_impl(server_name, command, args) + + async def _ensure_dependencies_impl( + self, + server_name: str, + command: str, + args: List[str] + ) -> bool: + """Internal implementation of ensure_dependencies (called within lock).""" + # Skip dependency checking for direct script execution commands + # These commands run scripts directly and don't need package installation + SKIP_COMMANDS = {"node", "python", "python3", "bash", "sh", "deno", "bun"} + + if command.lower() in SKIP_COMMANDS: + logger.debug(f"Skipping dependency check for direct script execution command: {command}") + return True + + # Skip dependency checking for GitHub-based npx packages + # These packages are handled directly by npx which downloads, builds, and runs them + # npm install -g doesn't work properly for GitHub packages that require building + if command == "npx": + package_name = self._extract_npm_package(args) + if package_name and package_name.startswith("github:"): + logger.debug(f"Skipping dependency check for GitHub-based npx package: {package_name}") + return True + + # Check if this server has already failed installation + cache_key = f"{server_name}:{command}:{':'.join(args)}" + if cache_key in self._failed_installations: + error_msg = self._failed_installations[cache_key] + logger.debug(f"Skipping installation for '{server_name}' - previously failed") + raise MCPDependencyError(error_msg) + + # Special handling for uvx - check if uv is installed + if command == "uvx": + if not self._check_command_available("uv"): + # Only show once to user, no verbose logging + print(f"\n{Colors.RED}✗ Server '{server_name}' requires 'uv' to be installed{Colors.RESET}") + print(f"{Colors.YELLOW}Please install uv first:") + print(f" • macOS/Linux: curl -LsSf https://astral.sh/uv/install.sh | sh") + print(f" • Or with pip: pip install uv") + print(f" • Or with brew: brew install uv{Colors.RESET}\n") + + error_msg = f"uvx requires 'uv' to be installed (server: {server_name})" + self._failed_installations[cache_key] = error_msg + raise MCPCommandNotFoundError(error_msg) + + # Check if the command is available + if not self._check_command_available(command): + error_msg = ( + f"Command '{command}' is not available.\n" + f"Please install the necessary tools first." + ) + logger.error(error_msg) + self._failed_installations[cache_key] = error_msg + raise MCPCommandNotFoundError(error_msg) + + # Check if the package is installed + if await self._check_package_installed(command, args): + logger.debug(f"The dependencies of the MCP server '{server_name}' are installed") + return True + + # Extract package name for display + if command == "npx": + package_name = self._extract_npm_package(args) + package_type = "npm" + elif command == "uvx": + package_name = self._extract_python_package(args) + package_type = "Python" + elif command == "uv": + package_name = self._extract_uv_package(args) + package_type = "Python" + else: + package_name = f"{command} {' '.join(args)}" + package_type = "package" + + # Build the message for displaying the install command + install_cmd = self._get_install_command(command, args) + + # If we can't determine an install command, show helpful message + if not install_cmd: + print(f"\n{Colors.YELLOW}Cannot automatically install dependencies for '{server_name}'{Colors.RESET}") + print(f"{Colors.GRAY}Command: {command} {' '.join(args)}{Colors.RESET}") + print(f"\n{Colors.WHITE}This MCP server may require manual installation or configuration.{Colors.RESET}") + print(f"{Colors.GRAY}Please refer to the MCP server's official documentation for installation instructions.{Colors.RESET}\n") + + error_msg = f"Manual installation required for '{server_name}' (command: {command})" + self._failed_installations[cache_key] = error_msg + raise MCPDependencyError(error_msg) + + install_cmd_str = ' '.join(install_cmd) + + # Build the message + message = ( + f"\n{Colors.WHITE}The MCP server needs to install dependencies{Colors.RESET}\n\n" + f"Server name: {Colors.CYAN}{server_name}{Colors.RESET}\n" + f"Package type: {Colors.YELLOW}{package_type}{Colors.RESET}\n" + f"Package name: {Colors.YELLOW}{package_name or 'Unknown'}{Colors.RESET}\n" + f"Install command: {Colors.GRAY}{install_cmd_str}{Colors.RESET}\n\n" + f"{Colors.YELLOW}Whether to install this dependency package?{Colors.RESET}" + ) + + # Ask the user + if not await self._ask_user(message): + error_msg = f"User cancelled the dependency installation for '{server_name}'" + logger.warning(error_msg) + self._failed_installations[cache_key] = error_msg + raise MCPInstallationCancelledError(error_msg) + + # Execute installation + success = await self._install_package(command, args) + + if not success: + error_msg = f"Dependency installation failed for '{server_name}'" + logger.error(error_msg) + self._failed_installations[cache_key] = error_msg + raise MCPInstallationFailedError(error_msg) + + return True + + +# Global singleton instance +_global_installer: Optional[MCPInstallerManager] = None + + +def get_global_installer() -> MCPInstallerManager: + """Get the global installer manager instance.""" + global _global_installer + if _global_installer is None: + _global_installer = MCPInstallerManager() + return _global_installer + +def set_global_installer(installer: MCPInstallerManager) -> None: + """Set the global installer manager instance.""" + global _global_installer + _global_installer = installer \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/provider.py b/openspace/grounding/backends/mcp/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..db4cadbc82c1608bc2eed9e12c955302ee02d5c9 --- /dev/null +++ b/openspace/grounding/backends/mcp/provider.py @@ -0,0 +1,473 @@ +""" +MCP Provider implementation. + +This module provides a provider for managing MCP server sessions. +""" +import asyncio +from typing import Dict, List, Optional + +from openspace.grounding.backends.mcp.session import MCPSession +from openspace.grounding.core.provider import Provider +from openspace.grounding.core.types import SessionConfig, BackendType, ToolSchema +from openspace.grounding.backends.mcp.client import MCPClient +from openspace.grounding.backends.mcp.installer import MCPInstallerManager, MCPDependencyError +from openspace.grounding.backends.mcp.tool_cache import get_tool_cache +from openspace.grounding.backends.mcp.tool_converter import _sanitize_mcp_schema +from openspace.grounding.core.tool import BaseTool, RemoteTool +from openspace.utils.logging import Logger +from openspace.config.utils import get_config_value + +logger = Logger.get_logger(__name__) + + +class MCPProvider(Provider[MCPSession]): + """ + MCP Provider manages multiple MCP server sessions. + + Each MCP server defined in config corresponds to one session. + The provider handles lazy/eager session creation and tool aggregation. + """ + + def __init__(self, config: Dict | None = None, installer: Optional[MCPInstallerManager] = None): + """Initialize MCP Provider. + + Args: + config: Configuration dict with MCP server definitions. + Example: {"mcpServers": {"server1": {...}, "server2": {...}}} + installer: Optional installer manager for dependency installation + """ + super().__init__(BackendType.MCP, config) + + # Extract MCP-specific configuration + sandbox = get_config_value(config, "sandbox", False) + timeout = get_config_value(config, "timeout", 30) + sse_read_timeout = get_config_value(config, "sse_read_timeout", 300.0) + max_retries = get_config_value(config, "max_retries", 3) + retry_interval = get_config_value(config, "retry_interval", 2.0) + check_dependencies = get_config_value(config, "check_dependencies", True) + auto_install = get_config_value(config, "auto_install", False) + # Tool call retry settings (for transient errors like 400, 500, etc.) + tool_call_max_retries = get_config_value(config, "tool_call_max_retries", 3) + tool_call_retry_delay = get_config_value(config, "tool_call_retry_delay", 1.0) + + # Create sandbox options if sandbox is enabled + sandbox_options = None + if sandbox: + sandbox_options = { + "timeout": timeout, + "sse_read_timeout": sse_read_timeout, + } + + # Create installer with auto_install setting if not provided + if installer is None and auto_install: + installer = MCPInstallerManager(auto_install=True) + + # Initialize MCPClient with configuration + self._client = MCPClient( + config=config or {}, + sandbox=sandbox, + sandbox_options=sandbox_options, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + max_retries=max_retries, + retry_interval=retry_interval, + installer=installer, + check_dependencies=check_dependencies, + tool_call_max_retries=tool_call_max_retries, + tool_call_retry_delay=tool_call_retry_delay, + ) + + # Map server name to session for quick lookup + self._server_sessions: Dict[str, MCPSession] = {} + + async def initialize(self) -> None: + """Initialize the MCP provider. + + If config["eager_sessions"] is True, creates sessions for all configured servers. + Otherwise, sessions are created lazily on first access. + """ + if self.is_initialized: + return + + # config can be dict or Pydantic model, use utility function + eager = get_config_value(self.config, "eager_sessions", False) + if eager: + servers = self.list_servers() + logger.debug(f"Eagerly initializing {len(servers)} MCP server sessions") + for srv in servers: + if srv not in self._server_sessions: + cfg = SessionConfig( + session_name=f"mcp-{srv}", + backend_type=BackendType.MCP, + connection_params={"server": srv}, + ) + await self.create_session(cfg) + + self.is_initialized = True + logger.info( + f"MCPProvider initialized with {len(self.list_servers())} servers (eager={eager})" + ) + + def list_servers(self) -> List[str]: + """Return all configured MCP server names from MCPClient config. + + Returns: + List of server names + """ + return self._client.get_server_names() + + async def create_session(self, session_config: SessionConfig) -> MCPSession: + """Create a new MCP session for a specific server. + + Args: + session_config: Must contain 'server' in connection_params + + Returns: + MCPSession instance + + Raises: + ValueError: If 'server' not in connection_params + Exception: If session creation or initialization fails + """ + server = get_config_value(session_config.connection_params, "server") + if not server: + raise ValueError("MCPProvider.create_session requires 'server' in connection_params") + + # Generate session_id: mcp- + session_id = f"{self.backend_type.value}-{server}" + + # Check if session already exists + if server in self._server_sessions: + logger.debug(f"Session for server '{server}' already exists, returning existing session") + return self._server_sessions[server] + + # Create session through MCPClient + try: + logger.debug(f"Creating new session for MCP server: {server}") + session = await self._client.create_session(server, auto_initialize=True) + session.session_id = session_id + + # Store in both maps + self._server_sessions[server] = session + self._sessions[session_id] = session + + logger.info(f"Created MCP session '{session_id}' for server '{server}'") + return session + except MCPDependencyError as e: + # Dependency errors already shown to user, just debug log + logger.debug(f"Dependency error for server '{server}': {type(e).__name__}") + raise + except Exception as e: + logger.error(f"Failed to create session for server '{server}': {e}") + raise + + async def close_session(self, session_name: str) -> None: + """Close an MCP session by session name. + + Args: + session_name: Session name in format 'mcp-' + """ + # Parse server name from session_name (format: mcp-) + try: + prefix, server_name = session_name.split("-", 1) + if prefix != self.backend_type.value: + raise ValueError(f"Invalid MCP session name format: {session_name}, expected 'mcp-'") + except ValueError as e: + logger.warning(f"Invalid session_name format: {session_name} - {e}") + return + + # Check if session exists + if session_name not in self._sessions and server_name not in self._server_sessions: + logger.warning(f"Session '{session_name}' not found, nothing to close") + return + + error_occurred = False + try: + logger.debug(f"Closing MCP session '{session_name}' (server: {server_name})") + await self._client.close_session(server_name) + logger.info(f"Successfully closed MCP session '{session_name}'") + except Exception as e: + error_occurred = True + logger.error(f"Error closing MCP session '{session_name}': {e}") + finally: + # Clean up both maps regardless of errors + self._server_sessions.pop(server_name, None) + self._sessions.pop(session_name, None) + + if error_occurred: + logger.warning(f"Session '{session_name}' removed from tracking despite close error") + + async def list_tools(self, session_name: str | None = None, use_cache: bool = True) -> List[BaseTool]: + """List tools from MCP sessions. + + Args: + session_name: If provided, only list tools from that session. + If None, list tools from all sessions. + use_cache: If True, try to load from cache first (no server startup). + If False, start servers and get live tools. + + Returns: + List of BaseTool instances + """ + await self.ensure_initialized() + + # Case 1: List tools from specific session (always live, no cache) + if session_name: + sess = self._sessions.get(session_name) + if sess: + try: + tools = await sess.list_tools() + server_name = session_name.replace(f"{self.backend_type.value}-", "", 1) + for tool in tools: + tool.bind_runtime_info( + backend=self.backend_type, + session_name=session_name, + server_name=server_name, + ) + return tools + except Exception as e: + logger.error(f"Error listing tools from session '{session_name}': {e}") + return [] + else: + logger.warning(f"Session '{session_name}' not found") + return [] + + # Case 2: List tools from all servers + # Try cache first if enabled + if use_cache: + cache = get_tool_cache() + if cache.has_cache(): + tools = self._load_tools_from_cache() + if tools: + logger.info(f"Loaded {len(tools)} tools from cache (no server startup)") + return tools + + # No cache or cache disabled, start servers + return await self._list_tools_live() + + def _load_tools_from_cache(self) -> List[BaseTool]: + """Load tools from cache file without starting servers. + + Priority: + 1. Try to load from sanitized cache (mcp_tool_cache_sanitized.json) + 2. If not exists, load from raw cache and sanitize, then save sanitized version + """ + cache = get_tool_cache() + config_servers = self.list_servers() + + # Try sanitized cache first + if cache.has_sanitized_cache(): + logger.debug("Loading from sanitized cache") + all_cached_tools = cache.get_all_sanitized_tools() + return self._build_tools_from_cache(all_cached_tools, config_servers) + + # Fall back to raw cache, sanitize and save + if cache.has_cache(): + logger.info("Sanitized cache not found, building from raw cache...") + all_cached_tools = cache.get_all_tools() + sanitized_servers = self._sanitize_and_save_cache(all_cached_tools, cache) + return self._build_tools_from_cache(sanitized_servers, config_servers) + + return [] + + def _sanitize_and_save_cache( + self, + raw_tools: Dict[str, List[Dict]], + cache + ) -> Dict[str, List[Dict]]: + """Sanitize raw cache and save to sanitized cache file.""" + sanitized_servers: Dict[str, List[Dict]] = {} + + for server_name, tool_list in raw_tools.items(): + sanitized_tools = [] + for tool_meta in tool_list: + raw_params = tool_meta.get("parameters", {}) + sanitized_params = _sanitize_mcp_schema(raw_params) + sanitized_tools.append({ + "name": tool_meta["name"], + "description": tool_meta.get("description", ""), + "parameters": sanitized_params, + }) + sanitized_servers[server_name] = sanitized_tools + + # Save sanitized cache for future use + cache.save_sanitized(sanitized_servers) + logger.info(f"Created sanitized cache with {len(sanitized_servers)} servers") + + return sanitized_servers + + def _build_tools_from_cache( + self, + all_cached_tools: Dict[str, List[Dict]], + config_servers: List[str] + ) -> List[BaseTool]: + """Build BaseTool instances from cached tool metadata.""" + tools: List[BaseTool] = [] + + for server_name in config_servers: + tool_list = all_cached_tools.get(server_name) + if not tool_list: + continue + + session_name = f"{self.backend_type.value}-{server_name}" + for tool_meta in tool_list: + schema = ToolSchema( + name=tool_meta["name"], + description=tool_meta.get("description", ""), + parameters=tool_meta.get("parameters", {}), + backend_type=BackendType.MCP, + ) + tool = RemoteTool(schema=schema, connector=None, backend=BackendType.MCP) + tool.bind_runtime_info( + backend=self.backend_type, + session_name=session_name, + server_name=server_name, + ) + tools.append(tool) + + return tools + + async def _list_tools_live(self) -> List[BaseTool]: + """List tools by starting all servers. + + Uses a semaphore to serialize session creation, avoiding TaskGroup race conditions + that occur when multiple MCP connections are initialized concurrently. + """ + servers = self.list_servers() + + if not servers: + logger.warning("No MCP servers configured") + return [] + + # Find servers that don't have sessions yet + to_create = [s for s in servers if s not in self._server_sessions] + + # Create missing sessions with serialized execution using semaphore + if to_create: + logger.info(f"Creating {len(to_create)} MCP sessions (serialized to avoid race conditions)") + + # Use semaphore with limit=1 to serialize session creation + # This avoids TaskGroup race conditions in concurrent HTTP connection setup + semaphore = asyncio.Semaphore(1) + + async def _create_with_semaphore(server: str): + async with semaphore: + logger.debug(f"Creating session for '{server}'") + return await self._lazy_create(server) + + tasks = [_create_with_semaphore(s) for s in to_create] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Log errors + for i, result in enumerate(results): + if isinstance(result, MCPDependencyError): + logger.debug(f"Dependency error for '{to_create[i]}': {type(result).__name__}") + elif isinstance(result, Exception): + logger.error(f"Failed to create session for '{to_create[i]}': {result}") + + # Aggregate tools from all sessions + uniq: Dict[tuple[str, str], BaseTool] = {} + failed_servers = [] + + logger.debug(f"Listing tools from {len(self._server_sessions)} sessions") + for server, sess in self._server_sessions.items(): + try: + tools = await sess.list_tools() + session_name = f"{self.backend_type.value}-{server}" + for tool in tools: + key = (server, tool.schema.name) + if key not in uniq: + tool.bind_runtime_info( + backend=self.backend_type, + session_name=session_name, + server_name=server, + ) + uniq[key] = tool + except Exception as e: + failed_servers.append(server) + logger.error(f"Error listing tools from server '{server}': {e}") + + if failed_servers: + logger.warning(f"Failed to list tools from {len(failed_servers)} server(s): {failed_servers}") + + tools_list = list(uniq.values()) + logger.debug(f"Listed {len(tools_list)} unique tools from {len(self._server_sessions)} MCP servers") + + # Save to cache for next time + await self._save_tools_to_cache(tools_list) + + return tools_list + + async def _save_tools_to_cache(self, tools: List[BaseTool]) -> None: + """Save tools metadata to cache file.""" + cache = get_tool_cache() + + # Group tools by server + servers: Dict[str, List[Dict]] = {} + for tool in tools: + server_name = tool.runtime_info.server_name if tool.is_bound else "unknown" + if server_name not in servers: + servers[server_name] = [] + servers[server_name].append({ + "name": tool.schema.name, + "description": tool.schema.description or "", + "parameters": tool.schema.parameters or {}, + }) + + cache.save(servers) + + async def ensure_server_session(self, server_name: str) -> Optional[MCPSession]: + """Ensure a server session exists, creating it if needed. + + This is used for on-demand server startup when executing tools. + """ + if server_name in self._server_sessions: + return self._server_sessions[server_name] + + # Server not running, start it + logger.info(f"Starting MCP server on-demand: {server_name}") + cfg = SessionConfig( + session_name=f"mcp-{server_name}", + backend_type=BackendType.MCP, + connection_params={"server": server_name}, + ) + + try: + session = await self.create_session(cfg) + return session + except Exception as e: + logger.error(f"Failed to start server '{server_name}': {e}") + return None + + async def _lazy_create(self, server: str) -> None: + """Internal helper for lazy session creation. + + Args: + server: Server name to create session for + + Raises: + Exception: Re-raises any exception from session creation for error tracking + """ + # Double-check to avoid race conditions + if server in self._server_sessions: + logger.debug(f"Session for server '{server}' already exists, skipping lazy creation") + return + + cfg = SessionConfig( + session_name=f"mcp-{server}", + backend_type=BackendType.MCP, + connection_params={"server": server}, + ) + + try: + await self.create_session(cfg) + logger.debug(f"Lazily created session for server '{server}'") + except MCPDependencyError as e: + # Dependency errors already shown to user + logger.debug(f"Dependency error for server '{server}': {type(e).__name__}") + # Re-raise so that asyncio.gather can track the error + raise + except Exception as e: + logger.error(f"Failed to lazily create session for server '{server}': {e}") + # Re-raise so that asyncio.gather can track the error + raise \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/session.py b/openspace/grounding/backends/mcp/session.py new file mode 100644 index 0000000000000000000000000000000000000000..817d99bd5d9eda831d1cbae64bd023f6f37af5f4 --- /dev/null +++ b/openspace/grounding/backends/mcp/session.py @@ -0,0 +1,75 @@ +""" +Session manager for MCP connections. + +This module provides a session manager for MCP connections, +which handles authentication, initialization, and tool discovery. +""" + +from typing import Any, Dict + +from openspace.grounding.backends.mcp.transport.connectors import MCPBaseConnector +from openspace.grounding.backends.mcp.tool_converter import convert_mcp_tool_to_base_tool +from openspace.grounding.core.session import BaseSession +from openspace.grounding.core.types import BackendType +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class MCPSession(BaseSession): + """Session manager for MCP connections. + + This class manages the lifecycle of an MCP connection, including + authentication, initialization, and tool discovery. + """ + + def __init__( + self, + connector: MCPBaseConnector, + *, + session_id: str = "", + auto_connect: bool = True, + auto_initialize: bool = True, + ) -> None: + """Initialize a new MCP session. + + Args: + connector: The connector to use for communicating with the MCP implementation. + session_id: Unique identifier for this session + auto_connect: Whether to automatically connect to the MCP implementation. + auto_initialize: Whether to automatically initialize the session. + """ + super().__init__( + connector=connector, + session_id=session_id, + backend_type=BackendType.MCP, + auto_connect=auto_connect, + auto_initialize=auto_initialize, + ) + + async def initialize(self) -> Dict[str, Any]: + """Initialize the MCP session and discover available tools. + + Returns: + The session information returned by the MCP implementation. + """ + # Make sure we're connected + if not self.is_connected and self.auto_connect: + await self.connect() + + # Initialize the session through connector + logger.debug(f"Initializing MCP session {self.session_id}") + session_info = await self.connector.initialize() + + # List tools from MCP server and convert to BaseTool + mcp_tools = self.connector.tools # MCPBaseConnector caches tools after initialize + logger.debug(f"Converting {len(mcp_tools)} MCP tools to BaseTool") + + self.tools = [ + convert_mcp_tool_to_base_tool(mcp_tool, self.connector) + for mcp_tool in mcp_tools + ] + + logger.debug(f"MCP session {self.session_id} initialized with {len(self.tools)} tools") + + return session_info \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/tool_cache.py b/openspace/grounding/backends/mcp/tool_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..64abc8cb6d57c4d69a9b046ca8cbed268d3b8c4e --- /dev/null +++ b/openspace/grounding/backends/mcp/tool_cache.py @@ -0,0 +1,254 @@ +import json +from pathlib import Path +from datetime import datetime +from typing import Any, Dict, List, Optional + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +# Cache path in project root directory (OpenSpace/) +# __file__ = .../OpenSpace/openspace/grounding/backends/mcp/tool_cache.py +# parent x5 = .../OpenSpace/ +DEFAULT_CACHE_PATH = Path(__file__).parent.parent.parent.parent.parent / "mcp_tool_cache.json" +# Sanitized cache path (Claude API compatible JSON Schema) +DEFAULT_SANITIZED_CACHE_PATH = Path(__file__).parent.parent.parent.parent.parent / "mcp_tool_cache_sanitized.json" + + +class MCPToolCache: + """Simple file-based cache for MCP tool metadata.""" + + CACHE_VERSION = 1 + + def __init__(self, cache_path: Optional[Path] = None, sanitized_cache_path: Optional[Path] = None): + self.cache_path = cache_path or DEFAULT_CACHE_PATH + self.sanitized_cache_path = sanitized_cache_path or DEFAULT_SANITIZED_CACHE_PATH + self._cache: Optional[Dict] = None + self._sanitized_cache: Optional[Dict] = None + self._server_order: Optional[List[str]] = None + + def set_server_order(self, order: List[str]): + """Set expected server order (from config). Used when saving to disk.""" + self._server_order = order + + def _reorder_servers(self, servers: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]: + """Reorder servers dict according to _server_order.""" + if not self._server_order: + return servers + + ordered = {} + # First add servers in config order + for name in self._server_order: + if name in servers: + ordered[name] = servers[name] + # Then add any remaining servers (not in config) + for name in servers: + if name not in ordered: + ordered[name] = servers[name] + return ordered + + def _ensure_dir(self): + """Ensure cache directory exists.""" + self.cache_path.parent.mkdir(parents=True, exist_ok=True) + + def load(self) -> Dict[str, Any]: + """Load cache from disk. Returns empty dict if not exists.""" + if self._cache is not None: + return self._cache + + if not self.cache_path.exists(): + self._cache = {"version": self.CACHE_VERSION, "servers": {}} + return self._cache + + try: + with open(self.cache_path, "r", encoding="utf-8") as f: + self._cache = json.load(f) + logger.info(f"Loaded MCP tool cache: {len(self._cache.get('servers', {}))} servers") + return self._cache + except Exception as e: + logger.warning(f"Failed to load cache: {e}") + self._cache = {"version": self.CACHE_VERSION, "servers": {}} + return self._cache + + def save(self, servers: Dict[str, List[Dict]]): + """ + Save tool metadata to disk (overwrites existing cache). + + Args: + servers: Dict mapping server_name -> list of tool metadata dicts + Each tool dict should have: name, description, parameters + """ + self._ensure_dir() + + cache_data = { + "version": self.CACHE_VERSION, + "updated_at": datetime.now().isoformat(), + "servers": servers, + } + + try: + with open(self.cache_path, "w", encoding="utf-8") as f: + json.dump(cache_data, f, indent=2, ensure_ascii=False) + self._cache = cache_data + logger.info(f"Saved MCP tool cache: {len(servers)} servers") + except Exception as e: + logger.error(f"Failed to save cache: {e}") + + def save_server(self, server_name: str, tools: List[Dict]): + """ + Save/update a single server's tools to cache (incremental append). + + Args: + server_name: Name of the MCP server + tools: List of tool metadata dicts for this server + """ + self._ensure_dir() + + # Load existing cache + cache = self.load() + + # Update server entry + if "servers" not in cache: + cache["servers"] = {} + cache["servers"][server_name] = tools + cache["servers"] = self._reorder_servers(cache["servers"]) + cache["updated_at"] = datetime.now().isoformat() + + # Save back + try: + with open(self.cache_path, "w", encoding="utf-8") as f: + json.dump(cache, f, indent=2, ensure_ascii=False) + self._cache = cache + logger.debug(f"Saved {len(tools)} tools for server '{server_name}'") + except Exception as e: + logger.error(f"Failed to save cache for server '{server_name}': {e}") + + def get_server_tools(self, server_name: str) -> Optional[List[Dict]]: + """Get cached tools for a specific server.""" + cache = self.load() + return cache.get("servers", {}).get(server_name) + + def get_all_tools(self) -> Dict[str, List[Dict]]: + """Get all cached tools, grouped by server.""" + cache = self.load() + return cache.get("servers", {}) + + def has_cache(self) -> bool: + """Check if cache exists and has data.""" + cache = self.load() + return bool(cache.get("servers")) + + def clear(self): + """Clear the cache.""" + if self.cache_path.exists(): + self.cache_path.unlink() + self._cache = None + logger.info("MCP tool cache cleared") + + def save_failed_server(self, server_name: str, error: str): + """ + Record a failed server to cache. + + Args: + server_name: Name of the failed MCP server + error: Error message + """ + self._ensure_dir() + + # Load existing cache + cache = self.load() + + # Add to failed_servers list + if "failed_servers" not in cache: + cache["failed_servers"] = {} + cache["failed_servers"][server_name] = { + "error": error, + "failed_at": datetime.now().isoformat(), + } + cache["updated_at"] = datetime.now().isoformat() + + # Save back + try: + with open(self.cache_path, "w", encoding="utf-8") as f: + json.dump(cache, f, indent=2, ensure_ascii=False) + self._cache = cache + except Exception as e: + logger.error(f"Failed to save failed server '{server_name}': {e}") + + def get_failed_servers(self) -> Dict[str, Dict]: + """Get list of failed servers from cache.""" + cache = self.load() + return cache.get("failed_servers", {}) + + def load_sanitized(self) -> Dict[str, Any]: + """Load sanitized cache from disk. Returns empty dict if not exists.""" + if self._sanitized_cache is not None: + return self._sanitized_cache + + if not self.sanitized_cache_path.exists(): + self._sanitized_cache = {"version": self.CACHE_VERSION, "servers": {}} + return self._sanitized_cache + + try: + with open(self.sanitized_cache_path, "r", encoding="utf-8") as f: + self._sanitized_cache = json.load(f) + logger.info(f"Loaded sanitized MCP tool cache: {len(self._sanitized_cache.get('servers', {}))} servers") + return self._sanitized_cache + except Exception as e: + logger.warning(f"Failed to load sanitized cache: {e}") + self._sanitized_cache = {"version": self.CACHE_VERSION, "servers": {}} + return self._sanitized_cache + + def save_sanitized(self, servers: Dict[str, List[Dict]]): + """ + Save sanitized tool metadata to disk. + + Args: + servers: Dict mapping server_name -> list of sanitized tool metadata dicts + """ + self._ensure_dir() + + cache_data = { + "version": self.CACHE_VERSION, + "updated_at": datetime.now().isoformat(), + "sanitized": True, + "servers": servers, + } + + try: + with open(self.sanitized_cache_path, "w", encoding="utf-8") as f: + json.dump(cache_data, f, indent=2, ensure_ascii=False) + self._sanitized_cache = cache_data + logger.info(f"Saved sanitized MCP tool cache: {len(servers)} servers") + except Exception as e: + logger.error(f"Failed to save sanitized cache: {e}") + + def get_all_sanitized_tools(self) -> Dict[str, List[Dict]]: + """Get all sanitized cached tools, grouped by server.""" + cache = self.load_sanitized() + return cache.get("servers", {}) + + def has_sanitized_cache(self) -> bool: + """Check if sanitized cache exists and has data.""" + cache = self.load_sanitized() + return bool(cache.get("servers")) + + def clear_sanitized(self): + """Clear the sanitized cache.""" + if self.sanitized_cache_path.exists(): + self.sanitized_cache_path.unlink() + self._sanitized_cache = None + logger.info("Sanitized MCP tool cache cleared") + + +# Global instance +_tool_cache: Optional[MCPToolCache] = None + + +def get_tool_cache() -> MCPToolCache: + """Get global tool cache instance.""" + global _tool_cache + if _tool_cache is None: + _tool_cache = MCPToolCache() + return _tool_cache + diff --git a/openspace/grounding/backends/mcp/tool_converter.py b/openspace/grounding/backends/mcp/tool_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..8bce08fd0fd50a5a0edc403cfc45b25f63085bb5 --- /dev/null +++ b/openspace/grounding/backends/mcp/tool_converter.py @@ -0,0 +1,194 @@ +""" +Tool converter for MCP. + +This module provides utilities to convert MCP tools to BaseTool instances. +""" + +import copy +from typing import Any, Dict +from mcp.types import Tool as MCPTool + +from openspace.grounding.core.tool import BaseTool, RemoteTool +from openspace.grounding.core.types import BackendType, ToolSchema +from openspace.grounding.core.transport.connectors import BaseConnector +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +def _sanitize_mcp_schema(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize MCP tool schema to ensure Claude API compatibility (JSON Schema draft 2020-12). + + Fixes: + - Empty schemas -> valid object schema + - Missing required fields (type, properties, required) + - Removes non-standard fields (title, examples, nullable, default, etc.) + - Recursively cleans nested properties and items + - Ensures every property has a valid type + - Ensures top-level type is 'object' (Anthropic API requirement) + """ + if not params: + return {"type": "object", "properties": {}, "required": []} + + sanitized = copy.deepcopy(params) + sanitized = _deep_sanitize(sanitized) + + # Anthropic API requires top-level type to be 'object' + # If it's not an object, wrap the schema as a property of an object + top_level_type = sanitized.get("type") + if top_level_type and top_level_type != "object": + logger.debug(f"[MCP_SCHEMA_SANITIZE] Wrapping non-object schema (type={top_level_type}) into object") + wrapped = { + "type": "object", + "properties": { + "value": sanitized # The original schema becomes a property + }, + "required": ["value"] # Make it required + } + sanitized = wrapped + + return sanitized + + +def _deep_sanitize(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively sanitize a JSON schema to conform to JSON Schema draft 2020-12. + Removes non-standard fields and ensures valid structure. + """ + if not isinstance(schema, dict): + return {"type": "string"} + + # Allowed top-level keys for Claude API compatibility + allowed_keys = { + "type", "properties", "required", "items", + "description", "enum", "const", + "minimum", "maximum", "minLength", "maxLength", + "minItems", "maxItems", "pattern", + "additionalProperties", "anyOf", "oneOf", "allOf" + } + + # Remove disallowed keys + keys_to_remove = [k for k in schema if k not in allowed_keys] + for k in keys_to_remove: + schema.pop(k, None) + + # Ensure type exists + if "type" not in schema: + # Type is defined via anyOf/oneOf/allOf - don't add default type + # These combination keywords define the type themselves + if "anyOf" in schema or "oneOf" in schema or "allOf" in schema: + pass # Type is defined through combination keywords, do not add default type + # Try to infer type + elif "properties" in schema: + schema["type"] = "object" + elif "items" in schema: + schema["type"] = "array" + elif "enum" in schema: + # For enum, try to infer from values + enum_vals = schema.get("enum", []) + if enum_vals and all(isinstance(v, str) for v in enum_vals): + schema["type"] = "string" + elif enum_vals and all(isinstance(v, (int, float)) for v in enum_vals): + schema["type"] = "number" + else: + schema["type"] = "string" + elif not schema: + # Empty schema (e.g., only had $schema which was removed) -> no parameters needed + schema["type"] = "object" + schema["properties"] = {} + schema["required"] = [] + else: + schema["type"] = "object" + + # Handle object type + if schema.get("type") == "object": + if "properties" not in schema: + schema["properties"] = {} + if "required" not in schema: + schema["required"] = [] + + # Recursively sanitize properties + if isinstance(schema.get("properties"), dict): + for prop_name, prop_schema in list(schema["properties"].items()): + if isinstance(prop_schema, dict): + schema["properties"][prop_name] = _deep_sanitize(prop_schema) + else: + # Invalid property schema, replace with string + schema["properties"][prop_name] = {"type": "string"} + + # Sanitize additionalProperties if present + if "additionalProperties" in schema and isinstance(schema["additionalProperties"], dict): + schema["additionalProperties"] = _deep_sanitize(schema["additionalProperties"]) + + # Handle array type + elif schema.get("type") == "array": + if "items" in schema: + if isinstance(schema["items"], dict): + schema["items"] = _deep_sanitize(schema["items"]) + elif isinstance(schema["items"], list): + # Tuple validation - sanitize each item + schema["items"] = [_deep_sanitize(item) if isinstance(item, dict) else {"type": "string"} for item in schema["items"]] + else: + schema["items"] = {"type": "string"} + else: + # Default items to string if not specified + schema["items"] = {"type": "string"} + + # Handle anyOf/oneOf/allOf + for combo_key in ["anyOf", "oneOf", "allOf"]: + if combo_key in schema and isinstance(schema[combo_key], list): + schema[combo_key] = [ + _deep_sanitize(sub) if isinstance(sub, dict) else {"type": "string"} + for sub in schema[combo_key] + ] + + return schema + + +def convert_mcp_tool_to_base_tool( + mcp_tool: MCPTool, + connector: BaseConnector +) -> BaseTool: + """ + Convert an MCP Tool to a BaseTool (RemoteTool) instance. + + This function extracts the tool schema from an MCP tool object and creates + a RemoteTool that can be used within the grounding framework. + + Args: + mcp_tool: MCP Tool object from the MCP SDK + connector: Connector instance for communicating with the MCP server + + Returns: + RemoteTool instance wrapping the MCP tool + """ + # Extract tool metadata + tool_name = mcp_tool.name + tool_description = getattr(mcp_tool, 'description', None) or "" + + # Convert MCP input schema to our parameter schema format (with sanitization) + input_schema: Dict[str, Any] = {} + if hasattr(mcp_tool, 'inputSchema') and mcp_tool.inputSchema: + input_schema = _sanitize_mcp_schema(mcp_tool.inputSchema) + else: + input_schema = {"type": "object", "properties": {}, "required": []} + + # Create ToolSchema + schema = ToolSchema( + name=tool_name, + description=tool_description, + parameters=input_schema, + backend_type=BackendType.MCP, + ) + + # Create and return RemoteTool + remote_tool = RemoteTool( + connector=connector, + remote_name=tool_name, + schema=schema, + backend=BackendType.MCP, + ) + + logger.debug(f"Converted MCP tool '{tool_name}' to RemoteTool") + return remote_tool \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/connectors/__init__.py b/openspace/grounding/backends/mcp/transport/connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e2fcc4fd000a64a8cbeacd1ac08716aa590804 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/__init__.py @@ -0,0 +1,20 @@ +""" +Connectors for various MCP transports. + +This module provides interfaces for connecting to MCP implementations +through different transport mechanisms. +""" + +from .base import MCPBaseConnector # noqa: F401 +from .http import HttpConnector # noqa: F401 +from .sandbox import SandboxConnector # noqa: F401 +from .stdio import StdioConnector # noqa: F401 +from .websocket import WebSocketConnector # noqa: F401 + +__all__ = [ + "MCPBaseConnector", + "StdioConnector", + "HttpConnector", + "WebSocketConnector", + "SandboxConnector", +] diff --git a/openspace/grounding/backends/mcp/transport/connectors/base.py b/openspace/grounding/backends/mcp/transport/connectors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fdff59701ba7bfaf612a8692ad904c302dff11 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/base.py @@ -0,0 +1,374 @@ +""" +Base connector for MCP implementations. + +This module provides the base connector interface that all MCP connectors must implement. +""" + +import asyncio +from abc import abstractmethod +from typing import Any + +from mcp import ClientSession +from mcp.shared.exceptions import McpError +from mcp.types import CallToolResult, GetPromptResult, Prompt, ReadResourceResult, Resource, Tool + +from openspace.grounding.core.transport.task_managers import BaseConnectionManager +from openspace.grounding.core.transport.connectors import BaseConnector +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +# Default retry settings for tool calls +DEFAULT_TOOL_CALL_MAX_RETRIES = 3 +DEFAULT_TOOL_CALL_RETRY_DELAY = 1.0 + + +class MCPBaseConnector(BaseConnector[ClientSession]): + """Base class for MCP connectors. + + This class defines the interface that all MCP connectors must implement. + """ + + def __init__( + self, + connection_manager: BaseConnectionManager[ClientSession], + tool_call_max_retries: int = DEFAULT_TOOL_CALL_MAX_RETRIES, + tool_call_retry_delay: float = DEFAULT_TOOL_CALL_RETRY_DELAY, + ): + """Initialize base connector with common attributes. + + Args: + connection_manager: The connection manager to use for the connection. + tool_call_max_retries: Maximum number of retries for tool calls (default: 3) + tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0) + """ + super().__init__(connection_manager) + self.client_session: ClientSession | None = None + self._tools: list[Tool] | None = None + self._resources: list[Resource] | None = None + self._prompts: list[Prompt] | None = None + self.auto_reconnect = True # Whether to automatically reconnect on connection loss (not configurable for now) + self.tool_call_max_retries = tool_call_max_retries + self.tool_call_retry_delay = tool_call_retry_delay + + @property + @abstractmethod + def public_identifier(self) -> str: + """Get the identifier for the connector.""" + pass + + async def _get_streams_from_connection(self): + """Get read and write streams from the connection. Override in subclasses if needed.""" + # Default implementation for most MCP connectors (stdio, HTTP) + # Returns the connection directly as it should be a tuple of (read_stream, write_stream) + return self._connection + + async def _after_connect(self) -> None: + """Create ClientSession after connection is established. + + Some connectors (like WebSocket) don't use ClientSession and may override this method. + """ + # Get streams from the connection + streams = await self._get_streams_from_connection() + + if streams is None: + # Some connectors (like WebSocket) don't use ClientSession + # They should override this method to set up their own resources + logger.debug("No streams returned, ClientSession creation skipped") + return + + if isinstance(streams, tuple) and len(streams) == 2: + read_stream, write_stream = streams + # Create the client session + self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None) + await self.client_session.__aenter__() + logger.debug("MCP ClientSession created successfully") + else: + raise RuntimeError(f"Invalid streams format: expected tuple of 2 elements, got {type(streams)}") + + async def _before_disconnect(self) -> None: + """Clean up MCP-specific resources before disconnection.""" + errors = [] + + # Close the client session + if self.client_session: + try: + logger.debug("Closing MCP client session") + await self.client_session.__aexit__(None, None, None) + except Exception as e: + error_msg = f"Error closing client session: {e}" + logger.warning(error_msg) + errors.append(error_msg) + finally: + self.client_session = None + + # Reset tools, resources, and prompts + self._tools = None + self._resources = None + self._prompts = None + + if errors: + logger.warning(f"Encountered {len(errors)} errors during MCP resource cleanup") + + async def _cleanup_on_connect_failure(self) -> None: + """Override to add MCP-specific cleanup on connection failure.""" + # Clean up client session if it was created + if self.client_session: + try: + await self.client_session.__aexit__(None, None, None) + except Exception: + pass + finally: + self.client_session = None + + # Call parent cleanup + await super()._cleanup_on_connect_failure() + + async def initialize(self) -> dict[str, Any]: + """Initialize the MCP session and return session information.""" + if not self.client_session: + raise RuntimeError("MCP client is not connected") + + logger.debug("Initializing MCP session") + + # Initialize the session + result = await self.client_session.initialize() + + server_capabilities = result.capabilities + + if server_capabilities.tools: + # Get available tools + tools_result = await self.list_tools() + self._tools = tools_result or [] + else: + self._tools = [] + + if server_capabilities.resources: + # Get available resources + resources_result = await self.list_resources() + self._resources = resources_result or [] + else: + self._resources = [] + + if server_capabilities.prompts: + # Get available prompts + prompts_result = await self.list_prompts() + self._prompts = prompts_result or [] + else: + self._prompts = [] + + logger.debug( + f"MCP session initialized with {len(self._tools)} tools, " + f"{len(self._resources)} resources, " + f"and {len(self._prompts)} prompts" + ) + + return result + + @property + def tools(self) -> list[Tool]: + """Get the list of available tools.""" + if self._tools is None: + raise RuntimeError("MCP client is not initialized") + return self._tools + + @property + def resources(self) -> list[Resource]: + """Get the list of available resources.""" + if self._resources is None: + raise RuntimeError("MCP client is not initialized") + return self._resources + + @property + def prompts(self) -> list[Prompt]: + """Get the list of available prompts.""" + if self._prompts is None: + raise RuntimeError("MCP client is not initialized") + return self._prompts + + @property + def is_connected(self) -> bool: + """Check if the connector is actually connected and the connection is alive. + + This property checks not only the connected flag but also verifies that + the client session exists and the underlying connection is still active. + + Returns: + True if the connector is connected and the connection is alive, False otherwise. + """ + # First check the basic connected flag + if not self._connected: + return False + + # Check if we have a client session + if not self.client_session: + self._connected = False + return False + + # Check if connection manager task is still running (if applicable) + if self._connection_manager and hasattr(self._connection_manager, "_task"): + task = self._connection_manager._task + if task and task.done(): + logger.debug("Connection manager task is done, marking as disconnected") + self._connected = False + return False + + return True + + async def _ensure_connected(self) -> None: + """Ensure the connector is connected, reconnecting if necessary. + + Raises: + RuntimeError: If connection cannot be established and auto_reconnect is False. + """ + if not self.client_session: + raise RuntimeError("MCP client is not connected") + + if not self.is_connected: + if self.auto_reconnect: + logger.debug("Connection lost, attempting to reconnect...") + try: + await self.connect() + logger.debug("Reconnection successful") + except Exception as e: + raise RuntimeError(f"Failed to reconnect to MCP server: {e}") from e + else: + raise RuntimeError( + "Connection to MCP server has been lost. Auto-reconnection is disabled. Please reconnect manually." + ) + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> CallToolResult: + """Call an MCP tool with automatic reconnection handling and retry logic. + + Args: + name: The name of the tool to call. + arguments: The arguments to pass to the tool. + + Returns: + The result of the tool call. + + Raises: + RuntimeError: If the connection is lost and cannot be reestablished. + Exception: If the tool call fails after all retries. + """ + last_error: Exception | None = None + + for attempt in range(self.tool_call_max_retries): + # Ensure we're connected + await self._ensure_connected() + + logger.debug(f"Calling tool '{name}' with arguments: {arguments} (attempt {attempt + 1}/{self.tool_call_max_retries})") + try: + result = await self.client_session.call_tool(name, arguments) + logger.debug(f"Tool '{name}' called successfully") + return result + except Exception as e: + last_error = e + error_str = str(e).lower() + + # Check if the error might be due to connection loss + if not self.is_connected: + logger.warning(f"Tool call '{name}' failed due to connection loss: {e}") + # Try to reconnect on next iteration + continue + + # Check for retryable HTTP errors (400, 500, 502, 503, 504) + is_retryable = any(code in error_str for code in ['400', '500', '502', '503', '504', 'bad request', 'internal server error', 'service unavailable', 'gateway timeout']) + + if is_retryable and attempt < self.tool_call_max_retries - 1: + delay = self.tool_call_retry_delay * (2 ** attempt) # Exponential backoff + logger.warning( + f"Tool call '{name}' failed with retryable error: {e}, " + f"retrying in {delay:.1f}s (attempt {attempt + 1}/{self.tool_call_max_retries})" + ) + await asyncio.sleep(delay) + continue + + # Non-retryable error or max retries reached, re-raise + raise + + # All retries exhausted + error_msg = f"Tool call '{name}' failed after {self.tool_call_max_retries} retries" + logger.error(error_msg) + raise RuntimeError(error_msg) from last_error + + async def list_tools(self) -> list[Tool]: + """List all available tools from the MCP implementation.""" + + # Ensure we're connected + await self._ensure_connected() + + logger.debug("Listing tools") + try: + result = await self.client_session.list_tools() + return result.tools + except McpError as e: + logger.error(f"Error listing tools: {e}") + return [] + + async def list_resources(self) -> list[Resource]: + """List all available resources from the MCP implementation.""" + # Ensure we're connected + await self._ensure_connected() + + logger.debug("Listing resources") + try: + result = await self.client_session.list_resources() + return result.resources + except McpError as e: + logger.error(f"Error listing resources: {e}") + return [] + + async def read_resource(self, uri: str) -> ReadResourceResult: + """Read a resource by URI.""" + if not self.client_session: + raise RuntimeError("MCP client is not connected") + + logger.debug(f"Reading resource: {uri}") + result = await self.client_session.read_resource(uri) + return result + + async def list_prompts(self) -> list[Prompt]: + """List all available prompts from the MCP implementation.""" + # Ensure we're connected + await self._ensure_connected() + + logger.debug("Listing prompts") + try: + result = await self.client_session.list_prompts() + return result.prompts + except McpError as e: + logger.error(f"Error listing prompts: {e}") + return [] + + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + """Get a prompt by name.""" + # Ensure we're connected + await self._ensure_connected() + + logger.debug(f"Getting prompt: {name}") + result = await self.client_session.get_prompt(name, arguments) + return result + + async def request(self, method: str, params: dict[str, Any] | None = None) -> Any: + """Send a raw request to the MCP implementation.""" + # Ensure we're connected + await self._ensure_connected() + + logger.debug(f"Sending request: {method} with params: {params}") + return await self.client_session.request({"method": method, "params": params or {}}) + + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + await self._ensure_connected() + + if not name.startswith("__"): + return await self.call_tool(name, params) + + if name == "__read_resource__": + return await self.read_resource(params["uri"]) + if name == "__list_prompts__": + return await self.list_prompts() + if name == "__get_prompt__": + return await self.get_prompt(params["name"], params.get("args")) + + raise ValueError(f"Unsupported MCP invoke name: {name}") \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/connectors/http.py b/openspace/grounding/backends/mcp/transport/connectors/http.py new file mode 100644 index 0000000000000000000000000000000000000000..099f112f21405daf82444dd48192aa2c4f198d2e --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/http.py @@ -0,0 +1,705 @@ +""" +HTTP connector for MCP implementations. + +This module provides a connector for communicating with MCP implementations +through HTTP APIs with SSE, Streamable HTTP, or simple JSON-RPC for transport. +""" + +import asyncio +import anyio +import httpx +from typing import Any, Dict, List +from mcp import ClientSession +from mcp.types import ( + CallToolResult, + TextContent, + ImageContent, + EmbeddedResource, + Tool, + Resource, + Prompt, + GetPromptResult, + ReadResourceResult, +) + +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers.base import BaseConnectionManager +from openspace.grounding.backends.mcp.transport.task_managers import SseConnectionManager, StreamableHttpConnectionManager +from openspace.grounding.backends.mcp.transport.connectors.base import MCPBaseConnector, DEFAULT_TOOL_CALL_MAX_RETRIES, DEFAULT_TOOL_CALL_RETRY_DELAY + +logger = Logger.get_logger(__name__) + + +class HttpConnector(MCPBaseConnector): + """Connector for MCP implementations using HTTP transport. + + This connector uses HTTP/SSE or streamable HTTP to communicate with remote MCP implementations, + using a connection manager to handle the proper lifecycle management. + """ + + def __init__( + self, + base_url: str, + auth_token: str | None = None, + headers: dict[str, str] | None = None, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, + tool_call_max_retries: int = DEFAULT_TOOL_CALL_MAX_RETRIES, + tool_call_retry_delay: float = DEFAULT_TOOL_CALL_RETRY_DELAY, + ): + """Initialize a new HTTP connector. + + Args: + base_url: The base URL of the MCP HTTP API. + auth_token: Optional authentication token. + headers: Optional additional headers. + timeout: Timeout for HTTP operations in seconds. + sse_read_timeout: Timeout for SSE read operations in seconds. + tool_call_max_retries: Maximum number of retries for tool calls (default: 3) + tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0) + """ + self.base_url = base_url.rstrip("/") + self.auth_token = auth_token + self.headers = headers or {} + if auth_token: + self.headers["Authorization"] = f"Bearer {auth_token}" + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + + # JSON-RPC HTTP mode fields + self._use_jsonrpc = False + self._jsonrpc_client: httpx.AsyncClient | None = None + self._jsonrpc_request_id = 0 + + # Create a placeholder connection manager (will be set up later in connect()) + # We use a placeholder here because the actual transport type (SSE vs Streamable HTTP) + # can only be determined at runtime through server negotiation as per MCP specification + from openspace.grounding.core.transport.task_managers import PlaceholderConnectionManager + connection_manager = PlaceholderConnectionManager() + super().__init__( + connection_manager, + tool_call_max_retries=tool_call_max_retries, + tool_call_retry_delay=tool_call_retry_delay, + ) + + async def connect(self) -> None: + """Create the underlying session/connection. + + For JSON-RPC mode, we don't use a connection manager. + """ + if self._connected: + return + + try: + # Hook: before connection - this sets up transport type + await self._before_connect() + + if self._use_jsonrpc: + # JSON-RPC mode doesn't use connection manager + # Just call _after_connect to set up the HTTP client + await self._after_connect() + self._connected = True + else: + # Use normal connection flow with connection manager + # If _before_connect() already established a connection, reuse it + if self._connection is None: + self._connection = await self._connection_manager.start() + await self._after_connect() + self._connected = True + except Exception: + await self._cleanup_on_connect_failure() + raise + + async def disconnect(self) -> None: + """Close the session/connection and reset state.""" + if not self._connected: + return + + # Hook: before disconnection + await self._before_disconnect() + + if not self._use_jsonrpc: + # Stop the connection manager only for non-JSON-RPC modes + if self._connection_manager: + await self._connection_manager.stop() + self._connection = None + + # Hook: after disconnection + await self._after_disconnect() + + self._connected = False + + async def _before_connect(self) -> None: + """Negotiate transport type and set up the appropriate connection manager. + + Tries transports in order: + 1. Streamable HTTP (new MCP transport) + 2. SSE (legacy MCP transport) + 3. Simple JSON-RPC HTTP (for custom servers) + + This implements backwards compatibility per MCP specification. + """ + self.transport_type = None + self._use_jsonrpc = False + connection_manager = None + streamable_error = None + sse_error = None + + # First, try the new streamable HTTP transport + try: + logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}") + connection_manager = StreamableHttpConnectionManager( + self.base_url, self.headers, self.timeout, self.sse_read_timeout + ) + + # Test the connection by starting it with built-in timeout + read_stream, write_stream = await connection_manager.start(timeout=self.timeout) + + # Create and verify ClientSession + test_client = ClientSession(read_stream, write_stream, sampling_callback=None) + + # Add timeout to __aenter__ - use asyncio.wait_for instead of anyio.fail_after + # to avoid cancel scope conflicts with background tasks + try: + await asyncio.wait_for(test_client.__aenter__(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"ClientSession enter timed out after {self.timeout}s") + + try: + # Add timeout to initialize() using asyncio.wait_for to prevent hanging + try: + await asyncio.wait_for(test_client.initialize(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"initialize() timed out after {self.timeout}s") + + try: + await asyncio.wait_for(test_client.list_tools(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"list_tools() timed out after {self.timeout}s") + + # SUCCESS! Keep the client session (don't close it, closing destroys the streams) + # Store it directly as the client_session for later use + self.transport_type = "streamable HTTP" + self._connection_manager = connection_manager + self._connection = connection_manager.get_streams() + self.client_session = test_client # Reuse the working session + logger.debug("Streamable HTTP transport selected") + return + except TimeoutError: + try: + await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + raise + except Exception as init_error: + # Clean up the test client only on error + try: + await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + raise init_error + + except Exception as e: + streamable_error = e + logger.debug(f"Streamable HTTP failed: {e}") + + # Clean up the failed connection manager + if connection_manager: + try: + await asyncio.wait_for(connection_manager.stop(), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + + # Try SSE fallback + try: + logger.debug(f"Attempting SSE fallback connection to: {self.base_url}") + connection_manager = SseConnectionManager( + self.base_url, self.headers, self.timeout, self.sse_read_timeout + ) + + # Test the connection by starting it with built-in timeout + read_stream, write_stream = await connection_manager.start(timeout=self.timeout) + + # Create and verify ClientSession + test_client = ClientSession(read_stream, write_stream, sampling_callback=None) + + # Add timeout to __aenter__ - use asyncio.wait_for instead of anyio.fail_after + # to avoid cancel scope conflicts with background tasks + try: + await asyncio.wait_for(test_client.__aenter__(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"ClientSession enter timed out after {self.timeout}s") + + try: + try: + await asyncio.wait_for(test_client.initialize(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"initialize() timed out after {self.timeout}s") + + try: + await asyncio.wait_for(test_client.list_tools(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"list_tools() timed out after {self.timeout}s") + + # SUCCESS! Keep the client session (don't close it, closing destroys the streams) + # Store it directly as the client_session for later use + self.transport_type = "SSE" + self._connection_manager = connection_manager + self._connection = connection_manager.get_streams() + self.client_session = test_client # Reuse the working session + logger.debug("SSE transport selected") + return + except TimeoutError: + try: + await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + raise + except Exception as init_error: + # Clean up the test client only on error + try: + await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + raise init_error + + except Exception as e: + sse_error = e + logger.debug(f"SSE failed: {e}") + + # Clean up the failed connection manager + if connection_manager: + try: + await asyncio.wait_for(connection_manager.stop(), timeout=2) + except (asyncio.TimeoutError, Exception): + pass + + # Both MCP transports failed, try simple JSON-RPC HTTP as last resort + # This is useful for custom MCP servers that don't implement proper MCP transports + logger.debug(f"Attempting JSON-RPC HTTP fallback to: {self.base_url}") + try: + # Test JSON-RPC connection + await self._try_jsonrpc_connection() + + self.transport_type = "JSON-RPC HTTP" + self._use_jsonrpc = True + logger.info(f"JSON-RPC HTTP transport selected for: {self.base_url}") + return + + except Exception as jsonrpc_error: + # All transports failed + logger.error( + f"All transport methods failed for {self.base_url}. " + f"Streamable HTTP: {streamable_error}, SSE: {sse_error}, JSON-RPC: {jsonrpc_error}" + ) + # Raise the most relevant error - prefer the original streamable error + raise streamable_error or sse_error or jsonrpc_error + + async def _try_jsonrpc_connection(self) -> None: + """Test JSON-RPC HTTP connection by sending an initialize request.""" + headers = {**self.headers, "Content-Type": "application/json"} + + async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout), headers=headers) as client: + payload = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "OpenSpace", "version": "1.0.0"}, + } + } + + response = await client.post(self.base_url, json=payload) + response.raise_for_status() + + data = response.json() + + # Check for JSON-RPC error + if "error" in data: + error = data["error"] + raise RuntimeError(f"JSON-RPC error: {error.get('message', str(error))}") + + # Success - server supports JSON-RPC + logger.debug(f"JSON-RPC test succeeded: {data.get('result', {})}") + + async def _after_connect(self) -> None: + """Create ClientSession (or set up JSON-RPC client) and log success.""" + if self._use_jsonrpc: + # Set up JSON-RPC HTTP client + headers = {**self.headers, "Content-Type": "application/json"} + self._jsonrpc_client = httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + headers=headers, + ) + logger.debug(f"JSON-RPC HTTP client set up for: {self.base_url}") + else: + # Skip creating ClientSession if _before_connect() already created one + if self.client_session is None: + await super()._after_connect() + else: + logger.debug("Reusing ClientSession from _before_connect()") + + logger.debug(f"Successfully connected to MCP implementation via {self.transport_type}: {self.base_url}") + + async def _before_disconnect(self) -> None: + """Clean up resources before disconnection.""" + # Clean up JSON-RPC client if used + if self._jsonrpc_client: + try: + await self._jsonrpc_client.aclose() + except Exception as e: + logger.warning(f"Error closing JSON-RPC client: {e}") + finally: + self._jsonrpc_client = None + + # Call parent cleanup for MCP resources + await super()._before_disconnect() + + @property + def public_identifier(self) -> str: + """Get the identifier for the connector.""" + return {"type": self.transport_type, "base_url": self.base_url} + + # ===================== + # JSON-RPC HTTP Methods + # ===================== + + def _next_jsonrpc_id(self) -> int: + """Get next JSON-RPC request ID.""" + self._jsonrpc_request_id += 1 + return self._jsonrpc_request_id + + async def _send_jsonrpc_request( + self, + method: str, + params: Dict[str, Any] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> Any: + """Send a JSON-RPC request and return the result. + + Args: + method: The JSON-RPC method name (e.g., "tools/list", "tools/call") + params: The method parameters + max_retries: Maximum number of retries for transient errors (400, 503, etc.) + retry_delay: Initial delay between retries (doubles each retry) + + Returns: + The result field from the JSON-RPC response + """ + if not self._jsonrpc_client: + raise RuntimeError("JSON-RPC client not initialized") + + last_error = None + + for attempt in range(max_retries): + request_id = self._next_jsonrpc_id() + payload = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params or {}, + } + + logger.debug(f"Sending JSON-RPC request: {method} (id={request_id}, attempt {attempt + 1}/{max_retries})") + + try: + response = await self._jsonrpc_client.post(self.base_url, json=payload) + response.raise_for_status() + + data = response.json() + + if "error" in data: + error = data["error"] + error_msg = error.get("message", str(error)) + raise RuntimeError(f"JSON-RPC error: {error_msg}") + + return data.get("result", {}) + + except httpx.HTTPStatusError as e: + last_error = e + status_code = e.response.status_code + + # Retry on 400 (Bad Request) and 5xx errors + # 400 can happen when MCP server is temporarily not ready + if status_code in (400, 500, 502, 503, 504) and attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"HTTP {status_code} error on {method}, retrying in {delay:.1f}s " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(delay) + continue + + raise RuntimeError(f"HTTP error: {status_code}") from e + + except httpx.RequestError as e: + last_error = e + # Retry on connection errors + if attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"Request error on {method}: {e}, retrying in {delay:.1f}s " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(delay) + continue + + raise RuntimeError(f"Request error: {e}") from e + + # Should not reach here, but just in case + raise RuntimeError(f"Max retries exceeded for {method}") from last_error + + def _parse_tools_from_json(self, tools_data: List[Dict]) -> List[Tool]: + """Parse tool data into Tool objects.""" + tools = [] + for tool_dict in tools_data: + try: + tool = Tool( + name=tool_dict.get("name", ""), + description=tool_dict.get("description", ""), + inputSchema=tool_dict.get("inputSchema", {}), + ) + tools.append(tool) + except Exception as e: + logger.warning(f"Failed to parse tool: {e}") + return tools + + def _parse_resources_from_json(self, resources_data: List[Dict]) -> List[Resource]: + """Parse resource data into Resource objects.""" + resources = [] + for res_dict in resources_data: + try: + resource = Resource( + uri=res_dict.get("uri", ""), + name=res_dict.get("name", ""), + description=res_dict.get("description"), + mimeType=res_dict.get("mimeType"), + ) + resources.append(resource) + except Exception as e: + logger.warning(f"Failed to parse resource: {e}") + return resources + + def _parse_prompts_from_json(self, prompts_data: List[Dict]) -> List[Prompt]: + """Parse prompt data into Prompt objects.""" + prompts = [] + for prompt_dict in prompts_data: + try: + prompt = Prompt( + name=prompt_dict.get("name", ""), + description=prompt_dict.get("description"), + arguments=prompt_dict.get("arguments"), + ) + prompts.append(prompt) + except Exception as e: + logger.warning(f"Failed to parse prompt: {e}") + return prompts + + # ===================== + # Override MCP Methods for JSON-RPC Support + # ===================== + + async def initialize(self) -> Dict[str, Any]: + """Initialize the MCP session.""" + if not self._use_jsonrpc: + return await super().initialize() + + # JSON-RPC mode + logger.debug("Initializing JSON-RPC HTTP MCP session") + + result = await self._send_jsonrpc_request("initialize", { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "OpenSpace", "version": "1.0.0"}, + }) + + capabilities = result.get("capabilities", {}) + + # List tools + if capabilities.get("tools"): + try: + tools_result = await self._send_jsonrpc_request("tools/list", {}) + self._tools = self._parse_tools_from_json(tools_result.get("tools", [])) + except Exception: + self._tools = [] + else: + # Try anyway - some servers don't advertise capabilities correctly + try: + tools_result = await self._send_jsonrpc_request("tools/list", {}) + self._tools = self._parse_tools_from_json(tools_result.get("tools", [])) + except Exception: + self._tools = [] + + # List resources + if capabilities.get("resources"): + try: + resources_result = await self._send_jsonrpc_request("resources/list", {}) + self._resources = self._parse_resources_from_json(resources_result.get("resources", [])) + except Exception: + self._resources = [] + else: + self._resources = [] + + # List prompts + if capabilities.get("prompts"): + try: + prompts_result = await self._send_jsonrpc_request("prompts/list", {}) + self._prompts = self._parse_prompts_from_json(prompts_result.get("prompts", [])) + except Exception: + self._prompts = [] + else: + self._prompts = [] + + logger.info( + f"JSON-RPC HTTP MCP session initialized with {len(self._tools)} tools, " + f"{len(self._resources)} resources, {len(self._prompts)} prompts" + ) + + return result + + @property + def is_connected(self) -> bool: + """Check if the connector is connected.""" + if self._use_jsonrpc: + return self._connected and self._jsonrpc_client is not None + return super().is_connected + + async def _ensure_connected(self) -> None: + """Ensure the connector is connected.""" + if self._use_jsonrpc: + if not self._connected or not self._jsonrpc_client: + raise RuntimeError("JSON-RPC HTTP connector is not connected") + else: + await super()._ensure_connected() + + async def list_tools(self) -> List[Tool]: + """List all available tools.""" + if not self._use_jsonrpc: + return await super().list_tools() + + await self._ensure_connected() + try: + tools_result = await self._send_jsonrpc_request("tools/list", {}) + self._tools = self._parse_tools_from_json(tools_result.get("tools", [])) + return self._tools + except Exception as e: + logger.error(f"Error listing tools: {e}") + return [] + + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> CallToolResult: + """Call an MCP tool.""" + if not self._use_jsonrpc: + return await super().call_tool(name, arguments) + + await self._ensure_connected() + logger.debug(f"Calling tool '{name}' with arguments: {arguments}") + + result = await self._send_jsonrpc_request("tools/call", { + "name": name, + "arguments": arguments, + }) + + # Parse the result into CallToolResult + content = [] + for item in result.get("content", []): + item_type = item.get("type", "text") + if item_type == "text": + content.append(TextContent(type="text", text=item.get("text", ""))) + elif item_type == "image": + content.append(ImageContent( + type="image", + data=item.get("data", ""), + mimeType=item.get("mimeType", "image/png"), + )) + elif item_type == "resource": + content.append(EmbeddedResource( + type="resource", + resource=item.get("resource", {}), + )) + + if not content and result: + content.append(TextContent(type="text", text=str(result))) + + return CallToolResult( + content=content, + isError=result.get("isError", False), + ) + + async def list_resources(self) -> List[Resource]: + """List all available resources.""" + if not self._use_jsonrpc: + return await super().list_resources() + + await self._ensure_connected() + try: + resources_result = await self._send_jsonrpc_request("resources/list", {}) + self._resources = self._parse_resources_from_json(resources_result.get("resources", [])) + return self._resources + except Exception as e: + logger.error(f"Error listing resources: {e}") + return [] + + async def read_resource(self, uri: str) -> ReadResourceResult: + """Read a resource by URI.""" + if not self._use_jsonrpc: + return await super().read_resource(uri) + + await self._ensure_connected() + result = await self._send_jsonrpc_request("resources/read", {"uri": uri}) + return ReadResourceResult(**result) + + async def list_prompts(self) -> List[Prompt]: + """List all available prompts.""" + if not self._use_jsonrpc: + return await super().list_prompts() + + await self._ensure_connected() + try: + prompts_result = await self._send_jsonrpc_request("prompts/list", {}) + self._prompts = self._parse_prompts_from_json(prompts_result.get("prompts", [])) + return self._prompts + except Exception as e: + logger.error(f"Error listing prompts: {e}") + return [] + + async def get_prompt(self, name: str, arguments: Dict[str, Any] | None = None) -> GetPromptResult: + """Get a prompt by name.""" + if not self._use_jsonrpc: + return await super().get_prompt(name, arguments) + + await self._ensure_connected() + result = await self._send_jsonrpc_request("prompts/get", { + "name": name, + "arguments": arguments or {}, + }) + return GetPromptResult(**result) + + async def request(self, method: str, params: Dict[str, Any] | None = None) -> Any: + """Send a raw request to the MCP implementation.""" + if not self._use_jsonrpc: + return await super().request(method, params) + + await self._ensure_connected() + return await self._send_jsonrpc_request(method, params or {}) + + async def invoke(self, name: str, params: Dict[str, Any]) -> Any: + """Invoke a tool or special method.""" + if not self._use_jsonrpc: + return await super().invoke(name, params) + + await self._ensure_connected() + + if not name.startswith("__"): + return await self.call_tool(name, params) + + if name == "__read_resource__": + return await self.read_resource(params["uri"]) + if name == "__list_prompts__": + return await self.list_prompts() + if name == "__get_prompt__": + return await self.get_prompt(params["name"], params.get("args")) + + raise ValueError(f"Unsupported MCP invoke name: {name}") diff --git a/openspace/grounding/backends/mcp/transport/connectors/sandbox.py b/openspace/grounding/backends/mcp/transport/connectors/sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b6f4c532485caa714e07f66b30ec324cb773ac --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/sandbox.py @@ -0,0 +1,251 @@ +""" +Sandbox connector for MCP implementations. + +This module provides a connector for communicating with MCP implementations +that are executed inside a sandbox environment (supports any BaseSandbox implementation). +""" + +import asyncio +import sys +import time + +import aiohttp +from mcp import ClientSession + +from openspace.utils.logging import Logger +from openspace.grounding.backends.mcp.transport.task_managers import SseConnectionManager +from openspace.grounding.core.security import BaseSandbox +from openspace.grounding.backends.mcp.transport.connectors.base import MCPBaseConnector + +logger = Logger.get_logger(__name__) + + +class SandboxConnector(MCPBaseConnector): + """Connector for MCP implementations running in a sandbox environment. + + This connector runs a user-defined stdio command within a sandbox environment + through a BaseSandbox implementation (e.g., E2BSandbox), potentially wrapped + by a utility like 'supergateway' to expose its stdio. + """ + + def __init__( + self, + sandbox: BaseSandbox, + command: str, + args: list[str], + env: dict[str, str] | None = None, + supergateway_command: str = "npx -y supergateway", + port: int = 3000, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, + ): + """Initialize a new sandbox connector. + + Args: + sandbox: A BaseSandbox implementation (e.g., E2BSandbox) to run commands in. + command: The user's MCP server command to execute in the sandbox. + args: Command line arguments for the user's MCP server command. + env: Environment variables for the user's MCP server command. + supergateway_command: Command to run supergateway (default: "npx -y supergateway"). + port: Port number for the sandbox server (default: 3000). + timeout: Timeout for the sandbox process in seconds. + sse_read_timeout: Timeout for the SSE connection in seconds. + """ + # Store user command configuration + self.user_command = command + self.user_args = args or [] + self.user_env = env or {} + self.port = port + + # Create a placeholder connection manager (will be set up in connect()) + # We need the sandbox to start first to get the base_url, so we can't create + # the real SseConnectionManager until connect() is called + from openspace.grounding.core.transport.task_managers import PlaceholderConnectionManager + connection_manager = PlaceholderConnectionManager() + super().__init__(connection_manager) + + # Sandbox configuration + self._sandbox = sandbox + self.supergateway_cmd_parts = supergateway_command + + # Runtime state + self.process = None + self.client_session: ClientSession | None = None + self.errlog = sys.stderr + self.base_url: str | None = None + self._connected = False + self._connection_manager: SseConnectionManager | None = None + + # SSE connection parameters + self.headers = {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + + self.stdout_lines: list[str] = [] + self.stderr_lines: list[str] = [] + self._server_ready = asyncio.Event() + + def _handle_stdout(self, data: str) -> None: + """Handle stdout data from the sandbox process.""" + self.stdout_lines.append(data) + logger.debug(f"[SANDBOX STDOUT] {data}", end="", flush=True) + + def _handle_stderr(self, data: str) -> None: + """Handle stderr data from the sandbox process.""" + self.stderr_lines.append(data) + logger.debug(f"[SANDBOX STDERR] {data}", file=self.errlog, end="", flush=True) + + async def wait_for_server_response(self, base_url: str, timeout: int = 30) -> bool: + """Wait for the server to respond to HTTP requests. + + Args: + base_url: The base URL to check for server readiness + timeout: Maximum time to wait in seconds + + Returns: + True if server is responding, raises TimeoutError otherwise + """ + logger.info(f"Waiting for server at {base_url} to respond...") + sys.stdout.flush() + + start_time = time.time() + ping_url = f"{base_url}/sse" + + # Try to connect to the server + while time.time() - start_time < timeout: + try: + async with aiohttp.ClientSession() as session: + try: + # First try the endpoint + async with session.get(ping_url, timeout=2) as response: + if response.status == 200: + elapsed = time.time() - start_time + logger.info(f"Server is ready! SSE endpoint responded with 200 after {elapsed:.1f}s") + return True + except Exception: + # If sse endpoint doesn't work, try the base URL + async with session.get(base_url, timeout=2) as response: + if response.status < 500: # Accept any non-server error + elapsed = time.time() - start_time + logger.info( + f"Server is ready! Base URL responded with {response.status} after {elapsed:.1f}s" + ) + return True + except Exception: + # Wait a bit before trying again + await asyncio.sleep(0.5) + continue + + # If we get here, the request failed + await asyncio.sleep(0.5) + + # Log status every 5 seconds + elapsed = time.time() - start_time + if int(elapsed) % 5 == 0: + logger.info(f"Still waiting for server to respond... ({elapsed:.1f}s elapsed)") + sys.stdout.flush() + + # If we get here, we timed out + raise TimeoutError(f"Timeout waiting for server to respond (waited {timeout} seconds)") + + async def _before_connect(self) -> None: + """Set up the sandbox and prepare the connection manager.""" + logger.debug("Connecting to MCP implementation in sandbox") + + # Start the sandbox if not already active + if not self._sandbox.is_active: + logger.debug("Starting sandbox...") + await self._sandbox.start() + + # Get the host for the sandbox + # Note: This assumes the sandbox implementation has a get_host method + # For E2BSandbox, this is available + host = self._sandbox.get_host(self.port) + self.base_url = f"https://{host}".rstrip("/") + + # Append command with args + command = f"{self.user_command} {' '.join(self.user_args)}" + + # Construct the full command with supergateway + full_command = f'{self.supergateway_cmd_parts} \ + --base-url {self.base_url} \ + --port {self.port} \ + --cors \ + --stdio "{command}"' + + logger.debug(f"Full command: {full_command}") + + # Execute the command in the sandbox + self.process = await self._sandbox.execute_safe( + full_command, + envs=self.user_env, + timeout=1000 * 60 * 10, # 10 minutes timeout + background=True, + on_stdout=self._handle_stdout, + on_stderr=self._handle_stderr, + ) + + # Wait for the server to be ready + await self.wait_for_server_response(self.base_url, timeout=30) + logger.debug("Initializing connection manager...") + + # Create the SSE connection URL + sse_url = f"{self.base_url}/sse" + + # Create and set up the connection manager + self._connection_manager = SseConnectionManager(sse_url, self.headers, self.timeout, self.sse_read_timeout) + + async def _after_connect(self) -> None: + """Create ClientSession and log success.""" + await super()._after_connect() + logger.debug(f"Successfully connected to MCP implementation via HTTP/SSE in sandbox: {self.base_url}") + + async def _before_disconnect(self) -> None: + """Clean up sandbox-specific resources before disconnection.""" + logger.debug("Cleaning up sandbox resources") + + # Stop the sandbox (which will clean up processes) + if self._sandbox and self._sandbox.is_active: + try: + logger.debug("Stopping sandbox instance") + await self._sandbox.stop() + logger.debug("Sandbox instance stopped successfully") + except Exception as e: + logger.warning(f"Error stopping sandbox: {e}") + + self.process = None + + # Call the parent method to clean up MCP resources + await super()._before_disconnect() + + # Clear any collected output + self.stdout_lines = [] + self.stderr_lines = [] + self.base_url = None + + async def _cleanup_on_connect_failure(self) -> None: + """Clean up sandbox resources on connection failure.""" + # Stop the sandbox if it was started + if self._sandbox and self._sandbox.is_active: + try: + await self._sandbox.stop() + except Exception as e: + logger.warning(f"Error stopping sandbox during cleanup: {e}") + + self.process = None + self.stdout_lines = [] + self.stderr_lines = [] + self.base_url = None + + # Call parent cleanup + await super()._cleanup_on_connect_failure() + + @property + def sandbox(self) -> BaseSandbox: + """Get the underlying sandbox instance.""" + return self._sandbox + + @property + def public_identifier(self) -> str: + """Get the identifier for the connector.""" + return {"type": "sandbox", "command": self.user_command, "args": self.user_args} diff --git a/openspace/grounding/backends/mcp/transport/connectors/stdio.py b/openspace/grounding/backends/mcp/transport/connectors/stdio.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5a52ec1027408a6ae6cb59f8e2c910176bfcbd --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/stdio.py @@ -0,0 +1,76 @@ +""" +StdIO connector for MCP implementations. + +This module provides a connector for communicating with MCP implementations +through the standard input/output streams. +""" + +import sys + +from mcp import ClientSession, StdioServerParameters + +from openspace.utils.logging import Logger +from ..task_managers import StdioConnectionManager +from .base import MCPBaseConnector + +logger = Logger.get_logger(__name__) + + +class StdioConnector(MCPBaseConnector): + """Connector for MCP implementations using stdio transport. + + This connector uses the stdio transport to communicate with MCP implementations + that are executed as child processes. It uses a connection manager to handle + the proper lifecycle management of the stdio client. + """ + + def __init__( + self, + command: str = "npx", + args: list[str] | None = None, + env: dict[str, str] | None = None, + errlog=None, + ): + """Initialize a new stdio connector. + + Args: + command: The command to execute. + args: Optional command line arguments. + env: Optional environment variables. + errlog: Stream to write error output to (defaults to filtered stderr). + StdioConnectionManager will wrap this to filter harmless errors. + """ + self.command = command + self.args = args or [] # Ensure args is never None + + # Ensure env is not None and add settings to suppress non-JSON output from servers + self.env = env or {} + # Add environment variables to encourage MCP servers to suppress non-JSON output + # Many Node.js-based servers respect NODE_ENV=production + if "NODE_ENV" not in self.env: + self.env["NODE_ENV"] = "production" + # Add flag to suppress informational messages (some servers respect this) + if "MCP_SILENT" not in self.env: + self.env["MCP_SILENT"] = "true" + + self.errlog = errlog + + # Create server parameters and connection manager + # StdioConnectionManager will wrap errlog in FilteredStderrWrapper + server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env) + connection_manager = StdioConnectionManager(server_params, self.errlog) + super().__init__(connection_manager) + + async def _before_connect(self) -> None: + """Log connection attempt.""" + logger.debug(f"Connecting to MCP implementation: {self.command}") + + async def _after_connect(self) -> None: + """Create ClientSession and log success.""" + # Call parent's _after_connect to create the ClientSession + await super()._after_connect() + logger.debug(f"Successfully connected to MCP implementation: {self.command}") + + @property + def public_identifier(self) -> dict[str, str]: + return {"type": "stdio", "command&args": f"{self.command} {' '.join(self.args)}"} \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/connectors/utils.py b/openspace/grounding/backends/mcp/transport/connectors/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9ef9253e9dd939ea9e57e767b52c35285cdbd8 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/utils.py @@ -0,0 +1,13 @@ +from typing import Any + + +def is_stdio_server(server_config: dict[str, Any]) -> bool: + """Check if the server configuration is for a stdio server. + + Args: + server_config: The server configuration section + + Returns: + True if the server is a stdio server, False otherwise + """ + return "command" in server_config and "args" in server_config \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/connectors/websocket.py b/openspace/grounding/backends/mcp/transport/connectors/websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..5408935cc1453bbbeda90e2e856859f1cb592bd1 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/connectors/websocket.py @@ -0,0 +1,245 @@ +""" +WebSocket connector for MCP implementations. + +This module provides a connector for communicating with MCP implementations +through WebSocket connections. +""" + +import asyncio +import json +import uuid +from typing import Any + +from mcp.types import Tool +from websockets import ClientConnection + +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers.base import BaseConnectionManager +from ..task_managers import WebSocketConnectionManager +from .base import MCPBaseConnector + +logger = Logger.get_logger(__name__) + + +class WebSocketConnector(MCPBaseConnector): + """Connector for MCP implementations using WebSocket transport. + + This connector uses WebSockets to communicate with remote MCP implementations, + using a connection manager to handle the proper lifecycle management. + """ + + def __init__( + self, + url: str, + auth_token: str | None = None, + headers: dict[str, str] | None = None, + ): + """Initialize a new WebSocket connector. + + Args: + url: The WebSocket URL to connect to. + auth_token: Optional authentication token. + headers: Optional additional headers. + """ + self.url = url + self.auth_token = auth_token + self.headers = headers or {} + if auth_token: + self.headers["Authorization"] = f"Bearer {auth_token}" + + self.ws: ClientConnection | None = None + self._receiver_task: asyncio.Task | None = None + self.pending_requests: dict[str, asyncio.Future] = {} + self._tools: list[Tool] | None = None + + # Create connection manager with actual parameters + connection_manager = WebSocketConnectionManager(self.url, self.headers) + super().__init__(connection_manager) + self._connected = False + + async def _get_streams_from_connection(self): + """WebSocket doesn't use streams, return None to skip ClientSession creation.""" + return None + + async def _after_connect(self) -> None: + """Set up WebSocket-specific resources after connection. + + WebSocket doesn't use ClientSession, so we skip the parent's implementation + and set up WebSocket-specific resources instead. + """ + # Store the WebSocket connection + self.ws = self._connection + + # Start the message receiver task + self._receiver_task = asyncio.create_task(self._receive_messages(), name="websocket_receiver_task") + + logger.debug(f"Successfully connected to MCP implementation via WebSocket: {self.url}") + + async def _receive_messages(self) -> None: + """Continuously receive and process messages from the WebSocket.""" + if not self.ws: + raise RuntimeError("WebSocket is not connected") + + try: + async for message in self.ws: + # Parse the message + data = json.loads(message) + + # Check if this is a response to a pending request + request_id = data.get("id") + if request_id and request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + if "result" in data: + future.set_result(data["result"]) + elif "error" in data: + future.set_exception(Exception(data["error"])) + + logger.debug(f"Received response for request {request_id}") + else: + logger.debug(f"Received message: {data}") + except Exception as e: + logger.error(f"Error in WebSocket message receiver: {e}") + # If the websocket connection was closed or errored, + # reject all pending requests + for future in self.pending_requests.values(): + if not future.done(): + future.set_exception(e) + + async def _before_disconnect(self) -> None: + """Clean up WebSocket-specific resources before disconnection.""" + errors = [] + + # First cancel the receiver task + if self._receiver_task and not self._receiver_task.done(): + try: + logger.debug("Cancelling WebSocket receiver task") + self._receiver_task.cancel() + try: + await self._receiver_task + except asyncio.CancelledError: + logger.debug("WebSocket receiver task cancelled successfully") + except Exception as e: + logger.warning(f"Error during WebSocket receiver task cancellation: {e}") + except Exception as e: + error_msg = f"Error cancelling WebSocket receiver task: {e}" + logger.warning(error_msg) + errors.append(error_msg) + finally: + self._receiver_task = None + + # Reject any pending requests + if self.pending_requests: + logger.debug(f"Rejecting {len(self.pending_requests)} pending requests") + for future in self.pending_requests.values(): + if not future.done(): + future.set_exception(ConnectionError("WebSocket disconnected")) + self.pending_requests.clear() + + # Reset WebSocket and tools + self.ws = None + self._tools = None + + if errors: + logger.warning(f"Encountered {len(errors)} errors during WebSocket resource cleanup") + + async def _cleanup_on_connect_failure(self) -> None: + """Clean up WebSocket resources on connection failure.""" + # Cancel receiver task if it was started + if self._receiver_task and not self._receiver_task.done(): + try: + self._receiver_task.cancel() + await self._receiver_task + except asyncio.CancelledError: + pass + except Exception: + pass + finally: + self._receiver_task = None + + # Reject pending requests + for future in self.pending_requests.values(): + if not future.done(): + future.set_exception(ConnectionError("Connection failed")) + self.pending_requests.clear() + + # Call parent cleanup + await super()._cleanup_on_connect_failure() + self.ws = None + + async def _send_request(self, method: str, params: dict[str, Any] | None = None) -> Any: + """Send a request and wait for a response.""" + if not self.ws: + raise RuntimeError("WebSocket is not connected") + + # Create a request ID + request_id = str(uuid.uuid4()) + + # Create a future to receive the response + future = asyncio.Future() + self.pending_requests[request_id] = future + + # Send the request + await self.ws.send(json.dumps({"id": request_id, "method": method, "params": params or {}})) + + logger.debug(f"Sent request {request_id} method: {method}") + + # Wait for the response + try: + return await future + except Exception as e: + # Remove the request from pending requests + self.pending_requests.pop(request_id, None) + logger.error(f"Error waiting for response to request {request_id}: {e}") + raise + + async def initialize(self) -> dict[str, Any]: + """Initialize the MCP session and return session information.""" + logger.debug("Initializing MCP session") + result = await self._send_request("initialize") + + # Get available tools + tools_result = await self.list_tools() + self._tools = [Tool(**tool) for tool in tools_result] + + logger.debug(f"MCP session initialized with {len(self._tools)} tools") + return result + + async def list_tools(self) -> list[dict[str, Any]]: + """List all available tools from the MCP implementation.""" + logger.debug("Listing tools") + result = await self._send_request("tools/list") + return result.get("tools", []) + + @property + def tools(self) -> list[Tool]: + """Get the list of available tools.""" + if not self._tools: + raise RuntimeError("MCP client is not initialized") + return self._tools + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + """Call an MCP tool with the given arguments.""" + logger.debug(f"Calling tool '{name}' with arguments: {arguments}") + return await self._send_request("tools/call", {"name": name, "arguments": arguments}) + + async def list_resources(self) -> list[dict[str, Any]]: + """List all available resources from the MCP implementation.""" + logger.debug("Listing resources") + result = await self._send_request("resources/list") + return result + + async def read_resource(self, uri: str) -> tuple[bytes, str]: + """Read a resource by URI.""" + logger.debug(f"Reading resource: {uri}") + result = await self._send_request("resources/read", {"uri": uri}) + return result.get("content", b""), result.get("mimeType", "") + + async def request(self, method: str, params: dict[str, Any] | None = None) -> Any: + """Send a raw request to the MCP implementation.""" + logger.debug(f"Sending request: {method} with params: {params}") + return await self._send_request(method, params) + + @property + def public_identifier(self) -> str: + """Get the identifier for the connector.""" + return {"type": "websocket", "url": self.url} diff --git a/openspace/grounding/backends/mcp/transport/task_managers/__init__.py b/openspace/grounding/backends/mcp/transport/task_managers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9def05b27802fee698ce912a1aad99b5711e700f --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/task_managers/__init__.py @@ -0,0 +1,18 @@ +""" +Connectors for various MCP transports. + +This module provides interfaces for connecting to MCP implementations +through different transport mechanisms. +""" + +from .sse import SseConnectionManager +from .stdio import StdioConnectionManager +from .streamable_http import StreamableHttpConnectionManager +from .websocket import WebSocketConnectionManager + +__all__ = [ + "StdioConnectionManager", + "WebSocketConnectionManager", + "SseConnectionManager", + "StreamableHttpConnectionManager", +] \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/task_managers/sse.py b/openspace/grounding/backends/mcp/transport/task_managers/sse.py new file mode 100644 index 0000000000000000000000000000000000000000..afe37e1d7ac00e6a989625ec3702b85e162bd7cd --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/task_managers/sse.py @@ -0,0 +1,50 @@ +""" +SSE connection management for MCP implementations. + +This module provides a connection manager for SSE-based MCP connections +that ensures proper task isolation and resource cleanup. +""" + +from typing import Any, Tuple +from mcp.client.sse import sse_client +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers import ( + AsyncContextConnectionManager, +) + +logger = Logger.get_logger(__name__) + + +class SseConnectionManager(AsyncContextConnectionManager[Tuple[Any, Any], ...]): + """Connection manager for SSE-based MCP connections. + + This class handles the proper task isolation for sse_client context managers + to prevent the "cancel scope in different task" error. It runs the sse_client + in a dedicated task and manages its lifecycle. + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, + ): + """Initialize a new SSE connection manager. + + Args: + url: The SSE endpoint URL + headers: Optional HTTP headers + timeout: Timeout for HTTP operations in seconds + sse_read_timeout: Timeout for SSE read operations in seconds + """ + super().__init__( + sse_client, + url=url, + headers=headers or {}, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) + self.url = url + self.headers = headers or {} + logger.debug("SseConnectionManager init url=%s", url) diff --git a/openspace/grounding/backends/mcp/transport/task_managers/stdio.py b/openspace/grounding/backends/mcp/transport/task_managers/stdio.py new file mode 100644 index 0000000000000000000000000000000000000000..533bd1f7c5174beef312411ef64b1980ca489553 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/task_managers/stdio.py @@ -0,0 +1,351 @@ +""" +StdIO connection management for MCP implementations. + +This module provides a connection manager for stdio-based MCP connections +that ensures proper task isolation and resource cleanup. +""" + +import asyncio +import io +import logging +import sys +from typing import Any, TextIO, Tuple + +from mcp import StdioServerParameters +from mcp.client.stdio import stdio_client + +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers import ( + AsyncContextConnectionManager, +) + +logger = Logger.get_logger(__name__) + + +class FilteredStderrWrapper(io.TextIOBase): + """Wrapper for stderr that filters out harmless MCP server shutdown messages. + + This wrapper suppresses error messages from MCP servers during shutdown + that are harmless but create noise in the logs. + """ + + def __init__(self, wrapped_stream: TextIO): + """Initialize the wrapper. + + Args: + wrapped_stream: The underlying stderr stream + """ + self._stream = wrapped_stream + self._buffer = "" + self._in_traceback = False + self._traceback_lines = [] + self._in_rich_traceback = False # Track rich-formatted tracebacks + self._rich_traceback_needs_error_line = False # After ╰, need one more line + + def write(self, s: str) -> int: + """Write to stderr, filtering out harmless error messages. + + Args: + s: The string to write + + Returns: + Number of characters written + """ + # Buffer the input for line-by-line processing + self._buffer += s + + # Process complete lines + while '\n' in self._buffer: + line, self._buffer = self._buffer.split('\n', 1) + self._process_line(line + '\n') + + return len(s) + + def _process_line(self, line: str): + """Process a single line and decide whether to output it.""" + # Detect start of traceback or exception group + if line.lstrip().startswith(("╭", "┏")): + self._in_traceback = True + self._in_rich_traceback = True + self._rich_traceback_needs_error_line = False + self._traceback_lines = [line] + return + + if (line.strip().startswith('Traceback (most recent call last)') or + line.strip().startswith('Exception Group Traceback (most recent call last)') or + line.strip().startswith('BaseExceptionGroup:') or + line.strip().startswith('ExceptionGroup:')): + self._in_traceback = True + self._traceback_lines = [line] + self._in_rich_traceback = False + self._rich_traceback_needs_error_line = False + return + + # Collect traceback lines + if self._in_traceback: + self._traceback_lines.append(line) + + # If not in rich traceback mode, but current line contains rich border characters, switch to rich mode + if not self._in_rich_traceback and any(ch in line for ch in ("╭", "┏")): + self._in_rich_traceback = True + + # Check for end of rich-formatted traceback (line with ╰) + if self._in_rich_traceback and '╰' in line: + # Rich traceback box ended, but we need to collect the error line that follows + self._rich_traceback_needs_error_line = True + return + + # If we just ended a rich traceback, this should be the error line + if self._rich_traceback_needs_error_line: + # Now we have the complete rich traceback including the error line + if self._is_harmless_error(): + logger.debug(f"Suppressed harmless rich-formatted MCP server error") + else: + # Output the full traceback + for tb_line in self._traceback_lines: + self._stream.write(tb_line) + self._stream.flush() + + # Reset traceback collection + self._in_traceback = False + self._in_rich_traceback = False + self._rich_traceback_needs_error_line = False + self._traceback_lines = [] + return + + # For exception groups, we need to collect more lines + # Check if we've collected enough to determine if it's harmless + if len(self._traceback_lines) > 5 and not self._in_rich_traceback: + # Check periodically if this is a harmless error + if self._is_harmless_error(): + # Suppress this traceback + logger.debug(f"Suppressed harmless MCP server shutdown error") + self._in_traceback = False + self._in_rich_traceback = False + self._rich_traceback_needs_error_line = False + self._traceback_lines = [] + return + + # Check if this is the error line (last line of regular traceback) + # But not for rich tracebacks which use box characters + # A final traceback line is typically unindented and contains "ErrorType: message" + if not self._in_rich_traceback and line and not line[0].isspace() and ':' in line: + # Check if this is a harmless cleanup error + if self._is_harmless_error(): + # Suppress this traceback + logger.debug(f"Suppressed harmless MCP server shutdown error") + else: + # Output the full traceback + for tb_line in self._traceback_lines: + self._stream.write(tb_line) + self._stream.flush() + + # Reset traceback collection + self._in_traceback = False + self._in_rich_traceback = False + self._rich_traceback_needs_error_line = False + self._traceback_lines = [] + return + + # If we've collected too many lines without finding the end, output and reset + if len(self._traceback_lines) > 100: + # Output what we have + for tb_line in self._traceback_lines: + self._stream.write(tb_line) + self._stream.flush() + self._in_traceback = False + self._in_rich_traceback = False + self._rich_traceback_needs_error_line = False + self._traceback_lines = [] + return + else: + # Normal line - check if it's a harmless error log + line_lower = line.lower() + harmless_log_patterns = [ + 'an error occurred during closing of asynchronous generator', + 'asyncgen:', + 'service stopped.', + ] + + # Check if this is a harmless log line + is_harmless_log = any(pattern in line_lower for pattern in harmless_log_patterns) + + if not is_harmless_log: + # Output normal lines + self._stream.write(line) + self._stream.flush() + else: + # Suppress harmless log messages + logger.debug(f"Suppressed harmless log line: {line.strip()}") + + def _is_harmless_error(self) -> bool: + """Check if the collected traceback is a harmless error.""" + traceback_text = ''.join(self._traceback_lines).lower() + + # List of harmless error patterns (case-insensitive) + harmless_patterns = [ + 'valueerror: i/o operation on closed file', + 'oserror: [errno 9] bad file descriptor', + 'brokenpipeerror', + 'runtimeerror: attempted to exit cancel scope in a different task', + 'baseexceptiongroup: unhandled errors in a taskgroup', + 'generatorexit', + 'an error occurred during closing of asynchronous generator', + ] + + # Check if any pattern matches and it's related to shutdown + for pattern in harmless_patterns: + if pattern in traceback_text: + # Also check if it's related to shutdown/cleanup + shutdown_keywords = ['finally:', 'stopped', 'cleanup', '__exit__', '__aexit__', 'stdio_client', 'service stopped'] + if any(keyword in traceback_text for keyword in shutdown_keywords): + return True + + return False + + def flush(self): + """Flush any remaining buffered content and the underlying stream.""" + if self._buffer: + self._process_line(self._buffer) + self._buffer = "" + + if self._traceback_lines: + # Flush incomplete traceback + for line in self._traceback_lines: + self._stream.write(line) + self._traceback_lines = [] + + self._stream.flush() + + def fileno(self) -> int: + """Return the file descriptor of the underlying stream.""" + if hasattr(self._stream, 'fileno'): + return self._stream.fileno() + return -1 + + @property + def closed(self) -> bool: + """Check if the stream is closed.""" + return self._stream.closed + + +class StdioConnectionManager(AsyncContextConnectionManager[Tuple[Any, Any], ...]): + """Connection manager for stdio-based MCP connections. + + This class handles the proper task isolation for stdio_client context managers + to prevent the "cancel scope in different task" error. It runs the stdio_client + in a dedicated task and manages its lifecycle. + + Note: Error handling during cleanup (e.g., I/O operations on closed files) is + handled by the parent AsyncContextConnectionManager class in _close_connection(). + """ + + def __init__( + self, + server_params: StdioServerParameters, + errlog: TextIO | None = None, + ): + """Initialize a new stdio connection manager. + + Args: + server_params: The parameters for the stdio server + errlog: The error log stream (defaults to filtered sys.stderr) + """ + # Wrap stderr to filter out harmless shutdown errors + if errlog is None: + errlog = FilteredStderrWrapper(sys.stderr) + elif not isinstance(errlog, FilteredStderrWrapper): + errlog = FilteredStderrWrapper(errlog) + + super().__init__(stdio_client, server_params, errlog) + self.server_params = server_params + self.errlog = errlog + self._mcp_logger_filter = None + self._stop_event: asyncio.Event | None = None # Signal for background task + self._runner_task: asyncio.Task | None = None # Background runner task + self._conn_future: asyncio.Future | None = None # Future for the established connection + logger.debug("StdioConnectionManager init with params=%s", server_params) + + async def _establish_connection(self) -> Tuple[Any, Any]: + """Establish connection in a dedicated task to avoid cancel-scope issues.""" + # Suppress MCP SDK's noisy JSON parse errors **before** starting the runner + self._suppress_mcp_json_errors() + + # Lazily create primitives the first time we connect + if self._stop_event is None: + self._stop_event = asyncio.Event() + if self._conn_future is None or self._conn_future.done(): + self._conn_future = asyncio.get_event_loop().create_future() + + async def _runner(): # Runs in its *own* task (same task for enter/exit) + try: + async with stdio_client(self.server_params, self.errlog) as conn: + # Pass connection back to the caller + if not self._conn_future.done(): + self._conn_future.set_result(conn) + # Wait until close is requested + await self._stop_event.wait() + finally: + # Make sure the future is set even on error so awaiters don’t hang + if not self._conn_future.done(): + self._conn_future.set_exception(RuntimeError("Connection failed")) + + # Start background runner if not already active + if self._runner_task is None or self._runner_task.done(): + self._runner_task = asyncio.create_task(_runner(), name="stdio_client_runner") + + # Wait for the connection tuple from the future + conn: Tuple[Any, Any] = await self._conn_future # type: ignore + return conn + + async def _close_connection(self) -> None: + """Request the background task to exit its context and wait for it.""" + try: + # Restore original logging configuration *before* shutdown + self._restore_mcp_logging() + + # Signal the runner to exit its context manager + if self._stop_event and not self._stop_event.is_set(): + self._stop_event.set() + + # Await the runner task so that __aexit__ executes in *its* task + if self._runner_task: + try: + await asyncio.wait_for(self._runner_task, timeout=2.0) + except asyncio.TimeoutError: + logger.warning("Timeout while waiting for stdio_client to shut down") + finally: + # Clean up helpers so next connect() creates new ones + self._runner_task = None + self._stop_event = None + self._conn_future = None + + def _suppress_mcp_json_errors(self): + """Suppress MCP SDK's JSON parsing error logs. + + The MCP SDK logs errors when it receives non-JSON messages from servers. + These are harmless (the SDK continues working), so we filter them out. + """ + mcp_logger = logging.getLogger("mcp.client.stdio") + + class JSONErrorFilter(logging.Filter): + """Filter out JSON parsing errors from MCP SDK.""" + def filter(self, record): + # Suppress "Failed to parse JSONRPC message" errors + if "Failed to parse JSONRPC message" in str(record.msg): + return False + return True + + self._mcp_logger_filter = JSONErrorFilter() + mcp_logger.addFilter(self._mcp_logger_filter) + + def _restore_mcp_logging(self): + """Restore MCP SDK logging to normal.""" + if self._mcp_logger_filter: + mcp_logger = logging.getLogger("mcp.client.stdio") + mcp_logger.removeFilter(self._mcp_logger_filter) + self._mcp_logger_filter = None + +if not isinstance(sys.stderr, FilteredStderrWrapper): + sys.stderr = FilteredStderrWrapper(sys.stderr) + logger.debug("Applied global FilteredStderrWrapper to sys.stderr") \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/task_managers/streamable_http.py b/openspace/grounding/backends/mcp/transport/task_managers/streamable_http.py new file mode 100644 index 0000000000000000000000000000000000000000..d42535ebd756f7125743b4951d83a5d7039297d6 --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/task_managers/streamable_http.py @@ -0,0 +1,103 @@ +""" +Streamable HTTP connection management for MCP implementations. + +This module provides a connection manager for streamable HTTP-based MCP connections +that ensures proper task isolation and resource cleanup. +""" + +from datetime import timedelta +from typing import Any, Tuple +from contextlib import asynccontextmanager + +from mcp.client.streamable_http import streamablehttp_client +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers import ( + AsyncContextConnectionManager, +) + +logger = Logger.get_logger(__name__) + + +def _make_shim(): + """ + Create a shim that wraps streamablehttp_client with improved error handling. + """ + @asynccontextmanager + async def _shim(**kw): + client_streams = None + ctx_manager = None + + try: + # Enter the context - this may raise ExceptionGroup during concurrent init + ctx_manager = streamablehttp_client(**kw) + try: + r, w, _sid_cb = await ctx_manager.__aenter__() + client_streams = (r, w) + except Exception as conn_error: + # Handle connection errors during __aenter__ + error_msg = str(conn_error).lower() + if "unhandled errors in a taskgroup" in error_msg: + logger.debug(f"TaskGroup race condition during connection: {type(conn_error).__name__}") + # Clean up and re-raise to trigger retry + if ctx_manager: + try: + await ctx_manager.__aexit__(None, None, None) + except Exception: + pass # Ignore cleanup errors + raise + else: + # Other connection errors - log and re-raise + logger.warning(f"Connection error: {conn_error}") + raise + + # Yield to caller + yield client_streams + + except GeneratorExit: + # Normal generator exit - this happens during cleanup + logger.debug("StreamableHTTP generator exit (normal cleanup)") + + finally: + # Always try to exit the context manager + if ctx_manager is not None: + try: + await ctx_manager.__aexit__(None, None, None) + except (GeneratorExit, RuntimeError, OSError, Exception) as e: + # Cleanup errors are expected during concurrent shutdown + # Log at debug level and suppress + error_type = type(e).__name__ + if "ExceptionGroup" in error_type or "TaskGroup" in str(e): + logger.debug(f"Benign TaskGroup cleanup error: {error_type}") + else: + logger.debug(f"Benign cleanup error: {error_type}") + + return _shim + + +class StreamableHttpConnectionManager( + AsyncContextConnectionManager[Tuple[Any, Any], ...] +): + """ + MCP Streamable-HTTP connection manager based on the generic + AsyncContextConnectionManager. Extra session-id callback returned by the + SDK is discarded by the shim above. + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + timeout: float = 5, + read_timeout: float = 60 * 5, + ): + shim = _make_shim() + super().__init__( + shim, + url=url, + headers=headers or {}, + timeout=timedelta(seconds=timeout), + sse_read_timeout=timedelta(seconds=read_timeout), + ) + self.url = url + self.headers = headers or {} + logger.debug("StreamableHttpConnectionManager init url=%s", url) \ No newline at end of file diff --git a/openspace/grounding/backends/mcp/transport/task_managers/websocket.py b/openspace/grounding/backends/mcp/transport/task_managers/websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..db873e07b0a4866555830444b72125b88b2d86dc --- /dev/null +++ b/openspace/grounding/backends/mcp/transport/task_managers/websocket.py @@ -0,0 +1,26 @@ +""" +WebSocket connection management for MCP implementations. + +This module provides a connection manager for WebSocket-based MCP connections. +""" + +from typing import Any, Tuple +from mcp.client.websocket import websocket_client +from openspace.utils.logging import Logger +from openspace.grounding.core.transport.task_managers import ( + AsyncContextConnectionManager, +) + +logger = Logger.get_logger(__name__) + +class WebSocketConnectionManager( + AsyncContextConnectionManager[Tuple[Any, Any], ...] +): + + def __init__(self, url: str, headers: dict[str, str] | None = None): + # Note: The current MCP websocket_client implementation doesn't support headers + # If headers need to be passed, this would need to be updated when MCP supports it + super().__init__(websocket_client, url) + self.url = url + self.headers = headers or {} + logger.debug("WebSocketConnectionManager init url=%s", url) \ No newline at end of file diff --git a/openspace/grounding/backends/shell/__init__.py b/openspace/grounding/backends/shell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a1cb1197d82deecaed1602567588870994ebc6 --- /dev/null +++ b/openspace/grounding/backends/shell/__init__.py @@ -0,0 +1,11 @@ +from .provider import ShellProvider +from .session import ShellSession +from .transport.connector import ShellConnector +from .transport.local_connector import LocalShellConnector + +__all__ = [ + "ShellProvider", + "ShellSession", + "ShellConnector", + "LocalShellConnector", +] \ No newline at end of file diff --git a/openspace/grounding/backends/shell/productivity_tools.py b/openspace/grounding/backends/shell/productivity_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..af327fdc120eb03f9735ab68049f06036df02b81 --- /dev/null +++ b/openspace/grounding/backends/shell/productivity_tools.py @@ -0,0 +1,426 @@ +""" +ClawWork-compatible productivity tools for fair benchmark comparison. + +When use_clawwork_productivity is enabled and the livebench package is installed, +these tools are added to the Shell backend so OpenSpace agents have the same +capabilities as ClawWork: search_web, read_webpage, create_file, read_file, +execute_code_sandbox, create_video. +""" + +import asyncio +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from openspace.grounding.core.types import BackendType, ToolResult, ToolStatus +from openspace.grounding.core.tool import BaseTool +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +# Lazy import to avoid hard dependency on ClawWork +_LIVEBENCH_AVAILABLE = False +_direct_tools = None +_productivity = None + + +def _ensure_livebench(): + """Import livebench modules if available.""" + global _LIVEBENCH_AVAILABLE, _direct_tools, _productivity + if _direct_tools is not None: + return _LIVEBENCH_AVAILABLE + try: + import livebench.tools.direct_tools as dt + import livebench.tools.productivity as prod + _direct_tools = dt + _productivity = prod + _LIVEBENCH_AVAILABLE = True + except ImportError as e: + logger.debug("ClawWork productivity tools not available: %s", e) + _LIVEBENCH_AVAILABLE = False + return _LIVEBENCH_AVAILABLE + + +def _set_global_state_for_productivity(data_path: str, current_date: str) -> None: + """Set ClawWork global state so productivity tools have data_path/date.""" + if not _direct_tools: + return + _direct_tools.set_global_state( + signature="openspace", + economic_tracker=None, + task_manager=None, + evaluator=None, + current_date=current_date, + current_task=None, + data_path=data_path, + supports_multimodal=True, + ) + + +def _dict_to_tool_result(out: Dict[str, Any]) -> ToolResult: + """Convert ClawWork tool dict to OpenSpace ToolResult.""" + if not isinstance(out, dict): + return ToolResult( + status=ToolStatus.ERROR, + content=str(out), + ) + err = out.get("error") + if err: + return ToolResult( + status=ToolStatus.ERROR, + content=err if isinstance(err, str) else json.dumps(err, ensure_ascii=False), + ) + return ToolResult( + status=ToolStatus.SUCCESS, + content=json.dumps(out, ensure_ascii=False, default=str), + ) + + +def _sync_invoke(tool_any: Any, args: Dict[str, Any]) -> Dict[str, Any]: + """Invoke a LangChain-style tool (sync) from async context.""" + if hasattr(tool_any, "invoke"): + return tool_any.invoke(args) + return tool_any(**args) + + +class _ProductivityToolBase(BaseTool): + """Base for productivity tools that delegate to ClawWork.""" + + backend_type = BackendType.SHELL + + def __init__(self, session: Any, data_path: str, current_date: str): + self._session = session + self._data_path = data_path or "." + self._current_date = current_date or "default" + super().__init__() + + async def _arun(self, **kwargs) -> ToolResult: + raise NotImplementedError("Subclasses must override _arun") + + async def _run_sync_tool(self, tool_obj: Any, args: Dict[str, Any]) -> ToolResult: + data_path = getattr(self._session, "default_working_dir", None) or self._data_path + _set_global_state_for_productivity(data_path, self._current_date) + try: + result = await asyncio.to_thread(_sync_invoke, tool_obj, args) + return _dict_to_tool_result(result) + except Exception as e: + logger.exception("Productivity tool %s failed", self.name) + return ToolResult(status=ToolStatus.ERROR, content=str(e)) + + +class SearchWebTool(_ProductivityToolBase): + _name = "search_web" + _description = ( + "Search the internet using Tavily or Jina. Returns structured results with " + "AI-generated answers. Use for up-to-date information." + ) + + async def _arun(self, query: str, max_results: int = 5) -> ToolResult: + return await self._run_sync_tool( + _productivity.search_web, + {"query": query, "max_results": max_results}, + ) + + +class ReadWebpageTool(_ProductivityToolBase): + _name = "read_webpage" + _description = ( + "Extract and read web page content from URLs using Tavily Extract. " + "Returns cleaned text in markdown format." + ) + + async def _arun(self, urls: str, query: Optional[str] = None) -> ToolResult: + return await self._run_sync_tool( + _productivity.read_webpage, + {"urls": urls, "query": query}, + ) + + +class CreateFileProductivityTool(_ProductivityToolBase): + _name = "create_file" + _description = ( + "Create a file in the current working directory. " + "Supports: txt, md, csv, json, xlsx, docx, pdf. " + "The file is created directly in your workspace." + ) + + async def _arun( + self, + filename: str, + content: str, + file_type: str = "txt", + ) -> ToolResult: + """Create a file via Shell connector so it lands in the task workspace.""" + file_type = file_type.lower().strip() + valid_types = ["txt", "md", "csv", "json", "xlsx", "docx", "pdf"] + if file_type not in valid_types: + return ToolResult( + status=ToolStatus.ERROR, + content=f"Invalid file type: {file_type}. Valid: {valid_types}", + ) + if not filename or not content: + return ToolResult(status=ToolStatus.ERROR, content="filename and content are required") + + import os + safe_name = os.path.basename(filename).replace("/", "_").replace("\\", "_") + # Strip extension from filename if it matches file_type to avoid .docx.docx + name_root, name_ext = os.path.splitext(safe_name) + if name_ext.lstrip(".").lower() == file_type: + safe_name = name_root + final_name = f"{safe_name}.{file_type}" + + escaped_content = json.dumps(content) + escaped_name = json.dumps(final_name) + + if file_type in ("txt", "md", "csv"): + code = ( + "import os\n" + f"name = {escaped_name}\n" + f"content = {escaped_content}\n" + "with open(name, 'w', encoding='utf-8') as f:\n" + " f.write(content)\n" + "sz = os.path.getsize(name)\n" + "print(f'Created {name} ({sz} bytes)')\n" + ) + elif file_type == "json": + code = ( + "import os, json\n" + f"name = {escaped_name}\n" + f"content = {escaped_content}\n" + "data = json.loads(content)\n" + "with open(name, 'w', encoding='utf-8') as f:\n" + " json.dump(data, f, indent=2, ensure_ascii=False)\n" + "sz = os.path.getsize(name)\n" + "print(f'Created {name} ({sz} bytes)')\n" + ) + elif file_type == "xlsx": + code = ( + "import os, json, io\n" + "import pandas as pd\n" + f"name = {escaped_name}\n" + f"content = {escaped_content}\n" + "try:\n" + " data = json.loads(content)\n" + " df = pd.DataFrame(data)\n" + "except:\n" + " df = pd.read_csv(io.StringIO(content))\n" + "df.to_excel(name, index=False, engine='openpyxl')\n" + "sz = os.path.getsize(name)\n" + "print(f'Created {name} ({sz} bytes)')\n" + ) + elif file_type == "docx": + code = ( + "import os\n" + "from docx import Document\n" + f"name = {escaped_name}\n" + f"content = {escaped_content}\n" + "doc = Document()\n" + "for para in content.split('\\n\\n'):\n" + " if para.strip():\n" + " doc.add_paragraph(para.strip())\n" + "doc.save(name)\n" + "sz = os.path.getsize(name)\n" + "print(f'Created {name} ({sz} bytes)')\n" + ) + elif file_type == "pdf": + code = ( + "import os\n" + "from reportlab.lib.pagesizes import letter\n" + "from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer\n" + "from reportlab.lib.styles import getSampleStyleSheet\n" + f"name = {escaped_name}\n" + f"content = {escaped_content}\n" + "doc = SimpleDocTemplate(name, pagesize=letter)\n" + "styles = getSampleStyleSheet()\n" + "story = []\n" + "for para in content.split('\\n\\n'):\n" + " if para.strip():\n" + " story.append(Paragraph(para.strip(), styles['Normal']))\n" + " story.append(Spacer(1, 12))\n" + "doc.build(story)\n" + "sz = os.path.getsize(name)\n" + "print(f'Created {name} ({sz} bytes)')\n" + ) + else: + return ToolResult(status=ToolStatus.ERROR, content=f"Unsupported: {file_type}") + + try: + from openspace.grounding.backends.shell.session import _parse_shell_result + working_dir = getattr(self._session, "default_working_dir", None) + result = await self._session.connector.run_python_script( + code, timeout=30, working_dir=working_dir, + ) + stdout, stderr, rc = _parse_shell_result(result) + if rc != 0: + return ToolResult(status=ToolStatus.ERROR, content=stderr or f"Failed to create {final_name}") + return ToolResult( + status=ToolStatus.SUCCESS, + content=f"Created {final_name} in workspace. {stdout.strip()}", + ) + except Exception as e: + return ToolResult(status=ToolStatus.ERROR, content=f"create_file failed: {e}") + + +class ReadFileProductivityTool(_ProductivityToolBase): + _name = "read_file" + _description = ( + "Read a file in various formats: pdf, docx, xlsx, pptx, png, jpg, jpeg, txt, json, md, csv, html, xml, yaml. " + "Returns content suitable for LLM consumption (text or images). " + "Relative paths are resolved against the task workspace directory." + ) + + def _resolve_path(self, file_path: str) -> Path: + """Resolve relative paths against the task workspace (data_path).""" + data_path = getattr(self._session, "default_working_dir", None) or self._data_path + p = Path(file_path) + if not p.is_absolute(): + resolved = Path(data_path) / p + if resolved.exists(): + return resolved + workspace = Path(data_path) + if workspace.is_dir(): + name = p.name + for candidate in workspace.rglob(name): + return candidate + return p + + async def _arun(self, filetype: str, file_path: str) -> ToolResult: + resolved = self._resolve_path(file_path) + return await self._run_sync_tool( + _productivity.read_file, + {"filetype": filetype, "file_path": resolved}, + ) + + +class ExecuteCodeSandboxTool(_ProductivityToolBase): + _name = "execute_code_sandbox" + _description = ( + "Execute Python code in a persistent sandbox. Supports artifact download via ARTIFACT_PATH:/path/to/file in output." + ) + + async def _arun(self, code: str, language: str = "python") -> ToolResult: + return await self._run_sync_tool( + _productivity.execute_code_sandbox, + {"code": code, "language": language}, + ) + + +class CreateVideoTool(_ProductivityToolBase): + _name = "create_video" + _description = ( + "Create a video from text slides and/or images. Input is a JSON string describing slides; output is MP4. " + "The video is created in the current working directory." + ) + + async def _arun( + self, + slides_json: str, + output_filename: str, + width: int = 1280, + height: int = 720, + fps: int = 24, + ) -> ToolResult: + """Create video via Shell connector so it lands in the task workspace.""" + import os + safe_name = os.path.basename(output_filename).replace("/", "_").replace("\\", "_") + if not safe_name.endswith(".mp4"): + safe_name = safe_name.rsplit(".", 1)[0] if "." in safe_name else safe_name + safe_name += ".mp4" + + escaped_slides = json.dumps(slides_json) + escaped_name = json.dumps(safe_name) + + code = ( + "import json, os\n" + f"slides_json = {escaped_slides}\n" + f"output_name = {escaped_name}\n" + f"width, height, fps = {width}, {height}, {fps}\n" + "slides = json.loads(slides_json)\n" + "try:\n" + " from PIL import Image, ImageDraw, ImageFont\n" + " import subprocess, tempfile, shutil\n" + " tmpdir = tempfile.mkdtemp()\n" + " frame_paths = []\n" + " for i, slide in enumerate(slides):\n" + " dur = slide.get('duration', 3.0)\n" + " n_frames = int(dur * fps)\n" + " if slide.get('type') == 'image' and slide.get('path'):\n" + " img = Image.open(slide['path']).resize((width, height))\n" + " else:\n" + " bg = slide.get('bg_color', '#000000')\n" + " tc = slide.get('text_color', '#FFFFFF')\n" + " img = Image.new('RGB', (width, height), bg)\n" + " draw = ImageDraw.Draw(img)\n" + " text = slide.get('content', '')\n" + " try:\n" + " font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 36)\n" + " except:\n" + " font = ImageFont.load_default()\n" + " bbox = draw.textbbox((0, 0), text, font=font)\n" + " tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]\n" + " draw.text(((width - tw) / 2, (height - th) / 2), text, fill=tc, font=font)\n" + " for j in range(n_frames):\n" + " fp = os.path.join(tmpdir, f'frame_{len(frame_paths):06d}.png')\n" + " img.save(fp)\n" + " frame_paths.append(fp)\n" + " cmd = ['ffmpeg', '-y', '-framerate', str(fps), '-i', os.path.join(tmpdir, 'frame_%06d.png'),\n" + " '-c:v', 'libx264', '-pix_fmt', 'yuv420p', output_name]\n" + " subprocess.run(cmd, capture_output=True, check=True)\n" + " shutil.rmtree(tmpdir)\n" + " sz = os.path.getsize(output_name)\n" + " print(f'Created {output_name} ({sz} bytes, {len(frame_paths)} frames)')\n" + "except Exception as e:\n" + " print(f'ERROR: {e}')\n" + " raise\n" + ) + + try: + from openspace.grounding.backends.shell.session import _parse_shell_result + working_dir = getattr(self._session, "default_working_dir", None) + result = await self._session.connector.run_python_script( + code, timeout=120, working_dir=working_dir, + ) + stdout, stderr, rc = _parse_shell_result(result) + if rc != 0: + return ToolResult(status=ToolStatus.ERROR, content=stderr or f"Failed to create video {safe_name}") + return ToolResult( + status=ToolStatus.SUCCESS, + content=f"Created video {safe_name} in workspace. {stdout.strip()}", + ) + except Exception as e: + return ToolResult(status=ToolStatus.ERROR, content=f"create_video failed: {e}") + + +def get_productivity_tools( + session: Any, + data_path: Optional[str] = None, + current_date: Optional[str] = None, +) -> List[BaseTool]: + """ + Return ClawWork-compatible productivity tools if livebench is installed. + + Args: + session: ShellSession (for compatibility; not used beyond data_path/date). + data_path: Sandbox root (default: session.default_working_dir or "."). + current_date: Date segment for sandbox paths (default: "default"). + + Returns: + List of tools to add to the session, or empty list if livebench unavailable. + """ + if not _ensure_livebench(): + return [] + path = data_path if data_path is not None else getattr(session, "default_working_dir", None) or "." + date = current_date if current_date is not None else "default" + return [ + SearchWebTool(session, data_path=path, current_date=date), + ReadWebpageTool(session, data_path=path, current_date=date), + CreateFileProductivityTool(session, data_path=path, current_date=date), + ReadFileProductivityTool(session, data_path=path, current_date=date), + ExecuteCodeSandboxTool(session, data_path=path, current_date=date), + CreateVideoTool(session, data_path=path, current_date=date), + ] + + +def is_productivity_available() -> bool: + """Return True if ClawWork productivity tools can be loaded.""" + return _ensure_livebench() diff --git a/openspace/grounding/backends/shell/provider.py b/openspace/grounding/backends/shell/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2884fbffb7c288ba40d47d3ddce51e41a3f92ce4 --- /dev/null +++ b/openspace/grounding/backends/shell/provider.py @@ -0,0 +1,105 @@ +from openspace.grounding.core.provider import Provider +from openspace.grounding.core.types import BackendType, SessionConfig +from .session import ShellSession +from .transport.connector import ShellConnector +from .transport.local_connector import LocalShellConnector +from openspace.config import get_config +from openspace.config.utils import get_config_value +from openspace.platforms.config import get_local_server_config +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class ShellProvider(Provider[ShellSession]): + + DEFAULT_SID = BackendType.SHELL.value + + def __init__(self, config: dict | None = None): + super().__init__(BackendType.SHELL, config) + # Note: _setup_security_policy() is already called by parent class __init__ + + def _setup_security_policy(self, config: dict | None = None): + security_policy = get_config().get_security_policy(self.backend_type.value) + + if config: + security_config = get_config_value(config, "security", None) + if security_config: + for key, value in security_config.items(): + if hasattr(security_policy, key): + setattr(security_policy, key, value) + + sandbox_enabled = get_config_value(config, "sandbox_enabled", None) + if sandbox_enabled is not None: + security_policy.sandbox_enabled = sandbox_enabled + + logger.info(f"Shell security policy: allow_shell_commands={security_policy.allow_shell_commands}, " + f"blocked_commands={security_policy.blocked_commands}") + + self.security_manager.set_backend_policy(BackendType.SHELL, security_policy) + + async def initialize(self) -> None: + if not self.is_initialized: + await self.create_session(SessionConfig( + session_name=self.DEFAULT_SID, + backend_type=BackendType.SHELL, + connection_params={} + )) + self.is_initialized = True + + async def create_session(self, session_config: SessionConfig) -> ShellSession: + sid = self.DEFAULT_SID + if sid in self._sessions: + return self._sessions[sid] + + # Use the config passed to ShellProvider (from GroundingClient), + # falling back to global config only if not available. + shell_config = self.config if self.config else get_config().get_backend_config("shell") + + # Determine execution mode: "local" or "server" + mode = getattr(shell_config, "mode", "local") + + if mode == "local": + # ---------- LOCAL MODE ---------- + # Execute scripts directly via subprocess, no server required. + logger.info("Shell backend using LOCAL mode (no server required)") + connector = LocalShellConnector( + retry_times=shell_config.max_retries, + retry_interval=shell_config.retry_interval, + security_manager=self.security_manager, + ) + else: + # ---------- SERVER MODE ---------- + # Connect to a running local_server via HTTP. + logger.info("Shell backend using SERVER mode (connecting to local_server)") + local_server_config = get_local_server_config() + default_port = local_server_config.get('port', shell_config.default_port) + + connector = ShellConnector( + vm_ip=get_config_value(session_config.connection_params, "vm_ip", local_server_config['host']), + port=get_config_value(session_config.connection_params, "port", default_port), + retry_times=shell_config.max_retries, + retry_interval=shell_config.retry_interval, + security_manager=self.security_manager, + ) + + # Create session with config parameters + session = ShellSession( + connector=connector, + session_id=sid, + security_manager=self.security_manager, + default_working_dir=shell_config.working_dir, + default_env=shell_config.env, + default_conda_env=shell_config.conda_env, + use_clawwork_productivity=getattr(shell_config, "use_clawwork_productivity", False), + productivity_date=getattr(shell_config, "productivity_date", "default"), + ) + + await session.initialize() + self._sessions[sid] = session + return session + + async def close_session(self, session_id: str) -> None: + sess = self._sessions.pop(session_id, None) + if sess: + await sess.disconnect() \ No newline at end of file diff --git a/openspace/grounding/backends/shell/session.py b/openspace/grounding/backends/shell/session.py new file mode 100644 index 0000000000000000000000000000000000000000..7dad2e2cbbe600643f2bcb92eb7da97fa05108f1 --- /dev/null +++ b/openspace/grounding/backends/shell/session.py @@ -0,0 +1,710 @@ +import json +import re +from typing import Any, Tuple, Union + +from openspace.grounding.core.types import BackendType, ToolResult, ToolStatus +from openspace.grounding.core.session import BaseSession +from openspace.grounding.backends.shell.transport.connector import ShellConnector +from openspace.grounding.backends.shell.transport.local_connector import LocalShellConnector +from openspace.grounding.core.tool import BaseTool +from openspace.grounding.core.security.policies import SecurityPolicyManager +from openspace.llm import LLMClient +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +def _parse_shell_result(result: Any) -> Tuple[str, str, int]: + """Parse a connector result dict into ``(stdout, stderr, returncode)``.""" + if isinstance(result, dict): + stdout = ( + result.get("content") + or result.get("output") + or result.get("stdout") + or "" + ) + stderr = result.get("error") or result.get("stderr") or "" + rc = result.get("returncode", 0) + return stdout, stderr, rc + return str(result), "", 0 + + +class ShellSession(BaseSession): + backend_type = BackendType.SHELL + + def __init__( + self, + connector: Union[ShellConnector, LocalShellConnector], + *, + session_id: str, + security_manager: SecurityPolicyManager = None, + default_working_dir: str = None, + default_env: dict = None, + default_conda_env: str = None, + model: str = None, + use_clawwork_productivity: bool = False, + productivity_date: str = "default", + ): + super().__init__(connector=connector, session_id=session_id, + backend_type=BackendType.SHELL) + self.security_manager = security_manager + self.default_working_dir = default_working_dir + self.default_env = default_env or {} + self.default_conda_env = default_conda_env + self.model = model + self.use_clawwork_productivity = use_clawwork_productivity + self.productivity_date = productivity_date or "default" + + async def initialize(self): + self.tools = [ + ShellAgentTool( + self, + security_manager=self.security_manager, + default_working_dir=self.default_working_dir, + default_env=self.default_env, + default_conda_env=self.default_conda_env, + model=self.model, + ), + # ReadFileTool(self), # Disabled: replaced by productivity read_file when use_clawwork_productivity is True + WriteFileTool(self), + ListDirTool(self), + RunShellTool(self), + ] + if not self.use_clawwork_productivity: + self.tools.insert(1, ReadFileTool(self)) + if self.use_clawwork_productivity: + from openspace.grounding.backends.shell.productivity_tools import get_productivity_tools + extra = get_productivity_tools( + self, + data_path=self.default_working_dir, + current_date=self.productivity_date, + ) + if extra: + self.tools.extend(extra) + logger.info("ClawWork productivity tools enabled: %s", [t.name for t in extra]) + else: + logger.warning("use_clawwork_productivity is True but livebench not available; productivity tools not added.") + return {"tools": [t.name for t in self.tools]} + +class PythonScriptTool(BaseTool): + _name = "_python_exec" + _description = "Internal helper: run python code." + + def __init__(self, session: "ShellSession", default_working_dir: str = None, default_env: dict = None, default_conda_env: str = None): + self._session = session + self._default_working_dir = default_working_dir + self._default_env = default_env or {} + self._default_conda_env = default_conda_env + super().__init__() + + async def _arun(self, code: str, timeout: int = 90, working_dir: str | None = None, env: dict | None = None, conda_env: str | None = None): + # Use provided params, or fall back to session defaults + effective_working_dir = working_dir or self._default_working_dir + effective_env = {**self._default_env, **(env or {})} # Merge default and provided env + effective_conda_env = conda_env or self._default_conda_env + return await self._session.connector.run_python_script( + code, + timeout=timeout, + working_dir=effective_working_dir, + env=effective_env if effective_env else None, + conda_env=effective_conda_env + ) + +class BashScriptTool(BaseTool): + _name = "_bash_exec" + _description = "Internal helper: run bash script." + + def __init__(self, session: "ShellSession", default_working_dir: str = None, default_env: dict = None, default_conda_env: str = None): + self._session = session + self._default_working_dir = default_working_dir + self._default_env = default_env or {} + self._default_conda_env = default_conda_env + super().__init__() + + async def _arun(self, script: str, timeout: int = 30, working_dir: str | None = None, env: dict | None = None, conda_env: str | None = None): + # Use provided params, or fall back to session defaults + effective_working_dir = working_dir or self._default_working_dir + effective_env = {**self._default_env, **(env or {})} # Merge default and provided env + effective_conda_env = conda_env or self._default_conda_env + return await self._session.connector.run_bash_script( + script, + timeout=timeout, + working_dir=effective_working_dir, + env=effective_env if effective_env else None, + conda_env=effective_conda_env + ) + + +class ReadFileTool(BaseTool): + """Read file contents via the Shell connector. + + Works with both local (subprocess) and remote (HTTP) connectors. + Lightweight alternative to ``shell_agent`` for simple file reads. + """ + + _name = "read_file" + _description = ( + "Read the full text content of a file at the given path. " + "Use this to inspect skill resources, configuration files, scripts, etc." + ) + backend_type = BackendType.SHELL + + def __init__(self, session: "ShellSession"): + self._session = session + super().__init__() + + async def _arun(self, path: str) -> ToolResult: + try: + result = await self._session.connector.run_bash_script( + f'cat -- "{path}"', + timeout=15, + ) + stdout, stderr, rc = _parse_shell_result(result) + if rc != 0: + return ToolResult( + status=ToolStatus.ERROR, + content=stderr or f"Cannot read file: {path}", + ) + return ToolResult(status=ToolStatus.SUCCESS, content=stdout) + except Exception as e: + return ToolResult( + status=ToolStatus.ERROR, + content=f"read_file failed: {e}", + ) + + +class WriteFileTool(BaseTool): + """Write text content to a file via the Shell connector. + + Creates parent directories automatically. Overwrites existing files. + Uses Python internally to avoid shell-escaping issues. + """ + + _name = "write_file" + _description = ( + "Write text content to a file at the given path. " + "Creates the file and parent directories if they do not exist; " + "overwrites the file if it already exists." + ) + backend_type = BackendType.SHELL + + def __init__(self, session: "ShellSession"): + self._session = session + super().__init__() + + async def _arun(self, path: str, content: str) -> ToolResult: + # Use Python to avoid shell-escaping pitfalls. + escaped_path = json.dumps(path) + escaped_content = json.dumps(content) + code = ( + "import os\n" + f"path = {escaped_path}\n" + f"content = {escaped_content}\n" + "parent = os.path.dirname(os.path.abspath(path))\n" + "if parent:\n" + " os.makedirs(parent, exist_ok=True)\n" + "with open(path, 'w') as f:\n" + " f.write(content)\n" + "print(f'Written {len(content)} chars to {path}')\n" + ) + try: + result = await self._session.connector.run_python_script( + code, timeout=15, + ) + stdout, stderr, rc = _parse_shell_result(result) + if rc != 0: + return ToolResult( + status=ToolStatus.ERROR, + content=stderr or f"Cannot write file: {path}", + ) + return ToolResult(status=ToolStatus.SUCCESS, content=stdout) + except Exception as e: + return ToolResult( + status=ToolStatus.ERROR, + content=f"write_file failed: {e}", + ) + + +class ListDirTool(BaseTool): + """List directory contents via the Shell connector.""" + + _name = "list_dir" + _description = ( + "List the contents of a directory. " + "Returns file names, sizes, and modification dates. " + "Defaults to the current directory if no path is given." + ) + backend_type = BackendType.SHELL + + def __init__(self, session: "ShellSession"): + self._session = session + super().__init__() + + async def _arun(self, path: str = ".") -> ToolResult: + try: + result = await self._session.connector.run_bash_script( + f'ls -la "{path}"', + timeout=15, + ) + stdout, stderr, rc = _parse_shell_result(result) + if rc != 0: + return ToolResult( + status=ToolStatus.ERROR, + content=stderr or f"Cannot list directory: {path}", + ) + return ToolResult(status=ToolStatus.SUCCESS, content=stdout) + except Exception as e: + return ToolResult( + status=ToolStatus.ERROR, + content=f"list_dir failed: {e}", + ) + + +class RunShellTool(BaseTool): + """Run a shell command directly and return stdout/stderr. + + Lightweight alternative to ``shell_agent`` for one-off commands like + ``grep``, ``cat``, ``ls``, ``curl``, etc. Unlike ``shell_agent``, this + does NOT involve an inner LLM agent — the command is executed as-is via + the connector and the raw output is returned. + + Works with both local (subprocess) and remote (HTTP) connectors. + """ + + _name = "run_shell" + _description = ( + "Execute a shell command as-is and return raw stdout/stderr. " + "You provide the exact command (or script); it is run without " + "interpretation, modification, or automatic retry. " + "If the task requires the tool itself to reason, write code, " + "or recover from errors autonomously, use shell_agent instead." + ) + backend_type = BackendType.SHELL + + def __init__(self, session: "ShellSession"): + self._session = session + super().__init__() + + async def _arun(self, command: str, timeout: int = 30) -> ToolResult: + timeout = min(timeout, 120) + try: + result = await self._session.connector.run_bash_script( + command, + timeout=timeout, + ) + stdout, stderr, rc = _parse_shell_result(result) + output = stdout + if stderr: + output += f"\n[STDERR]\n{stderr}" if output else stderr + if rc != 0 and not output: + output = f"Command exited with code {rc}" + return ToolResult( + status=ToolStatus.SUCCESS if rc == 0 else ToolStatus.ERROR, + content=output or "(no output)", + ) + except Exception as e: + return ToolResult( + status=ToolStatus.ERROR, + content=f"run_shell failed: {e}", + ) + + +class ShellAgentTool(BaseTool): + _name = "shell_agent" + _description = """Delegate a task to an intelligent shell agent that autonomously writes and executes code, and will automatically retry and fix errors when possible. +Give it a natural-language task description. The internal agent will: +- Decide whether to use Python or Bash +- Write and execute code, inspect output, and iterate +- Automatically retry and fix errors (up to several rounds) + +Use this when you want the tool itself to figure out how to accomplish a goal. +If you already have the exact command/script to run, use run_shell instead.""" + + backend_type = BackendType.SHELL + _CODE_RGX = re.compile( + r"```(?Ppython|py|bash|shell|sh)[^\n]*\n(?P.*?)```", + re.S | re.I, + ) + + def __init__( + self, + session: "ShellSession", + client_password: str = "", + max_steps: int = 5, + security_manager: SecurityPolicyManager = None, + default_working_dir: str = None, + default_env: dict = None, + default_conda_env: str = None, + model: str = None + ): + import os + self._session = session + # Use explicit model > OPENSPACE_MODEL env var > LLMClient default + resolved_model = model or os.environ.get("OPENSPACE_MODEL") or None + if resolved_model: + self._llm = LLMClient(model=resolved_model) + else: + self._llm = LLMClient() + self.client_password = client_password + self.max_steps = max_steps + self._system_info = None + self.security_manager = security_manager + self._default_working_dir = default_working_dir + self._default_env = default_env or {} + self._default_conda_env = default_conda_env + self._py_tool = PythonScriptTool(session, default_working_dir=default_working_dir, default_env=default_env, default_conda_env=default_conda_env) + self._bash_tool = BashScriptTool(session, default_working_dir=default_working_dir, default_env=default_env, default_conda_env=default_conda_env) + super().__init__() + + async def _get_system_info(self): + """ + Get system information for shell agent. + + First tries to get comprehensive info from local server's /platform endpoint. + Falls back to simple bash commands if that fails. + + Returns: + Dict with at least 'platform' and 'username' keys + """ + if self._system_info is None: + try: + # Try to get system info from server via HTTP API + # Only attempt HTTP when connector provides a valid base_url + # (LocalShellConnector sets base_url=None to signal local mode) + base_url = getattr(self._session.connector, "base_url", None) + + if base_url is not None: + try: + from openspace.platforms import SystemInfoClient + + async with SystemInfoClient(base_url=base_url, timeout=5) as client: + info = await client.get_system_info(use_cache=False) + + if info: + # Use comprehensive info from server + self._system_info = { + "platform": info.get("system", "Linux"), + "username": info.get("username", "user"), + "machine": info.get("machine"), + "release": info.get("release"), + "full_info": info # Keep full info for reference + } + logger.debug(f"Got system info from server: {info.get('system')}") + return self._system_info + + except ImportError: + logger.debug("SystemInfoClient not available, using bash commands") + else: + logger.debug("No server base_url (local mode), skipping HTTP system info") + + # Fallback: use simple bash commands (original method) + platform_result = await self._session.connector.run_bash_script("uname -s", timeout=5) + username_result = await self._session.connector.run_bash_script("whoami", timeout=5) + + platform = self._extract_output(platform_result).strip() + username = self._extract_output(username_result).strip() + + self._system_info = { + "platform": platform, + "username": username + } + logger.debug(f"Got system info from bash: {platform}") + + except Exception as e: + logger.warning(f"Failed to get system info: {e}, using defaults") + self._system_info = {"platform": "Linux", "username": "user"} + + return self._system_info + + async def _arun(self, task: str, timeout: int = 300): + from openspace.grounding.core.types import ToolResult, ToolStatus + + sys_info = await self._get_system_info() + conversation_history = [] + iteration = 0 + last_error = None + + # record the code history + code_history = [] + + # Build environment context + env_context = [] + if self._default_working_dir: + env_context.append(f"Working Directory: {self._default_working_dir}") + if self._default_conda_env: + env_context.append(f"Conda Environment: {self._default_conda_env}") + if self._default_env: + env_vars = ", ".join([f"{k}={v}" for k, v in list(self._default_env.items())[:3]]) + if len(self._default_env) > 3: + env_vars += f", ... (+{len(self._default_env)-3} more)" + env_context.append(f"Custom Environment Variables: {env_vars}") + + env_section = "\n".join([f"# {ctx}" for ctx in env_context]) if env_context else "" + + SHELL_AGENT_SYSTEM_PROMPT = f"""You are an expert system administrator and programmer focused on executing tasks efficiently. + +# System: {sys_info["platform"]}, User: {sys_info["username"]} +{env_section} + +# Your task: {task} + +# IMPORTANT: You MUST provide exactly ONE code block in EVERY response +# Either ```bash or ```python - never respond without code + +# Available actions: +1. Execute bash commands: ```bash ``` +2. Write Python code: ```python ``` + +# Rules: +- ALWAYS include a code block in your response +- Write EXACTLY ONE code block per response +- If you need to understand the current environment, start with bash commands like: pwd, ls, ps, df, etc. +- If you get errors, analyze and fix them in the next iteration +- For sudo: use 'echo {self.client_password} | sudo -S ' +- The environment (working directory, conda env) is managed automatically + +# CRITICAL: Avoid quote escaping errors in bash: +- For complex string operations (JSON, multi-line text, special chars): ALWAYS use Python with heredoc +- Good: ```python ``` +- Bad: bash commands with nested quotes like: echo "$(cat 'file' | grep "pattern")" +- When reading/writing files with complex content: prefer Python over bash +- When processing JSON: ALWAYS use Python's json module, never bash string manipulation + +# Before executing, check if task output already exists: +- Use 'ls -la ' to check for existing files +- If files exist, read and verify them first before recreating +- Avoid redundant work - reuse existing valid outputs + +# Task completion marking: +When you believe the task is COMPLETED, end your response with: +[TASK_COMPLETED: brief explanation of what was accomplished] + +When you encounter an UNRECOVERABLE error that you cannot fix, end your response with: +[TASK_FAILED: brief explanation of why it cannot be completed]""" + + conversation_history.append({"role": "system", "content": SHELL_AGENT_SYSTEM_PROMPT}) + + no_code_counter = 0 + final_message = "" + + while iteration < self.max_steps: + iteration += 1 + + logger.info(f"[ShellAgent] Step {iteration}/{self.max_steps}: Processing task") + + try: + messages_text = LLMClient.format_messages_to_text(conversation_history) + response = await self._llm.complete(messages_text) + + assistant_content = response["message"]["content"] + logger.debug(f"[ShellAgent] Step {iteration} LLM response: {assistant_content[:200]}...") + + # extract and execute the code, and track the code block + code_info, execution_result = await self._execute_code_from_response(assistant_content) + if code_info: + code_history.append(code_info) + + logger.info(f"[ShellAgent] Step {iteration} execution result: {execution_result[:100]}...") + if execution_result == "ERROR: No valid code block found": + no_code_counter += 1 + if no_code_counter >= 3: + final_message = f"Task failed after {iteration} steps: LLM failed to provide code blocks repeatedly" + return ToolResult( + status=ToolStatus.ERROR, + content=final_message, + metadata={"tool": self._name, "code_history": code_history} + ) + else: + no_code_counter = 0 + + completion_status = self._check_task_status(assistant_content, execution_result, last_error) + + if completion_status["completed"]: + content_parts = [f"Task completed successfully after {iteration} steps"] + content_parts.append(f"\n{'='*60}") + content_parts.append(f"\nFinal Result:") + content_parts.append(execution_result) + + if len(code_history) > 1: + content_parts.append(f"\n{'='*60}") + content_parts.append(f"\nExecution Summary ({len(code_history)} steps):") + for i, code_info in enumerate(code_history, 1): + lang = code_info.get("language", "unknown") + output = code_info.get("output", "") + output_preview = output[:200].replace('\n', ' ') + if len(output) > 200: + output_preview += "..." + content_parts.append(f"\n Step {i} [{lang}]: {output_preview}") + + content_parts.append(f"\n{'='*60}") + content_parts.append(f"\nSummary: {completion_status['reason']}") + + final_message = "\n".join(content_parts) + return ToolResult( + status=ToolStatus.SUCCESS, + content=final_message, + metadata={"tool": self._name, "code_history": code_history} + ) + elif completion_status["failed"]: + final_message = f"Task failed after {iteration} steps: {completion_status['reason']}\nLast result: {execution_result}" + return ToolResult( + status=ToolStatus.ERROR, + content=final_message, + metadata={"tool": self._name, "code_history": code_history} + ) + + feedback = self._generate_feedback(execution_result, iteration, last_error) + + conversation_history.extend([ + {"role": "assistant", "content": assistant_content}, + {"role": "user", "content": feedback} + ]) + + last_error = execution_result if "ERROR" in execution_result else None + + except Exception as e: + final_message = f"Tool execution failed at step {iteration}: {str(e)}" + return ToolResult( + status=ToolStatus.ERROR, + content=final_message, + metadata={"tool": self._name, "code_history": code_history} + ) + + final_message = f"Reached maximum steps ({self.max_steps}). Task may be too complex or impossible." + return ToolResult( + status=ToolStatus.ERROR, + content=final_message, + metadata={"tool": self._name, "code_history": code_history} + ) + + async def _execute_code_from_response(self, response: str): + """ + execute the code and track the code block + + Returns: + Tuple[Optional[Dict], str]: (code_info, execution_result) + - code_info: {"lang": "python/bash", "code": "...", "status": "success/error"} + - execution_result: the execution result string + """ + matches = list(self._CODE_RGX.finditer(response)) + if not matches: + return None, "ERROR: No valid code block found" + + lang, code = matches[0]["lang"].lower(), matches[0]["code"].strip() + + # standardize the language name + lang_normalized = "python" if lang in ["python", "py"] else "bash" + + code_info = { + "lang": lang_normalized, + "code": code, + } + + # Security check is only done at the Connector layer to avoid duplicate prompts + + try: + if lang in ["python", "py"]: + helper = self._py_tool + result = await helper._arun(code) + elif lang in ["bash", "shell", "sh"]: + helper = self._bash_tool + result = await helper._arun(code) + else: + execution_result = f"ERROR: Unsupported language: {lang}" + code_info["status"] = "error" + return code_info, execution_result + + execution_result = self._extract_output(result) + code_info["status"] = "success" if "ERROR" not in execution_result else "error" + return code_info, execution_result + + except Exception as e: + execution_result = f"EXECUTION ERROR: {str(e)}" + code_info["status"] = "error" + return code_info, execution_result + + def _generate_feedback(self, result: str, iteration: int, last_error: str) -> str: + feedback = f"Step {iteration} result:\n{result}\n\n" + + if "ERROR" in result: + if last_error and last_error == result: + feedback += "Same error as previous step. Try a different approach.\n" + else: + feedback += "Error occurred. Analyze the error and fix it.\n" + else: + feedback += "Execution successful. Continue to next step if needed.\n" + + feedback += "\nWhat's your next action? (Remember: provide exactly ONE code block)" + return feedback + + def _extract_output(self, result): + if isinstance(result, dict): + # Check for execution errors + stderr = result.get("error") or result.get("stderr") or "" + returncode = result.get("returncode", 0) + stdout = result.get("content") or result.get("output") or result.get("stdout") or "" + + # If there's a non-zero return code or stderr with actual errors, report it + if returncode != 0 or (stderr and len(stderr.strip()) > 0): + error_msg = f"EXECUTION ERROR (exit code {returncode}):\n" + if stderr: + error_msg += f"stderr: {stderr}\n" + if stdout: + error_msg += f"stdout: {stdout}" + return error_msg + + return stdout or str(result) + return str(result) + + # Patterns that indicate actual execution failure in the result string. + _EXEC_ERROR_PATTERNS = [ + "EXECUTION ERROR", + "ERROR:", + "timed out", + "CommandNotFoundError", + "Traceback (most recent call last)", + "Exception:", + "PermissionError", + "FileNotFoundError", + "SyntaxError:", + "ImportError:", + "ModuleNotFoundError", + "No such file or directory", + "command not found", + ] + + def _has_execution_error(self, execution_result: str) -> bool: + """Return True if *execution_result* contains any known error indicator.""" + return any(p in execution_result for p in self._EXEC_ERROR_PATTERNS) + + def _check_task_status(self, response: str, execution_result: str, last_error: str) -> dict: + # 1. Check for explicit LLM failure marker + if "[TASK_FAILED:" in response: + reason = response.split("[TASK_FAILED:")[1].split("]")[0].strip() + return {"completed": False, "failed": True, "reason": reason} + + # 2. Check execution result for errors + has_error = self._has_execution_error(execution_result) + + # 3. LLM says completed — but cross-check with actual result + if "[TASK_COMPLETED:" in response: + reason = response.split("[TASK_COMPLETED:")[1].split("]")[0].strip() + if has_error: + # LLM optimistically marked complete, but execution actually failed. + # Don't trust it — let the agent retry. + logger.warning( + f"[ShellAgent] LLM marked TASK_COMPLETED but execution has errors, " + f"ignoring completion marker. Reason: {reason}" + ) + if last_error and last_error == execution_result: + return {"completed": False, "failed": True, "reason": "Same error repeated - unable to resolve"} + return {"completed": False, "failed": False, "reason": "Execution error occurred (LLM completion marker ignored)"} + return {"completed": True, "failed": False, "reason": reason} + + # 4. No explicit markers — check execution errors + if has_error: + if last_error and last_error == execution_result: + return {"completed": False, "failed": True, "reason": "Same error repeated - unable to resolve"} + return {"completed": False, "failed": False, "reason": "Execution error occurred"} + + return {"completed": False, "failed": False, "reason": "Task in progress"} \ No newline at end of file diff --git a/openspace/grounding/backends/shell/transport/connector.py b/openspace/grounding/backends/shell/transport/connector.py new file mode 100644 index 0000000000000000000000000000000000000000..30a32b0df96f0f369b9642115e5055e0532e24ec --- /dev/null +++ b/openspace/grounding/backends/shell/transport/connector.py @@ -0,0 +1,199 @@ +import asyncio +from typing import Any, Optional, Dict + +from openspace.grounding.core.transport.connectors import AioHttpConnector +from openspace.grounding.core.security import SecurityPolicyManager +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class ShellConnector(AioHttpConnector): + """ + Shell backend HTTP connector + Basic routes: + POST /run_python {"code": str} + POST /run_bash_script {"script": str, "timeout": int, "working_dir": str | None} + """ + + def __init__( + self, + vm_ip: str, + port: int = 5000, + *, + retry_times: int = 3, + retry_interval: float = 5, + security_manager: "SecurityPolicyManager | None" = None, + ) -> None: + base_url = f"http://{vm_ip}:{port}" + super().__init__(base_url) + self.retry_times = retry_times + self.retry_interval = retry_interval + self._security_manager = security_manager + + async def _retry_invoke( + self, + name: str, + payload: Dict[str, Any], + script_timeout: int, + *, + break_on_timeout: bool = False + ): + """ + Execute HTTP request and retry + + Args: + name: RPC method name + payload: Request payload + script_timeout: Script execution timeout + break_on_timeout: Whether to exit immediately on timeout (default False) + + Returns: + Server response result + + Raises: + Exception: Last exception thrown after all retries fail + """ + last_exc: Exception | None = None + # HTTP request timeout should be longer than script execution timeout, leaving buffer time + http_timeout = script_timeout + 60 + + for attempt in range(1, self.retry_times + 1): + try: + # Pass timeout parameter to server + result = await self.invoke(name, payload | {"timeout": script_timeout}) + logger.info("%s executed successfully (attempt %d/%d)", name, attempt, self.retry_times) + return result + except asyncio.TimeoutError as exc: + # Timeout exception usually does not need to be retried (script execution time too long) + if break_on_timeout: + logger.error("%s timed out after %d seconds, aborting retry", name, script_timeout) + raise RuntimeError( + f"Script execution timed out after {script_timeout} seconds" + ) from exc + last_exc = exc + if attempt == self.retry_times: + break + logger.warning( + "%s timed out (attempt %d/%d), retrying in %.1f seconds...", + name, attempt, self.retry_times, self.retry_interval + ) + await asyncio.sleep(self.retry_interval) + except Exception as exc: + last_exc = exc + if attempt == self.retry_times: + break + logger.warning( + "%s failed (attempt %d/%d): %s, retrying in %.1f seconds...", + name, attempt, self.retry_times, exc, self.retry_interval + ) + await asyncio.sleep(self.retry_interval) + + error_msg = f"{name} failed after {self.retry_times} retries" + logger.error(error_msg) + raise last_exc or RuntimeError(error_msg) + + async def run_python_script( + self, + code: str, + *, + timeout: int = 90, + working_dir: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + conda_env: Optional[str] = None + ) -> Any: + """ + Execute Python script on remote server + + Args: + code: Python code string + timeout: Execution timeout in seconds (default 90 seconds) + working_dir: Working directory for script execution (optional) + env: Environment variables for script execution (optional) + conda_env: Conda environment name to activate (optional) + + Returns: + Server response result + + Raises: + PermissionError: Security policy blocked execution + RuntimeError: Execution failed or timed out + """ + if self._security_manager: + from openspace.grounding.core.types import BackendType + allowed = await self._security_manager.check_command_allowed(BackendType.SHELL, code) + if not allowed: + logger.error("SecurityPolicy blocked python code execution") + raise PermissionError("SecurityPolicy: python code execution blocked") + + payload = {"code": code, "working_dir": working_dir, "env": env, "conda_env": conda_env} + logger.info( + "Executing python script with timeout=%d seconds%s%s%s", + timeout, + f", working_dir={working_dir}" if working_dir else "", + f", env={list(env.keys())}" if env else "", + f", conda_env={conda_env}" if conda_env else "" + ) + # Python script timed out, exit immediately without retry (timeout usually means script logic problem) + return await self._retry_invoke( + "POST /run_python", + payload, + timeout, + break_on_timeout=True + ) + + async def run_bash_script( + self, + script: str, + *, + timeout: int = 90, + working_dir: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + conda_env: Optional[str] = None + ) -> Any: + """ + Execute Bash script on remote server + + Args: + script: Bash script content (can be multi-line) + timeout: Execution timeout in seconds (default 90 seconds) + working_dir: Working directory for script execution (optional) + env: Environment variables for script execution (optional) + conda_env: Conda environment name to activate (optional) + + Returns: + Server response result, containing status, output, error, returncode, etc. + + Raises: + PermissionError: Security policy blocked execution + RuntimeError: Execution failed or timed out + """ + if self._security_manager: + from openspace.grounding.core.types import BackendType + allowed = await self._security_manager.check_command_allowed(BackendType.SHELL, script) + if not allowed: + logger.error("SecurityPolicy blocked bash script execution") + raise PermissionError("SecurityPolicy: bash script execution blocked") + + payload = {"script": script, "working_dir": working_dir, "env": env, "conda_env": conda_env} + logger.info( + "Executing bash script with timeout=%d seconds%s%s%s", + timeout, + f", working_dir={working_dir}" if working_dir else "", + f", env={list(env.keys())}" if env else "", + f", conda_env={conda_env}" if conda_env else "" + ) + + # Bash script timed out, exit immediately without retry (timeout usually means script logic problem) + result = await self._retry_invoke( + "POST /run_bash_script", + payload, + timeout, + break_on_timeout=True + ) + + # Record execution result + if isinstance(result, dict) and "returncode" in result: + logger.info("Bash script executed with return code: %d", result.get("returncode", -1)) + + return result \ No newline at end of file diff --git a/openspace/grounding/backends/shell/transport/local_connector.py b/openspace/grounding/backends/shell/transport/local_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..2db21b052c3d1afc7470779ea8bf8bfd85b90c35 --- /dev/null +++ b/openspace/grounding/backends/shell/transport/local_connector.py @@ -0,0 +1,411 @@ +""" +Local Shell Connector — execute Python / Bash scripts directly via subprocess. + +This connector has the **same public API** as ShellConnector (HTTP version) +but runs everything in-process, removing the need for a local_server. + +Return format is kept identical so that ShellSession / ShellAgentTool +work without any changes. +""" + +import asyncio +import os +import platform +import tempfile +import uuid +from typing import Any, Optional, Dict + +from openspace.grounding.core.transport.connectors.base import BaseConnector +from openspace.grounding.core.transport.task_managers.noop import NoOpConnectionManager +from openspace.grounding.core.security import SecurityPolicyManager +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +platform_name = platform.system() + + +# --------------------------------------------------------------------------- +# Conda helpers (mirrored from local_server/main.py) +# --------------------------------------------------------------------------- + +def _get_conda_activation_prefix(conda_env: str | None) -> str: + """Generate platform-specific conda activation prefix.""" + if not conda_env: + return "" + if platform_name == "Windows": + conda_paths = [ + os.path.expandvars(r"%USERPROFILE%\miniconda3\Scripts\activate.bat"), + os.path.expandvars(r"%USERPROFILE%\anaconda3\Scripts\activate.bat"), + r"C:\ProgramData\Miniconda3\Scripts\activate.bat", + r"C:\ProgramData\Anaconda3\Scripts\activate.bat", + ] + for p in conda_paths: + if os.path.exists(p): + return f'call "{p}" {conda_env} && ' + return f"conda activate {conda_env} && " + else: + conda_paths = [ + os.path.expanduser("~/miniconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/anaconda3/etc/profile.d/conda.sh"), + "/opt/conda/etc/profile.d/conda.sh", + "/usr/local/miniconda3/etc/profile.d/conda.sh", + "/usr/local/anaconda3/etc/profile.d/conda.sh", + ] + for p in conda_paths: + if os.path.exists(p): + return f'source "{p}" && conda activate {conda_env} && ' + return f"conda activate {conda_env} && " + + +def _wrap_script_with_conda(script: str, conda_env: str | None) -> str: + """Wrap bash script with conda activation if needed.""" + if not conda_env: + return script + if platform_name == "Windows": + prefix = _get_conda_activation_prefix(conda_env) + return f"{prefix}{script}" + else: + conda_paths = [ + os.path.expanduser("~/miniconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/anaconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/opt/anaconda3/etc/profile.d/conda.sh"), + "/opt/conda/etc/profile.d/conda.sh", + ] + conda_sh = None + for p in conda_paths: + if os.path.exists(p): + conda_sh = p + break + if conda_sh: + return ( + f'#!/bin/bash\n' + f'if [ -f "{conda_sh}" ]; then\n' + f' . "{conda_sh}"\n' + f' conda activate {conda_env} 2>/dev/null || true\n' + f'fi\n\n' + f'{script}\n' + ) + else: + logger.warning( + "Conda environment '%s' requested but conda not found. " + "Executing with system Python.", conda_env + ) + return script + + +class LocalShellConnector(BaseConnector[Any]): + """ + Shell connector that runs scripts **locally** using asyncio subprocesses, + bypassing the Flask local_server entirely. + + Public API is compatible with ``ShellConnector`` so that ``ShellSession`` + works without modification. + """ + + def __init__( + self, + *, + retry_times: int = 3, + retry_interval: float = 5, + security_manager: "SecurityPolicyManager | None" = None, + ) -> None: + super().__init__(NoOpConnectionManager()) + self.retry_times = retry_times + self.retry_interval = retry_interval + self._security_manager = security_manager + # Provide base_url = None so ShellSession._get_system_info falls back + # to bash-based detection instead of HTTP. + self.base_url: str | None = None + + # ------------------------------------------------------------------ + # connect / disconnect (mostly no-ops for local execution) + # ------------------------------------------------------------------ + + async def connect(self) -> None: + """No real connection to establish for local mode.""" + if self._connected: + return + await super().connect() + logger.info("LocalShellConnector: ready (local mode, no server required)") + + # ------------------------------------------------------------------ + # Core execution helpers + # ------------------------------------------------------------------ + + async def _run_subprocess( + self, + cmd: list[str], + *, + timeout: int = 90, + working_dir: str | None = None, + env: dict[str, str] | None = None, + ) -> Dict[str, Any]: + """Run a command via asyncio subprocess and return a result dict + matching the format returned by the local_server endpoints.""" + exec_env = os.environ.copy() + if env: + exec_env.update(env) + + cwd = working_dir or os.getcwd() + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=exec_env, + ) + stdout_b, stderr_b = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_b.decode("utf-8", errors="replace") if stdout_b else "" + stderr = stderr_b.decode("utf-8", errors="replace") if stderr_b else "" + returncode = proc.returncode or 0 + + return { + "status": "success" if returncode == 0 else "error", + "output": stdout, + "content": stdout or "Code executed successfully (no output)", + "error": stderr, + "returncode": returncode, + } + except asyncio.TimeoutError: + return { + "status": "error", + "output": f"Execution timed out after {timeout} seconds", + "content": f"Execution timed out after {timeout} seconds", + "error": "", + "returncode": -1, + } + except Exception as e: + return { + "status": "error", + "output": "", + "content": "", + "error": str(e), + "returncode": -1, + } + + async def _run_shell_command( + self, + shell_cmd: str, + *, + timeout: int = 90, + working_dir: str | None = None, + env: dict[str, str] | None = None, + ) -> Dict[str, Any]: + """Run a shell command string (used for conda-wrapped scripts).""" + exec_env = os.environ.copy() + if env: + exec_env.update(env) + + cwd = working_dir or os.getcwd() + + try: + proc = await asyncio.create_subprocess_shell( + shell_cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + cwd=cwd, + env=exec_env, + ) + stdout_b, _ = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_b.decode("utf-8", errors="replace") if stdout_b else "" + returncode = proc.returncode or 0 + + return { + "status": "success" if returncode == 0 else "error", + "output": stdout, + "content": stdout or "Code executed successfully (no output)", + "error": "", + "returncode": returncode, + } + except asyncio.TimeoutError: + return { + "status": "error", + "output": f"Script execution timed out after {timeout} seconds", + "content": f"Script execution timed out after {timeout} seconds", + "error": "", + "returncode": -1, + } + except Exception as e: + return { + "status": "error", + "output": "", + "content": "", + "error": str(e), + "returncode": -1, + } + + # ------------------------------------------------------------------ + # Public API (same signatures as ShellConnector) + # ------------------------------------------------------------------ + + async def run_python_script( + self, + code: str, + *, + timeout: int = 90, + working_dir: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + conda_env: Optional[str] = None, + ) -> Any: + """Execute a Python script locally. + + Return format matches the server's ``/run_python`` endpoint. + """ + # Security check + if self._security_manager: + from openspace.grounding.core.types import BackendType + allowed = await self._security_manager.check_command_allowed( + BackendType.SHELL, code + ) + if not allowed: + logger.error("SecurityPolicy blocked python code execution") + raise PermissionError("SecurityPolicy: python code execution blocked") + + # Write code to temp file (same as local_server) + suffix = uuid.uuid4().hex + if platform_name == "Windows": + temp_filename = os.path.join(tempfile.gettempdir(), f"python_exec_{suffix}.py") + else: + temp_filename = f"/tmp/python_exec_{suffix}.py" + + try: + with open(temp_filename, "w") as f: + f.write(code) + + logger.info( + "Executing python script locally with timeout=%d seconds%s%s%s", + timeout, + f", working_dir={working_dir}" if working_dir else "", + f", env={list(env.keys())}" if env else "", + f", conda_env={conda_env}" if conda_env else "", + ) + + if conda_env: + activation = _get_conda_activation_prefix(conda_env) + if activation: + python_cmd = "python" if platform_name == "Windows" else "python3" + full_cmd = f'{activation}{python_cmd} "{temp_filename}"' + result = await self._run_shell_command( + full_cmd, timeout=timeout, working_dir=working_dir, env=env + ) + else: + python_cmd = "python" if platform_name == "Windows" else "python3" + result = await self._run_subprocess( + [python_cmd, temp_filename], + timeout=timeout, + working_dir=working_dir, + env=env, + ) + else: + python_cmd = "python" if platform_name == "Windows" else "python3" + result = await self._run_subprocess( + [python_cmd, temp_filename], + timeout=timeout, + working_dir=working_dir, + env=env, + ) + + return result + + finally: + if os.path.exists(temp_filename): + os.remove(temp_filename) + + async def run_bash_script( + self, + script: str, + *, + timeout: int = 90, + working_dir: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + conda_env: Optional[str] = None, + ) -> Any: + """Execute a Bash script locally. + + Return format matches the server's ``/run_bash_script`` endpoint. + """ + # Security check + if self._security_manager: + from openspace.grounding.core.types import BackendType + allowed = await self._security_manager.check_command_allowed( + BackendType.SHELL, script + ) + if not allowed: + logger.error("SecurityPolicy blocked bash script execution") + raise PermissionError("SecurityPolicy: bash script execution blocked") + + # Wrap with conda if needed + final_script = _wrap_script_with_conda(script, conda_env) + + # Write to temp file (same as local_server) + suffix = uuid.uuid4().hex + if platform_name == "Windows": + temp_filename = os.path.join(tempfile.gettempdir(), f"bash_exec_{suffix}.sh") + else: + temp_filename = f"/tmp/bash_exec_{suffix}.sh" + + try: + with open(temp_filename, "w") as f: + f.write(final_script) + os.chmod(temp_filename, 0o755) + + logger.info( + "Executing bash script locally with timeout=%d seconds%s%s%s", + timeout, + f", working_dir={working_dir}" if working_dir else "", + f", env={list(env.keys())}" if env else "", + f", conda_env={conda_env}" if conda_env else "", + ) + + shell_cmd = ["bash", temp_filename] if platform_name == "Windows" else ["/bin/bash", temp_filename] + result = await self._run_subprocess( + shell_cmd, + timeout=timeout, + working_dir=working_dir, + env=env, + ) + return result + + finally: + if os.path.exists(temp_filename): + os.unlink(temp_filename) + + # ------------------------------------------------------------------ + # BaseConnector abstract methods + # ------------------------------------------------------------------ + + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + """Dispatch by name — same routing as ShellConnector via AioHttpConnector.""" + name_upper = name.strip().upper() + if "/RUN_PYTHON" in name_upper: + return await self.run_python_script( + params.get("code", ""), + timeout=params.get("timeout", 90), + working_dir=params.get("working_dir"), + env=params.get("env"), + conda_env=params.get("conda_env"), + ) + elif "/RUN_BASH_SCRIPT" in name_upper: + return await self.run_bash_script( + params.get("script", ""), + timeout=params.get("timeout", 90), + working_dir=params.get("working_dir"), + env=params.get("env"), + conda_env=params.get("conda_env"), + ) + else: + raise NotImplementedError(f"LocalShellConnector does not support: {name}") + + async def request(self, *args: Any, **kwargs: Any) -> Any: + """Not used in local mode.""" + raise NotImplementedError( + "LocalShellConnector does not support raw HTTP requests" + ) + diff --git a/openspace/grounding/backends/web/__init__.py b/openspace/grounding/backends/web/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4d2be876ae1c3b335806a8beb38414d4be6574 --- /dev/null +++ b/openspace/grounding/backends/web/__init__.py @@ -0,0 +1,7 @@ +from .provider import WebProvider +from .session import WebSession + +__all__ = [ + "WebProvider", + "WebSession" +] \ No newline at end of file diff --git a/openspace/grounding/backends/web/provider.py b/openspace/grounding/backends/web/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4ca54ce3d34db5f6d98f8d3a269c93ed2bc647 --- /dev/null +++ b/openspace/grounding/backends/web/provider.py @@ -0,0 +1,55 @@ +from typing import Dict, Any +from openspace.grounding.core.types import BackendType, SessionConfig +from openspace.grounding.core.provider import Provider +from .session import WebSession +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class WebProvider(Provider[WebSession]): + + DEFAULT_SID = BackendType.WEB.value + + def __init__(self, config: Dict[str, Any] = None): + super().__init__(BackendType.WEB, config) + + async def initialize(self) -> None: + """Initialize Web Provider and create default session""" + if not self.is_initialized: + logger.info("Initializing Web provider (Knowledge Research)") + # Auto-create default session + await self.create_session(SessionConfig( + session_name=self.DEFAULT_SID, + backend_type=BackendType.WEB, + connection_params={} + )) + self.is_initialized = True + + async def create_session(self, session_config: SessionConfig) -> WebSession: + """Create Web session""" + session_name = session_config.session_name + + if session_name in self._sessions: + logger.warning(f"Session {session_name} already exists, returning existing session") + return self._sessions[session_name] + + # Create WebSession with auto-connect and auto-initialize enabled + session = WebSession( + session_id=session_name, + config=session_config, + auto_connect=True, + auto_initialize=True + ) + + self._sessions[session_name] = session + + logger.info(f"Created Web session (Knowledge Research): {session_name}") + return session + + async def close_session(self, session_name: str) -> None: + """Close Web session""" + session = self._sessions.pop(session_name, None) + if session: + await session.disconnect() + logger.info(f"Closed Web session: {session_name}") \ No newline at end of file diff --git a/openspace/grounding/backends/web/session.py b/openspace/grounding/backends/web/session.py new file mode 100644 index 0000000000000000000000000000000000000000..9a450c5970f30ea68edb298fb8e2203a0817b1f2 --- /dev/null +++ b/openspace/grounding/backends/web/session.py @@ -0,0 +1,243 @@ +import os +from pathlib import Path +from typing import Dict, Any, Optional +from openspace.grounding.core.session import BaseSession +from openspace.grounding.core.types import BackendType, SessionConfig +from openspace.grounding.core.tool import BaseTool +from openspace.grounding.core.transport.connectors import BaseConnector +from openspace.llm import LLMClient +from openspace.utils.logging import Logger +from dotenv import load_dotenv + +# Load .env from openspace package root (4 levels up), then CWD fallback. +_PKG_ENV = Path(__file__).resolve().parent.parent.parent.parent / ".env" # openspace/.env +if _PKG_ENV.is_file(): + load_dotenv(_PKG_ENV) +load_dotenv() +logger = Logger.get_logger(__name__) + + +try: + from openai import AsyncOpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + + +class WebConnector(BaseConnector): + def __init__(self, api_key: str, base_url: str): + self.api_key = api_key + self.base_url = base_url + self.client: Optional[AsyncOpenAI] = None + self._connected = False + + async def connect(self) -> None: + if self._connected: + return + + if not OPENAI_AVAILABLE: + raise RuntimeError( + "OpenAI library not available. Install with: pip install openai" + ) + + if not self.api_key: + raise RuntimeError( + "API key not provided. Set OPENROUTER_API_KEY environment variable " + "or provide deep_research_api_key in config." + ) + + self.client = AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key + ) + self._connected = True + logger.info(f"Web connector connected to {self.base_url}") + + async def disconnect(self) -> None: + if not self._connected: + return + + self.client = None + self._connected = False + logger.info("Web connector disconnected") + + @property + def is_connected(self) -> bool: + return self._connected + + async def invoke(self, name: str, params: dict) -> Any: + if name == "chat_completion": + if not self.client: + raise RuntimeError("Client not connected") + return await self.client.chat.completions.create(**params) + raise NotImplementedError(f"Unknown method: {name}") + + async def request(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("Web backend uses invoke() instead of request()") + + +class WebSession(BaseSession): + + backend_type = BackendType.WEB + + def __init__( + self, + *, + session_id: str, + config: SessionConfig, + deep_research_api_key: Optional[str] = None, + deep_research_base_url: str = "https://openrouter.ai/api/v1", + auto_connect: bool = True, + auto_initialize: bool = True + ): + api_key = deep_research_api_key or os.getenv("OPENROUTER_API_KEY") + connector = WebConnector( + api_key=api_key or "", # Empty string will raise an error when connect + base_url=deep_research_base_url + ) + + super().__init__( + connector=connector, + session_id=session_id, + backend_type=BackendType.WEB, + auto_connect=auto_connect, + auto_initialize=auto_initialize + ) + self.config = config + + @property + def web_connector(self) -> WebConnector: + return self.connector + + async def initialize(self) -> Dict[str, Any]: + """Connect to WebConnector and register tools. + + BaseSession in __aenter__ will call connect() according to auto_connect, + but in provider.create_session directly instantiating Session will not trigger this logic. + Therefore, we need to explicitly ensure that the connection is established, avoiding AttributeError + when DeepResearchTool is called and `self.web_connector.client` is still None. + """ + + # If the connection is not established, connect explicitly + if not self.is_connected: + try: + await self.connect() + except Exception as e: + logger.error(f"Failed to connect WebSession {self.session_id}: {e}") + raise + + if self.tools: + logger.debug(f"Web session {self.session_id} already initialized, skipping") + return { + "tools": [t.name for t in self.tools], + "backend": BackendType.WEB.value + } + + self.tools = [DeepResearchTool(session=self)] + + logger.info(f"Initialized Web session {self.session_id} with AI Deep Research tool") + + return { + "tools": [t.name for t in self.tools], + "backend": BackendType.WEB.value + } + + +class DeepResearchTool(BaseTool): + + backend_type = BackendType.WEB + _name = "deep_research_agent" + _description = """Knowledge Research Tool - Primary tool for acquiring external knowledge + +PURPOSE: +Acquires comprehensive knowledge from the web through deep research and analysis. +Powered by Perplexity AI's sonar-deep-research model, then post-processed to extract +actionable insights and concise summaries. The main tool for gathering information +beyond existing knowledge base. + +WHEN TO USE: +- Information needed on professional/technical topics +- Research on technical problems, concepts, or implementations +- Understanding of latest developments, trends, or news +- Comparison of different approaches, tools, or solutions +- Factual information, definitions, or explanations required +- Synthesis from multiple authoritative sources needed + +HOW IT WORKS: +1. Conducts deep web search using Perplexity's sonar-deep-research +2. Analyzes and synthesizes information from multiple sources +3. Post-processes to distill knowledge-dense summary retaining critical details +4. Returns comprehensive summary ready for immediate use + +RETURNS: +Knowledge-dense comprehensive summary (400-600 words) that: +- Retains important details and technical specifics +- Focuses on substantive knowledge without losing critical information +- Organized and structured for clarity +- Directly usable by agents for decision-making and task execution + +NOT DESIGNED FOR: +- Tasks requiring browser interaction or UI manipulation +- Direct file downloads or web scraping operations +- Real-time system operations or executions + +USAGE GUIDELINES: +- Frame clear, specific questions (e.g., "Explain the architecture of Transformer models") +- Specify context when needed (e.g., "Compare PostgreSQL vs MySQL for high-concurrency scenarios") +- Suitable for any knowledge or information acquisition needs +""" + + def __init__( + self, + session: WebSession + ): + super().__init__() + self._session = session + self._llm = LLMClient() + + async def _arun(self, query: str) -> str: + if not query: + return "ERROR: Missing required parameter: query" + + try: + # Step 1: Deep research + logger.info(f"Start deep research: {query}") + + completion = await self._session.web_connector.client.chat.completions.create( + model="perplexity/sonar-deep-research", + messages=[{"role": "user", "content": query}] + ) + + full_answer = completion.choices[0].message.content + logger.info(f"Research completed, length: {len(full_answer)} characters") + + # Step 2: Use LLMClient to generate summary and distill key points + logger.info(f"Begin to distill key points...") + + SUMMARY_AGENT_PROMPT = f"""Please distill the following deep research results into a knowledge-dense summary. Requirements: + +Provide a comprehensive yet concise summary (400-600 words): +- Focus on SUBSTANTIVE knowledge and key information +- Retain important details, technical specifics, and concrete facts +- Do NOT sacrifice critical information for brevity +- Organize information clearly and logically with proper structure +- Remove only redundancy and verbose explanations +- Include actionable insights and decision-relevant information +- Make it directly usable for task execution and decision-making + +Output ONLY the summary text, no additional formatting or JSON structure needed. + +Deep Research Results: +{full_answer} +""" + + summary_response = await self._llm.complete(SUMMARY_AGENT_PROMPT) + summary = summary_response["message"]["content"].strip() + + logger.info(f"Summary generation completed") + + return summary + + except Exception as e: + logger.error(f"Deep research failed: {e}") + return f"ERROR: AI research failed: {e}" \ No newline at end of file diff --git a/openspace/grounding/core/exceptions.py b/openspace/grounding/core/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..3458c7dc360408ace4fb3c543ce34f00ac4f0c9c --- /dev/null +++ b/openspace/grounding/core/exceptions.py @@ -0,0 +1,67 @@ +""" +Unified exception & error-code definitions for the grounding framework +""" +from enum import Enum, auto +from typing import Any, Dict + + +class ErrorCode(str, Enum): + # generic + UNKNOWN = auto() + CONFIG_INVALID = auto() + + # provider / session / connector + PROVIDER_ERROR = auto() + SESSION_NOT_FOUND = auto() + + # connection + CONNECTION_FAILED = auto() + CONNECTION_TIMEOUT = auto() + + # tool + TOOL_NOT_FOUND = auto() + TOOL_EXECUTION_FAIL = auto() + AMBIGUOUS_TOOL = auto() + + +class GroundingError(Exception): + """ + Framework-wide base exception. + + Args: + message: Human readable error message. + code: One of the error codes defined above. + retryable: Whether the caller may retry the operation automatically. + context: Extra key-value pairs (e.g. tool_name, session_id) for logging / metrics. + """ + + __slots__ = ("message", "code", "retryable", "context") + + def __init__( + self, + message: str, + *, + code: ErrorCode = ErrorCode.UNKNOWN, + retryable: bool = False, + **context: Any, + ): + super().__init__(f"[{code}] {message}") + self.message: str = message + self.code: ErrorCode = code + self.retryable: bool = retryable + self.context: Dict[str, Any] = context + + def to_dict(self) -> Dict[str, Any]: + """Serialize error for structured logging / JSON response.""" + return { + "code": self.code.value, + "message": self.message, + "retryable": self.retryable, + "context": self.context, + } + + def __str__(self) -> str: + return f"[{self.code}] {self.message}" + + def __repr__(self) -> str: + return f"GroundingError(code={self.code}, msg={self.message!r})" \ No newline at end of file diff --git a/openspace/grounding/core/grounding_client.py b/openspace/grounding/core/grounding_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4f97a7f3582e8790288709300c9210fad874eff7 --- /dev/null +++ b/openspace/grounding/core/grounding_client.py @@ -0,0 +1,899 @@ +import asyncio +import time +from collections import OrderedDict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from .types import BackendType, SessionConfig, SessionInfo, SessionStatus, ToolResult +from .exceptions import ErrorCode, GroundingError +from .tool import BaseTool +from .provider import Provider, ProviderRegistry +from .session import BaseSession +from .search_tools import SearchCoordinator +from openspace.config import GroundingConfig, get_config +from openspace.config.utils import get_config_value +from openspace.utils.logging import Logger +import importlib + + +class GroundingClient: + """ + Global Entry, Facing Agent/Application, only concerned with Provider & Session + """ + def __init__(self, config: Optional[GroundingConfig] = None, recording_manager=None) -> None: + # Initialize logger first (needed by other initialization steps) + self._logger = Logger.get_logger(__name__) + + self._config: GroundingConfig = config or get_config() + self._registry: ProviderRegistry = ProviderRegistry() + + # Register providers from config + self._register_providers_from_config() + + # Session + self._sessions: Dict[str, BaseSession] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._server_session_map: dict[tuple[BackendType, str], str] = {} # (backend, server) -> session_name + + # Tool cache + self._tool_cache: "OrderedDict[str, tuple[List[BaseTool], float]]" = OrderedDict() + self._tool_cache_ttl: int = get_config_value(self._config, "tool_cache_ttl", 300) + self._tool_cache_maxsize: int = get_config_value(self._config, "tool_cache_maxsize", 300) + + # Concurrent control + self._lock = asyncio.Lock() + self._cache_lock = asyncio.Lock() + + # Tool search coordinator + self._search_coordinator: Optional[SearchCoordinator] = None + + # Recording manager (optional, for GUI intermediate step recording) + self._recording_manager = recording_manager + + # Tool quality manager + self._quality_manager = self._init_quality_manager() + + # Register SystemProvider (requires GroundingClient instance, so must be done after __init__) + self._register_system_provider() + + def _register_providers_from_config(self) -> None: + """ + Based on GroundingConfig.enabled_backends, register Provider instances to + self._registry. Here only do *instantiation*, not await initialize(), + to avoid blocking the event loop in the import stage; Provider will be lazily initialized when it is first used. + + Note: SystemProvider is skipped here and registered separately in _register_system_provider() + because it requires a GroundingClient instance. + """ + if not self._config.enabled_backends: + self._logger.warning("No enabled_backends defined in config") + return + + for item in self._config.enabled_backends: + be_name: str | None = item.get("name") + cls_path: str | None = item.get("provider_cls") + if not (be_name and cls_path): + self._logger.warning("Invalid backend entry: %s", item) + continue + + backend = BackendType(be_name.lower()) + + # Skip system backend - it will be registered separately + if backend == BackendType.SYSTEM: + self._logger.debug("Skipping system backend in config registration (will be registered separately)") + continue + + if backend in self._registry.list(): + continue # Already registered + + # Dynamically import Provider class + try: + module_path, _, cls_name = cls_path.rpartition(".") + module = importlib.import_module(module_path) + prov_cls = getattr(module, cls_name) + except (ModuleNotFoundError, AttributeError) as e: + self._logger.error("Import provider failed: %s (%s)", cls_path, e) + continue + + backend_cfg = self._config.get_backend_config(be_name) + provider: Provider = prov_cls(backend_cfg) + self._registry.register(provider) + + def _register_system_provider(self) -> None: + """ + Register SystemProvider separately because it requires GroundingClient instance. + SystemProvider provides meta-level tools for querying system state (list providers, tools, etc.) + and is always available regardless of configuration. + """ + try: + from .system import SystemProvider + system_provider = SystemProvider(self) + self._registry.register(system_provider) + self._logger.debug("SystemProvider registered successfully") + except Exception as e: + self._logger.warning(f"Failed to register SystemProvider: {e}") + + def _init_quality_manager(self): + """Initialize tool quality manager based on config.""" + try: + # Check if quality tracking is enabled in config + quality_config = getattr(self._config, 'tool_quality', None) + if not quality_config or not getattr(quality_config, 'enabled', True): + self._logger.debug("Tool quality tracking disabled") + return None + + from .quality import ToolQualityManager, set_quality_manager + from pathlib import Path + from openspace.config.constants import PROJECT_ROOT + + # Shared DB path + db_path = getattr(quality_config, 'db_path', None) + if db_path: + db_path = Path(db_path) + else: + # Default: same location as SkillStore + db_dir = PROJECT_ROOT / ".openspace" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / "openspace.db" + + manager = ToolQualityManager( + db_path=db_path, + enable_persistence=getattr(quality_config, 'enable_persistence', True), + auto_save=True, + evolve_interval=getattr(quality_config, 'evolve_interval', 5), + ) + + # Set as global manager for BaseTool access + set_quality_manager(manager) + + self._logger.info( + f"ToolQualityManager initialized " + f"(records={len(manager._records)})" + ) + return manager + + except Exception as e: + self._logger.warning(f"Failed to initialize ToolQualityManager: {e}") + return None + + @property + def quality_manager(self): + """Get the tool quality manager.""" + return self._quality_manager + + # Quality API for Upper Layer + def get_quality_report(self) -> Dict[str, Any]: + """ + Get comprehensive tool quality report. + """ + if not self._quality_manager: + return {"status": "disabled", "message": "Quality tracking not enabled"} + return self._quality_manager.get_quality_report() + + async def evolve_quality(self) -> Dict[str, Any]: + """ + Run quality self-evolution cycle. + + This triggers: + - Tool change detection + - Description re-evaluation for updated tools + - Adaptive quality weight computation + + Call this periodically or after tool set changes. + """ + if not self._quality_manager: + return {"status": "disabled"} + + # Get all tools + all_tools = await self.list_tools() + return await self._quality_manager.evolve(all_tools) + + def get_tool_insights(self, tool: BaseTool) -> Dict[str, Any]: + """ + Get detailed quality insights for a specific tool. + """ + if not self._quality_manager: + return {"status": "disabled"} + return self._quality_manager.get_tool_insights(tool) + + def register_provider(self, provider: Provider) -> None: + self._registry.register(provider) + + def get_provider(self, backend: BackendType) -> Provider: + return self._registry.get(backend) + + def list_providers(self) -> Dict[BackendType, Provider]: + return self._registry.list() + + @property + def recording_manager(self): + """Get the recording manager.""" + return self._recording_manager + + @recording_manager.setter + def recording_manager(self, manager): + """ + Set or update the recording manager. + This allows coordinator to inject recording_manager after GroundingClient creation. + """ + self._recording_manager = manager + self._logger.info("GroundingClient: RecordingManager updated") + + async def initialize_all_providers(self) -> None: + await asyncio.gather(*[provider.initialize() for provider in self._registry.list().values() if not provider.is_initialized]) + + + async def create_session( + self, + *, + backend: BackendType, + name: str | None = None, + connection_params: Dict[str, Any] | None = None, + server: str | None = None, + **options, + ) -> str: + """ + Create and initialize Session, return "session_name" (external visible) + name is auto generated when it's None: - + MCP backend needs to provide server + """ + async with self._lock: + # Check concurrent sessions limit + max_sessions = get_config_value(self._config, "max_concurrent_sessions", 100) + if len(self._sessions) >= max_sessions: + raise GroundingError(f"Reached maximum session limit: {max_sessions}") + + # Session naming strategy + if server: # Only MCP will pass in server + name = name or f"{backend.value}-{server}" + else: + name = name or backend.value # Other backends have a fixed 1 session + + if name in self._sessions: + # Reuse existing session + self._logger.warning("Session '%s' exists, reusing.", name) + return name + + # Get Provider (initialize if first time) + provider = self._registry.get(backend) + if not provider.is_initialized: + await provider.initialize() + + if backend == BackendType.MCP: + if server is None: + raise GroundingError("Must specify 'server' when creating MCP session") + + # Construct SessionConfig, pass to Provider to create + connection_params = connection_params or {} + if server: + connection_params.setdefault("server", server) + + # Inject recording_manager for GUI backend (for intermediate step recording) + if backend == BackendType.GUI and self._recording_manager is not None: + connection_params.setdefault("recording_manager", self._recording_manager) + + sess_cfg = SessionConfig( + session_name=name, # Use external visible name + backend_type=backend, + connection_params=connection_params, + **options, + ) + session_obj = await provider.create_session(sess_cfg) + + # Store session and monitoring info + async with self._lock: + self._sessions[name] = session_obj + now = datetime.utcnow() + self._session_info[name] = SessionInfo( + session_name=name, + backend_type=backend, + status=SessionStatus.CONNECTED, + created_at=now, + last_activity=now, + ) + if server: + self._server_session_map[(backend, server)] = name + + self._logger.info("Session created: %s", name) + return name + + def list_sessions(self) -> List[str]: + return list(self._sessions.keys()) + + async def close_session(self, name: str) -> None: + async with self._lock: + session = self._sessions.pop(name, None) + info = self._session_info.pop(name, None) + self._tool_cache.pop(name, None) + + for k, v in list(self._server_session_map.items()): + if v == name: + self._server_session_map.pop(k) + + if not session: + self._logger.warning("Session '%s' not found", name) + return + + try: + provider = self._registry.get(info.backend_type) if info else None + if provider: + await provider.close_session(name) + else: + # Fallback: if no provider, disconnect directly + await session.disconnect() + finally: + self._logger.info("Session closed: %s", name) + + async def close_all_sessions(self) -> None: + for sid in list(self._sessions.keys()): + await self.close_session(sid) + + async def ensure_session(self, backend: BackendType, server: str | None = None) -> str: + sid = backend.value if server is None else f"{backend.value}-{server}" + if sid not in self._sessions: + await self.create_session(backend=backend, name=sid, server=server) + return sid + + def get_session_info(self, name: str) -> SessionInfo: + """Get session monitoring info""" + if name not in self._session_info: + raise ErrorCode.SESSION_NOT_FOUND(name) + return self._session_info[name] + + def get_session(self, name: str) -> BaseSession: + """Get session""" + if name not in self._sessions: + raise ErrorCode.SESSION_NOT_FOUND(name) + return self._sessions[name] + + + async def _fetch_tools( + self, + backend: BackendType, + *, + session_name: str | None = None, + use_cache: bool = False, + bind_runtime_info: bool = True, + ) -> List[BaseTool]: + """ + Fetch tools from provider. + + Args: + backend: Backend type + session_name: + - None: fetch all tools from all sessions of this backend + - str: fetch tools from specific session + use_cache: Whether to use cache + bind_runtime_info: Whether to bind runtime info to tool instances + """ + now = time.time() + + # Auto-generate cache_scope from parameters + if session_name: + cache_scope = session_name + else: + cache_scope = f"backend-{backend.value}" + + # Check cache + if use_cache: + async with self._cache_lock: + if cache_scope in self._tool_cache: + tools, ts = self._tool_cache[cache_scope] + if now - ts < self._tool_cache_ttl: + self._tool_cache.move_to_end(cache_scope) + return tools + + provider = self._registry.get(backend) + if not provider.is_initialized: + await provider.initialize() + + tools = await provider.list_tools(session_name=session_name) + + if bind_runtime_info: + # If session_name is specified, bind all tools to that session + if session_name: + server_name = None + if backend == BackendType.MCP: + server_name = session_name.replace(f"{backend.value}-", "", 1) + + for tool in tools: + tool.bind_runtime_info( + backend=backend, + session_name=session_name, + server_name=server_name, + grounding_client=self, + ) + else: + # No session_name specified - get tools from all sessions + # For each backend, find the default/primary session + # For Shell/Web/GUI: use the default session (backend.value) + # For MCP: tools should already be bound by the provider + default_session_name = None + + # Try to find an existing session for this backend + for sid, info in self._session_info.items(): + if info.backend_type == backend: + default_session_name = sid + break + + # Fallback: use backend default naming + if not default_session_name: + default_session_name = backend.value + + server_name = None + if backend == BackendType.MCP and default_session_name: + server_name = default_session_name.replace(f"{backend.value}-", "", 1) + + for tool in tools: + # Only bind if tool doesn't have runtime info already + # (some providers like MCP bind runtime info during list_tools) + if not tool.is_bound: + tool.bind_runtime_info( + backend=backend, + session_name=default_session_name, + server_name=server_name, + grounding_client=self, + ) + elif not tool.runtime_info.grounding_client: + # Tool has runtime info but no grounding_client, add it + tool.bind_runtime_info( + backend=tool.runtime_info.backend, + session_name=tool.runtime_info.session_name, + server_name=tool.runtime_info.server_name, + grounding_client=self, + ) + + # Save to cache + if use_cache: + async with self._cache_lock: + self._tool_cache[cache_scope] = (tools, now) + self._tool_cache.move_to_end(cache_scope) + while len(self._tool_cache) > self._tool_cache_maxsize: + self._tool_cache.popitem(last=False) + + return tools + + async def list_tools( + self, + backend: BackendType | list[BackendType] | None = None, + session_name: str | None = None, + *, + use_cache: bool = False, + ) -> List[BaseTool]: + """ + List tools from backend(s) or session. + + 1. session_name is provided → return tools from that session + 2. backend is list → return tools from multiple backends + 3. backend is single → return tools from that backend + 4. backend is None → return tools from all backends + + Args: + backend: Single backend, list of backends, or None for all + session_name: Specific session name (overrides backend parameter) + use_cache: Whether to use cache + + Returns: + List of tools + """ + # Session-level + if session_name: + if session_name not in self._sessions: + raise ErrorCode.SESSION_NOT_FOUND(session_name) + backend_type = self._session_info[session_name].backend_type + return await self._fetch_tools( + backend_type, + session_name=session_name, + use_cache=use_cache, + ) + + # Multiple backends + if isinstance(backend, list): + tools: List[BaseTool] = [] + for be in backend: + backend_tools = await self._fetch_tools( + be, + session_name=None, # Provider aggregates all sessions + use_cache=use_cache, + ) + tools.extend(backend_tools) + return tools + + # Single backend + if backend is not None: + return await self._fetch_tools( + backend, + session_name=None, + use_cache=use_cache, + ) + + # All backends + tools: List[BaseTool] = [] + for backend_type in self._registry.list().keys(): + backend_tools = await self._fetch_tools( + backend_type, + session_name=None, + use_cache=use_cache, + ) + tools.extend(backend_tools) + return tools + + async def list_backend_tools( + self, + backend: BackendType | list[BackendType] | None = None, + use_cache: bool = False + ) -> list[BaseTool]: + return await self.list_tools(backend=backend, session_name=None, use_cache=use_cache) + + async def list_session_tools( + self, + session_name: str, + use_cache: bool = False + ) -> list[BaseTool]: + if session_name not in self._session_info: + raise ErrorCode.SESSION_NOT_FOUND(session_name) + backend = self._session_info[session_name].backend_type + return await self.list_tools(backend, session_name, use_cache) + + async def list_all_backend_tools( + self, + use_cache: bool = False + ) -> Dict[BackendType, list[BaseTool]]: + """List static tools for every registered backend.""" + result = {} + for backend_type in self.list_providers().keys(): + tools = await self.list_backend_tools(backend=backend_type, use_cache=use_cache) + result[backend_type] = tools + return result + + async def search_tools( + self, + task_description: str, + *, + backend: BackendType | list[BackendType] | None = None, + session_name: str | None = None, + max_tools: int | None = None, + search_mode: str | None = None, + use_cache: bool = True, + llm_callable = None, + enable_llm_filter: bool | None = None, + llm_filter_threshold: int | None = None, + enable_cache_persistence: bool | None = None, + cache_dir: str | None = None, + ) -> list[BaseTool]: + """ + Search tools from backend(s) or session. + + Args: + task_description: Task description for searching relevant tools + backend: Backend type(s) to search + session_name: Specific session to search + max_tools: Maximum number of tools to return + search_mode: Search mode ("semantic", "keyword", "hybrid") + use_cache: Whether to use cached tool list + llm_callable: LLM client for intelligent filtering + enable_llm_filter: Whether to use LLM pre-filtering + llm_filter_threshold: Threshold for applying LLM filter + enable_cache_persistence: Whether to persist embeddings to disk. If None, uses config value. + cache_dir: Directory for persistent cache. If None, uses config value or default. + """ + candidate_tools = await self.list_tools( + backend=backend, + session_name=session_name, + use_cache=use_cache, + ) + + if not candidate_tools: + self._logger.warning("No candidate tools found for search") + return [] + + # lazy initialize SearchCoordinator (or recreate if parameters changed) + if self._search_coordinator is None: + # Get quality ranking settings from config + quality_config = getattr(self._config, 'tool_quality', None) + enable_quality_ranking = getattr(quality_config, 'enable_quality_ranking', True) if quality_config else True + + self._search_coordinator = SearchCoordinator( + max_tools=max_tools, + llm=llm_callable, + enable_llm_filter=enable_llm_filter, + llm_filter_threshold=llm_filter_threshold, + enable_cache_persistence=enable_cache_persistence, + cache_dir=cache_dir, + quality_manager=self._quality_manager, + enable_quality_ranking=enable_quality_ranking, + ) + + # execute search and sort + try: + filtered_tools = await self._search_coordinator._arun( + task_prompt=task_description, + candidate_tools=candidate_tools, + max_tools=max_tools, + mode=search_mode, + ) + return filtered_tools + except Exception as exc: + self._logger.error(f"Tool search failed: {exc}") + # fallback: return top N tools + fallback_max = max_tools or self._config.tool_search.max_tools + return candidate_tools[:fallback_max] + + def get_last_search_debug_info(self) -> Optional[Dict[str, Any]]: + """Get debug info from the last tool search operation. + + Returns: + Dict containing search debug info, or None if no search has been performed. + """ + if self._search_coordinator is None: + return None + return self._search_coordinator.get_last_search_debug_info() + + async def get_tools_with_auto_search( + self, + *, + task_description: str | None = None, + backend: BackendType | list[BackendType] | None = None, + session_name: str | None = None, + max_tools: int | None = None, + search_mode: str | None = None, + use_cache: bool = True, + llm_callable = None, + enable_llm_filter: bool | None = None, + llm_filter_threshold: int | None = None, + enable_cache_persistence: bool | None = None, + cache_dir: str | None = None, + ) -> list[BaseTool]: + """ + Intelligent tool retrieval: automatically decides whether to return all tools or trigger search. + + Logic: + - If tool_count <= max_tools: return all tools directly + - If tool_count > max_tools: trigger search and return top max_tools + + Args: + task_description: Task description (required for search if triggered). + If None, search will not be triggered even if tool count exceeds max_tools. + backend: Backend type(s) to query + session_name: Specific session name + max_tools: Maximum number of tools to return. Also acts as the threshold for triggering search. + - None: Use value from config (default: 30) + search_mode: Search mode ("semantic", "keyword", "hybrid") + use_cache: Whether to use cache + llm_callable: LLM client (for intelligent filtering) + enable_llm_filter: Whether to use LLM for backend/server pre-filtering. + - None: Use config default + - False: Disable LLM filter, use tool-level search only + - True: Enable LLM filter + llm_filter_threshold: Only apply LLM filter when tool count > this threshold. + - None: Use default (50) + - N: Only apply LLM filter when > N tools + enable_cache_persistence: Whether to persist embeddings to disk. If None, uses config value. + cache_dir: Directory for persistent cache. If None, uses config value or default. + + Returns: + List of tools (at most max_tools) + + Examples: + # Scenario 1: Auto-detect whether search is needed + tools = await gc.get_tools_with_auto_search( + task_description="Create a flowchart", + backend=BackendType.MCP + ) + + # Scenario 2: Custom max_tools + tools = await gc.get_tools_with_auto_search( + task_description="Edit file", + backend=BackendType.SHELL, + max_tools=30 # Return at most 30 tools + ) + + # Scenario 3: Disable search (return all tools regardless of count) + tools = await gc.get_tools_with_auto_search( + backend=BackendType.MCP # No task_description = no search + ) + """ + # Fetch all candidate tools + all_tools = await self.list_tools( + backend=backend, + session_name=session_name, + use_cache=use_cache, + ) + + if not all_tools: + self._logger.warning("No tools found") + return [] + + # Determine max_tools from config if not provided + if max_tools is None: + max_tools = self._config.tool_search.max_tools + + # Decide whether search is needed + tools_count = len(all_tools) + need_search = tools_count > max_tools and task_description is not None + + if need_search: + self._logger.info( + f"Tool count ({tools_count}) > max_tools ({max_tools}), " + f"triggering search to filter relevant tools..." + ) + return await self.search_tools( + task_description=task_description, + backend=backend, + session_name=session_name, + max_tools=max_tools, + search_mode=search_mode, + use_cache=use_cache, + llm_callable=llm_callable, + enable_llm_filter=enable_llm_filter, + llm_filter_threshold=llm_filter_threshold, + enable_cache_persistence=enable_cache_persistence, + cache_dir=cache_dir, + ) + else: + if task_description is None: + self._logger.debug( + f"No task description provided, returning all {tools_count} tools" + ) + else: + self._logger.debug( + f"Tool count ({tools_count}) ≤ max_tools ({max_tools}), " + f"returning all tools without search" + ) + return all_tools + + async def invoke_tool( + self, + tool: BaseTool | str, + parameters: Dict[str, Any] | None = None, + *, + backend: BackendType | None = None, + session_name: str | None = None, + server: str | None = None, + keep_session: bool = False, + **kwargs + ) -> ToolResult: + """ + Universal tool invocation method. + Supports multiple calling patterns: + + 1. Using BaseTool instance with bound runtime info + 2. Using BaseTool instance with explicit backend/session + 3. Using tool name with automatic lookup + 4. Using tool name with explicit backend/session/server + + Args: + tool: BaseTool instance or tool name string + parameters: Tool parameters as dict + backend: Backend type (optional for BaseTool with runtime_info) + session_name: Session name (optional for BaseTool with runtime_info) + server: Server name (for MCP, optional for BaseTool with runtime_info) + keep_session: Whether to keep session alive after invocation + **kwargs: Alternative parameter passing + + Returns: + ToolResult + + Examples: + # Pattern 1: Tool instance with runtime info (from list_tools) + tools = await gc.list_tools() + tool = next(t for t in tools if t.name == "read_file") + result = await gc.invoke_tool(tool, {"path": "/tmp/a.txt"}) + + # Pattern 2: Tool instance with explicit backend/session + my_tool = MyTool() + result = await gc.invoke_tool( + my_tool, + {"arg": "value"}, + backend=BackendType.SHELL + ) + + # Pattern 3: Tool name with automatic lookup + result = await gc.invoke_tool("read_file", {"path": "/tmp/a.txt"}) + + # Pattern 4: Tool name with explicit backend/server + result = await gc.invoke_tool( + "read_file", + {"path": "/tmp/a.txt"}, + backend=BackendType.MCP, + server="filesystem" + ) + """ + params = parameters or kwargs + + # BaseTool instance + if isinstance(tool, BaseTool): + tool_name = tool.schema.name + + # Try to use bound runtime info first + if tool.is_bound and not (backend or session_name or server): + # Use runtime info + runtime_backend = tool.runtime_info.backend + runtime_session = tool.runtime_info.session_name + runtime_server = tool.runtime_info.server_name + else: + # Use provided or tool's default backend + runtime_backend = backend or tool.backend_type + runtime_session = session_name + runtime_server = server + + if runtime_backend == BackendType.NOT_SET: + raise GroundingError( + f"Cannot invoke tool '{tool_name}': no backend specified. " + f"Either bind runtime info or provide backend parameter.", + code=ErrorCode.TOOL_EXECUTION_FAIL + ) + + # Tool name string + elif isinstance(tool, str): + tool_name = tool + + # If explicit backend/session provided, use them + if backend or session_name: + runtime_session = session_name + runtime_server = server + + # Infer backend: prefer explicit backend; otherwise get from session + if backend is not None: + runtime_backend = backend + else: + if runtime_session not in self._session_info: + raise ErrorCode.SESSION_NOT_FOUND(runtime_session) + runtime_backend = self._session_info[ + runtime_session + ].backend_type + else: + # Auto-lookup: search for the tool + all_tools = await self.list_tools(use_cache=True) + matching = [t for t in all_tools if t.name == tool_name] + + if not matching: + raise GroundingError( + f"Tool '{tool_name}' not found", + code=ErrorCode.TOOL_NOT_FOUND + ) + + if len(matching) > 1: + sources = [ + f"{t.runtime_info.backend.value}/{t.runtime_info.session_name}" + for t in matching if t.is_bound + ] + raise GroundingError( + f"Multiple tools named '{tool_name}' found in: {sources}. " + f"Please specify 'backend' or 'session_name' parameter.", + code=ErrorCode.AMBIGUOUS_TOOL + ) + + # Use the found tool's runtime info + found_tool = matching[0] + runtime_backend = found_tool.runtime_info.backend + runtime_session = found_tool.runtime_info.session_name + runtime_server = found_tool.runtime_info.server_name + + # Execute the tool + # Ensure session exists (except for SYSTEM backend which doesn't use sessions) + # Check if session really exists - cached tools have session_name but session may not be running + if runtime_backend != BackendType.SYSTEM: + if not runtime_session or runtime_session not in self._sessions: + runtime_session = await self.ensure_session(runtime_backend, runtime_server) + + try: + provider = self._registry.get(runtime_backend) + # SystemProvider doesn't use sessions, pass a dummy value + session_param = runtime_session if runtime_session else "system" + result = await provider.call_tool(session_param, tool_name, params) + + # Update last_activity in session_info (skip for SYSTEM backend) + if runtime_backend != BackendType.SYSTEM and runtime_session and runtime_session in self._session_info: + async with self._lock: + old_info = self._session_info[runtime_session] + self._session_info[runtime_session] = old_info.model_copy( + update={"last_activity": datetime.utcnow()} + ) + + return result + finally: + # Auto-close session if requested (skip for SYSTEM backend) + if runtime_backend != BackendType.SYSTEM and not keep_session and runtime_session: + if runtime_server or runtime_session.startswith(runtime_backend.value): + await self.close_session(runtime_session) \ No newline at end of file diff --git a/openspace/grounding/core/provider.py b/openspace/grounding/core/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..88ad6acd6dff0556a195ef1d00b59808ab248da0 --- /dev/null +++ b/openspace/grounding/core/provider.py @@ -0,0 +1,148 @@ +""" +provider is to manage sessions of a backend, if the backend is mcp, then provider will manage sessions through servers +""" +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Generic, TypeVar + +from .tool import BaseTool +from .types import BackendType, SessionConfig, ToolResult, ToolStatus +from .session import BaseSession +from .security.policies import SecurityPolicyManager +from openspace.config import get_config +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) +TSession = TypeVar('TSession', bound=BaseSession) + + +class Provider(ABC, Generic[TSession]): + """Backend provider base class""" + def __init__(self, backend_type: BackendType, config: Dict[str, Any] = None): + self.backend_type = backend_type + self.config = config or {} + self.is_initialized = False + self._sessions: Dict[str, TSession] = {} # session management + self._session_counter: int = 0 + self.security_manager = SecurityPolicyManager() + + self._setup_security_policy(config) + + def _setup_security_policy(self, config: dict | None = None): + security_policy = get_config().get_security_policy(self.backend_type.value) + self.security_manager.set_backend_policy(BackendType.SHELL, security_policy) + + async def ensure_initialized(self) -> None: + """ + Internal helper. Guarantee that `initialize()` has been executed + """ + if not self.is_initialized: + await self.initialize() + + @abstractmethod + async def initialize(self) -> None: + """Initialize provider, call `create_session` to create all sessions if not exist + Subclasses should set `self.is_initialized = True` after successful initialization + """ + pass + + @abstractmethod + async def create_session(self, session_config: SessionConfig) -> TSession: + """Create session, update _sessions""" + pass + + @abstractmethod + async def close_session(self, session_name: str) -> None: + """Close session""" + pass + + def list_sessions(self) -> List[str]: + """Get all session IDs""" + return list(self._sessions.keys()) + + def get_session(self, session_name: str) -> Optional[TSession]: + """Get session object by ID""" + return self._sessions.get(session_name) + + async def close_all_sessions(self) -> None: + """Provider shutdown cleanup""" + for session_name in list(self._sessions.keys()): + try: + await self.close_session(session_name) + except Exception as e: + print(f"Error closing session {session_name}: {e}") + + self._sessions.clear() + self.is_initialized = False + + def __repr__(self) -> str: + return (f"Provider(backend={self.backend_type.value}, " + f"initialized={self.is_initialized}, " + f"sessions={len(self._sessions)}, " + f"config_items={len(self.config)})") + + async def list_tools(self, session_name: Optional[str] = None) -> List[BaseTool]: + """ + Return BaseTool list. + If session_name is specified, only return the tools of the specified session. + If session_name is not specified, return all tools of all sessions. + """ + await self.ensure_initialized() + + if session_name: + session = self._sessions.get(session_name) + return await session.list_tools() if session else [] + + tools: list[BaseTool] = [] + for sess in self._sessions.values(): + tools.extend(await sess.list_tools()) + return tools + + async def call_tool( + self, + session_name: str, + tool_name: str, + parameters: Dict[str, Any] | None = None, + ) -> ToolResult: + + await self.ensure_initialized() + parameters = parameters or {} + + session = self._sessions.get(session_name) + if session is None: + return ToolResult( + status=ToolStatus.ERROR, + content="", + error=f"Session '{session_name}' not found", + metadata={"session_name": session_name, "tool_name": tool_name}, + ) + + try: + return await session.call_tool(tool_name, parameters) + except Exception as e: + logger.error("Execute tool error: %s @%s - %s", tool_name, session_name, e) + return ToolResult( + status=ToolStatus.ERROR, + content="", + error=str(e), + metadata={"session_name": session_name, "tool_name": tool_name}, + ) + + +class ProviderRegistry: + """ + Maintain mapping of BackendType -> Provider, and provide dynamic registration / retrieval capabilities + """ + def __init__(self) -> None: + self._providers: dict[BackendType, Provider] = {} + + def register(self, provider: "Provider") -> None: + self._providers[provider.backend_type] = provider + logger.debug("Provider for %s registered", provider.backend_type) + + def get(self, backend: BackendType) -> "Provider": + if backend not in self._providers: + raise KeyError(f"Provider for '{backend.value}' not registered") + return self._providers[backend] + + def list(self) -> dict[BackendType, "Provider"]: + return dict(self._providers) \ No newline at end of file diff --git a/openspace/grounding/core/quality/__init__.py b/openspace/grounding/core/quality/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e0e0b57f97668d828943b15463de8af15f39d0 --- /dev/null +++ b/openspace/grounding/core/quality/__init__.py @@ -0,0 +1,28 @@ +from .types import ToolQualityRecord, ExecutionRecord, DescriptionQuality +from .manager import ToolQualityManager +from .store import QualityStore + +# Global manager instance +_global_manager: "ToolQualityManager | None" = None + + +def get_quality_manager() -> "ToolQualityManager | None": + """Get the global quality manager instance.""" + return _global_manager + + +def set_quality_manager(manager: "ToolQualityManager") -> None: + """Set the global quality manager instance.""" + global _global_manager + _global_manager = manager + + +__all__ = [ + "ToolQualityRecord", + "ExecutionRecord", + "DescriptionQuality", + "ToolQualityManager", + "QualityStore", + "get_quality_manager", + "set_quality_manager", +] diff --git a/openspace/grounding/core/quality/manager.py b/openspace/grounding/core/quality/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..197087efb5a5106f81f5a186395cc37aa96d60c6 --- /dev/null +++ b/openspace/grounding/core/quality/manager.py @@ -0,0 +1,884 @@ +""" +Tool Quality Manager + +Core API (called by main flow): +- record_execution(): Called by BaseTool after execution +- adjust_ranking(): Called by SearchCoordinator for quality-aware sorting +- evolve(): Called periodically by ToolLayer for self-evolution + +Query API (for inspection/debugging): +- get_quality_report(), get_tool_insights() +""" + +import hashlib +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from .types import ToolQualityRecord, ExecutionRecord, DescriptionQuality +from .store import QualityStore +from openspace.utils.logging import Logger + +if TYPE_CHECKING: + from openspace.grounding.core.tool import BaseTool + from openspace.grounding.core.types import ToolResult + from openspace.llm import LLMClient + +logger = Logger.get_logger(__name__) + + +class ToolQualityManager: + """ + Manages tool quality tracking and quality-aware ranking. + + Features: + - Track execution success rate and latency + - LLM-based description quality evaluation (optional, requires llm_client) + - Persistent memory across sessions + - Quality-integrated tool ranking + - Incremental update detection + """ + + def __init__( + self, + *, + db_path: Optional[Path] = None, + cache_dir: Optional[Path] = None, # deprecated, ignored + llm_client: Optional["LLMClient"] = None, + enable_persistence: bool = True, + auto_save: bool = True, + evolve_interval: int = 5, + ): + self._llm_client = llm_client + self._enable_persistence = enable_persistence + self._auto_save = auto_save + self._evolve_interval = evolve_interval + + # In-memory cache + self._records: Dict[str, ToolQualityRecord] = {} + self._global_execution_count: int = 0 + self._last_evolve_count: int = 0 + + # Persistent store (SQLite, shares DB file with SkillStore) + self._store = QualityStore(db_path=db_path) if enable_persistence else None + + # Load from DB + if self._store: + self._records, self._global_execution_count = self._store.load_all() + self._last_evolve_count = ( + (self._global_execution_count // self._evolve_interval) + * self._evolve_interval + ) + + logger.info( + f"ToolQualityManager initialized " + f"(persistence={enable_persistence}, records={len(self._records)}, " + f"global_count={self._global_execution_count}, " + f"evolve_interval={self._evolve_interval})" + ) + + def get_tool_key(self, tool: "BaseTool") -> str: + """Generate unique key for a tool.""" + from openspace.grounding.core.types import BackendType + + if tool.is_bound: + backend = tool.runtime_info.backend.value + server = tool.runtime_info.server_name or "default" + else: + backend = tool.backend_type.value if tool.backend_type != BackendType.NOT_SET else "unknown" + server = "default" + + return f"{backend}:{server}:{tool.name}" + + def _compute_description_hash(self, tool: "BaseTool") -> str: + """Compute hash of tool description for change detection.""" + content = f"{tool.name}|{tool.description or ''}|{tool.schema.parameters}" + return hashlib.md5(content.encode()).hexdigest()[:16] + + def get_record(self, tool: "BaseTool") -> ToolQualityRecord: + """Get or create quality record for a tool.""" + key = self.get_tool_key(tool) + + if key not in self._records: + backend, server, name = key.split(":", 2) + self._records[key] = ToolQualityRecord( + tool_key=key, + backend=backend, + server=server, + tool_name=name, + description_hash=self._compute_description_hash(tool), + ) + + return self._records[key] + + def get_quality_score(self, tool: "BaseTool") -> float: + """Get quality score for a tool (0-1).""" + return self.get_record(tool).quality_score + + # Key-based record access (for cross-system integration) + def get_or_create_record_by_key(self, tool_key: str) -> ToolQualityRecord: + """Get or create a ToolQualityRecord by its canonical key. + + Used by ExecutionAnalyzer integration where no BaseTool instance + is available. Parses ``tool_key`` into backend/server/tool_name. + + Key formats: + - ``backend:server:tool_name`` → three-part key (canonical for MCP) + - ``backend:tool_name`` → two-part; tries ``backend:default:tool_name`` + first for matching existing records. + """ + # 1. Direct match + if tool_key in self._records: + return self._records[tool_key] + + parts = tool_key.split(":", 2) + if len(parts) == 3: + backend, server, name = parts + elif len(parts) == 2: + backend, name = parts + server = "default" + # Try normalized 3-part key before creating a new record + canonical = f"{backend}:default:{name}" + if canonical in self._records: + return self._records[canonical] + else: + backend, server, name = "unknown", "default", tool_key + + canonical_key = f"{backend}:{server}:{name}" + if canonical_key in self._records: + return self._records[canonical_key] + + record = ToolQualityRecord( + tool_key=canonical_key, + backend=backend, + server=server, + tool_name=name, + ) + self._records[canonical_key] = record + return record + + def find_record_by_key(self, key: str) -> Optional[ToolQualityRecord]: + """Find a record by exact or partial tool key. + + Tries in order: + 1. Exact match (3-part ``backend:server:tool`` or 2-part) + 2. Normalized 2-part → ``backend:default:tool`` + 3. Linear scan matching backend + tool_name (ignoring server) + """ + # 1. Exact + if key in self._records: + return self._records[key] + + parts = key.split(":", 2) + if len(parts) == 2: + backend, tool_name = parts + # 2. Normalize + canonical = f"{backend}:default:{tool_name}" + if canonical in self._records: + return self._records[canonical] + # 3. Scan + for record in self._records.values(): + if record.backend == backend and record.tool_name == tool_name: + return record + return None + + async def record_llm_tool_issues( + self, tool_issues: List[str], task_id: str = "", + ) -> int: + """Record LLM-identified tool issues into the quality tracking system. + + Each issue is injected as a failed ``ExecutionRecord`` in the tool's + ``recent_executions`` via ``add_llm_issue()``, so it feeds into the + same ``recent_success_rate`` → ``penalty`` pipeline as rule-based + tracking. This means one unified set of quality metrics drives + ranking adjustments and future batch skill updates. + + The LLM may catch semantic failures (HTTP 200 but wrong data, + misleading output, etc.) that rule-based tracking misses. + + Args: + tool_issues: List of ``"key — description"`` strings. + Key formats: ``mcp:server:tool`` or ``backend:tool``. + task_id: Task ID for traceability (optional). + """ + updated = 0 + for issue in tool_issues: + # Parse "key — description" or "key - description" + if "—" in issue: + key_part, _, description = issue.partition("—") + elif " - " in issue: + key_part, _, description = issue.partition(" - ") + else: + key_part, description = issue, "" + key_part = key_part.strip() + description = description.strip() + if not key_part: + continue + + record = self.find_record_by_key(key_part) + if record is None: + record = self.get_or_create_record_by_key(key_part) + + # Inject into the unified quality pipeline + tag = f"(task={task_id}) " if task_id else "" + record.add_llm_issue(f"{tag}{description}" if description else f"{tag}flagged by analysis LLM") + updated += 1 + + # Persist + if updated and self._auto_save and self._store: + await self._store.save_all(self._records, self._global_execution_count) + + if updated: + logger.info( + f"Recorded {updated} LLM tool issue(s) into quality pipeline" + f"{f' (task={task_id})' if task_id else ''}" + ) + return updated + + def get_llm_flagged_tools( + self, min_flags: int = 2, + ) -> List[ToolQualityRecord]: + """Get tools repeatedly flagged by the analysis LLM. + + Useful for identifying tools whose descriptions, reliability, or + behavior may need attention — and for batch-triggering skill + updates on skills that depend on those tools. + + Each returned record carries LLM-identified failures in its + ``recent_executions`` (prefixed ``[LLM]``), providing actionable context. + + Args: + min_flags: Minimum number of LLM flags to include. + """ + return [ + r for r in self._records.values() + if r.llm_flagged_count >= min_flags + ] + + # Execution Tracking + async def record_execution( + self, + tool: "BaseTool", + result: "ToolResult", + execution_time_ms: float, + ) -> None: + """Record tool execution result and increment global counter.""" + record = self.get_record(tool) + + # Extract error message if failed + error_message = None + if result.is_error and result.error: + error_message = str(result.error)[:500] + + # Add execution record + record.add_execution(ExecutionRecord( + timestamp=datetime.now(), + success=result.is_success, + execution_time_ms=execution_time_ms, + error_message=error_message, + )) + + # Increment global execution count + self._global_execution_count += 1 + + # Auto-save + if self._auto_save and self._store: + await self._store.save_record(record, self._records, self._global_execution_count) + + logger.debug( + f"Recorded execution: {record.tool_key} " + f"success={result.is_success} time={execution_time_ms:.0f}ms " + f"(global_count={self._global_execution_count})" + ) + + async def evaluate_description( + self, + tool: "BaseTool", + force: bool = False, + ) -> Optional[DescriptionQuality]: + """ + Evaluate tool description quality using LLM. + """ + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("quality") + except ImportError: + _src_tok = None + + if not self._llm_client: + logger.debug("LLM client not available for description evaluation") + if _src_tok is not None: + reset_call_source(_src_tok) + return None + + record = self.get_record(tool) + + # Skip if already evaluated and not forced + if record.description_quality and not force: + # Check if description changed + current_hash = self._compute_description_hash(tool) + if current_hash == record.description_hash: + return record.description_quality + + # Build evaluation prompt + desc = tool.description or "No description provided" + if len(desc) > 4000: + desc = desc[:4000] + "\n... (truncated for length)" + + params = tool.schema.parameters or {} + if params: + param_lines = [] + # Extract parameter names and types from JSON schema + if "properties" in params: + for param_name, param_info in params.get("properties", {}).items(): + param_type = param_info.get("type", "unknown") + param_desc = param_info.get("description", "") + param_lines.append(f"- {param_name} ({param_type}): {param_desc}" if param_desc else f"- {param_name} ({param_type})") + param_text = "\n".join(param_lines) if param_lines else "No parameter descriptions available" + else: + param_text = "No parameters" + + prompt = f"""# Task: Evaluate this tool's documentation quality + +## Tool Information + +Name: {tool.name} + +Description: +{desc} + +Parameters: +{param_text} + +## Evaluation Task + +Rate the documentation on two dimensions (0.0 to 1.0 scale): + +### 1. Clarity +How clear is the tool's purpose and usage? + +- 0.0-0.3: No description or completely unclear +- 0.4-0.6: Basic purpose but vague +- 0.7-0.8: Clear purpose and functionality +- 0.9-1.0: Very clear with usage examples or context + +### 2. Completeness +Are inputs/outputs properly documented? + +- 0.0-0.3: Missing critical information +- 0.4-0.6: Basic info but lacks details +- 0.7-0.8: Well documented with types +- 0.9-1.0: Comprehensive with constraints and examples + +## Scoring Guidelines + +- Short descriptions can score high if clear and accurate +- If parameters exist but aren't explained in description, reduce completeness score +- Missing description means clarity = 0.0 + +## Output + +Respond with JSON only: + +```json +{{ + "reasoning": "Brief 1-2 sentence analysis", + "clarity": 0.8, + "completeness": 0.7 +}} +```""" + + try: + response = await self._llm_client.complete(prompt) + content = response["message"]["content"] + + # Parse JSON response + import json + + # Extract complete JSON object + def extract_json_object(text: str) -> str | None: + """Extract first complete JSON object from text by counting braces.""" + start = text.find('{') + if start == -1: + return None + + count = 0 + in_string = False + escape_next = False + + for i, char in enumerate(text[start:], start): + if escape_next: + escape_next = False + continue + + if char == '\\': + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if not in_string: + if char == '{': + count += 1 + elif char == '}': + count -= 1 + if count == 0: + return text[start:i+1] + return None + + json_str = extract_json_object(content) + if not json_str: + logger.warning(f"Could not find JSON in LLM response for {tool.name}") + return None + + data = json.loads(json_str) + + # Extract and validate scores with robust error handling + def safe_float(value, default=0.5, min_val=0.0, max_val=1.0): + """Safely convert to float and clamp to valid range.""" + try: + if value is None: + return default + f = float(value) + return max(min_val, min(max_val, f)) + except (ValueError, TypeError): + logger.warning(f"Invalid score value: {value}, using default {default}") + return default + + clarity = safe_float(data.get("clarity"), default=0.5) + completeness = safe_float(data.get("completeness"), default=0.5) + reasoning = str(data.get("reasoning", ""))[:500] # Limit reasoning length + + quality = DescriptionQuality( + clarity=clarity, + completeness=completeness, + evaluated_at=datetime.now(), + reasoning=reasoning, + ) + + # Update record + record.description_quality = quality + record.description_hash = self._compute_description_hash(tool) + record.last_updated = datetime.now() + + # Save + if self._auto_save and self._store: + await self._store.save_record(record, self._records, self._global_execution_count) + + logger.info(f"Evaluated description: {tool.name} score={quality.overall_score:.2f}") + return quality + + except Exception as e: + logger.error(f"Description evaluation failed for {tool.name}: {e}") + return None + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + # Quality-Aware Ranking + def adjust_ranking( + self, + tools_with_scores: List[Tuple["BaseTool", float]], + ) -> List[Tuple["BaseTool", float]]: + """ + Adjust tool ranking using penalty-based approach. + + Args: + tools_with_scores: List of (tool, semantic_score) tuples + """ + adjusted = [] + for tool, semantic_score in tools_with_scores: + penalty = self.get_penalty(tool) + + adjusted_score = semantic_score * penalty + + adjusted.append((tool, adjusted_score)) + + # Sort by adjusted score (descending) + adjusted.sort(key=lambda x: x[1], reverse=True) + + return adjusted + + def get_penalty(self, tool: "BaseTool") -> float: + """Get penalty factor for a tool (0.2-1.0).""" + return self.get_record(tool).penalty + + # Change Detection + def check_changes(self, tools: List["BaseTool"]) -> Dict[str, str]: + """ + Check for tool changes (new/updated/unchanged). + + Returns dict: {tool_key: "new"|"updated"|"unchanged"} + """ + changes = {} + + for tool in tools: + key = self.get_tool_key(tool) + current_hash = self._compute_description_hash(tool) + + if key not in self._records: + changes[key] = "new" + elif self._records[key].description_hash != current_hash: + changes[key] = "updated" + # Clear old evaluation on description change + self._records[key].description_quality = None + self._records[key].description_hash = current_hash + else: + changes[key] = "unchanged" + + new_count = sum(1 for v in changes.values() if v == "new") + updated_count = sum(1 for v in changes.values() if v == "updated") + + if new_count or updated_count: + logger.info(f"Tool changes: {new_count} new, {updated_count} updated") + + return changes + + async def save(self) -> None: + """ + Manually save all records to disk. + + Note: Usually not needed - auto_save handles persistence in + record_execution(), evaluate_description(), and evolve(). + Provided as public API for explicit save when needed. + """ + if self._store: + await self._store.save_all(self._records) + + def clear_cache(self) -> None: + """Clear all cached data.""" + self._records.clear() + if self._store: + self._store.clear() + + def get_stats(self) -> Dict: + """ + Get quality tracking statistics. + + Note: Query API for inspection, may not be called in main flow. + """ + if not self._records: + return {"total_tools": 0} + + records = list(self._records.values()) + + return { + "total_tools": len(records), + "total_executions": sum(r.total_calls for r in records), + "avg_success_rate": ( + sum(r.success_rate for r in records) / len(records) + if records else 0 + ), + "avg_quality_score": ( + sum(r.quality_score for r in records) / len(records) + if records else 0 + ), + "tools_with_description_eval": sum( + 1 for r in records if r.description_quality + ), + "tools_llm_flagged": sum( + 1 for r in records if r.llm_flagged_count > 0 + ), + } + + def get_top_tools( + self, + n: int = 10, + backend: Optional[str] = None, + min_calls: int = 3, + ) -> List[ToolQualityRecord]: + """ + Get top N tools by quality score. + + Args: + n: Number of tools to return + backend: Filter by backend type (optional) + min_calls: Minimum calls required (to filter untested tools) + """ + records = [ + r for r in self._records.values() + if r.total_calls >= min_calls + and (backend is None or r.backend == backend) + ] + + records.sort(key=lambda r: r.quality_score, reverse=True) + return records[:n] + + def get_problematic_tools( + self, + success_rate_threshold: float = 0.5, + min_calls: int = 5, + ) -> List[ToolQualityRecord]: + """ + Get tools with low success rate (candidates for review/removal). + + Args: + success_rate_threshold: Tools below this rate are flagged + min_calls: Minimum calls required (avoid flagging new tools) + """ + return [ + r for r in self._records.values() + if r.total_calls >= min_calls + and r.recent_success_rate < success_rate_threshold + ] + + def get_quality_report(self) -> Dict: + """ + Generate comprehensive quality report for upper layer. + + Returns structured report with: + - Overall stats + - Per-backend breakdown + - Top/problematic tools + - Improvement suggestions + """ + if not self._records: + return {"status": "no_data", "message": "No quality data collected yet"} + + records = list(self._records.values()) + tested_records = [r for r in records if r.total_calls >= 3] + + # Per-backend stats + backends = {} + for r in records: + if r.backend not in backends: + backends[r.backend] = { + "tools": 0, + "total_calls": 0, + "success_count": 0, + "servers": set() + } + backends[r.backend]["tools"] += 1 + backends[r.backend]["total_calls"] += r.total_calls + backends[r.backend]["success_count"] += r.success_count + backends[r.backend]["servers"].add(r.server) + + # Convert sets to counts + for b in backends: + backends[b]["servers"] = len(backends[b]["servers"]) + backends[b]["success_rate"] = ( + backends[b]["success_count"] / backends[b]["total_calls"] + if backends[b]["total_calls"] > 0 else 0 + ) + + # Top and problematic tools + top_tools = self.get_top_tools(5) + problematic = self.get_problematic_tools() + + return { + "summary": { + "total_tools": len(records), + "tested_tools": len(tested_records), + "total_executions": sum(r.total_calls for r in records), + "overall_success_rate": ( + sum(r.success_count for r in records) / + max(1, sum(r.total_calls for r in records)) + ), + "avg_quality_score": ( + sum(r.quality_score for r in tested_records) / len(tested_records) + if tested_records else 0 + ), + }, + "by_backend": backends, + "top_tools": [ + {"key": r.tool_key, "score": r.quality_score, "success_rate": r.success_rate} + for r in top_tools + ], + "problematic_tools": [ + {"key": r.tool_key, "success_rate": r.success_rate, "calls": r.total_calls} + for r in problematic + ], + "recommendations": self._generate_recommendations(records, problematic), + } + + def _generate_recommendations( + self, + records: List[ToolQualityRecord], + problematic: List[ToolQualityRecord], + ) -> List[str]: + """Generate actionable recommendations based on quality data.""" + recommendations = [] + + # Check for problematic tools + if problematic: + tool_names = [r.tool_name for r in problematic[:3]] + recommendations.append( + f"Review low-success tools: {', '.join(tool_names)}" + ) + + # Check for tools needing description evaluation + unevaluated = [r for r in records if not r.description_quality and r.total_calls >= 3] + if unevaluated: + recommendations.append( + f"{len(unevaluated)} tools need description quality evaluation" + ) + + # Check for low description quality + poor_docs = [ + r for r in records + if r.description_quality and r.description_quality.overall_score < 0.5 + ] + if poor_docs: + recommendations.append( + f"{len(poor_docs)} tools have poor documentation quality" + ) + + return recommendations + + def compute_adaptive_quality_weight(self) -> float: + """ + Compute adaptive quality weight based on data confidence. + + Returns higher weight when we have more reliable quality data, + lower weight when data is sparse. + """ + if not self._records: + return 0.1 # Low weight when no data + + records = list(self._records.values()) + tested_count = sum(1 for r in records if r.total_calls >= 3) + + if tested_count == 0: + return 0.1 + + # More tested tools -> higher confidence -> higher weight + coverage = tested_count / len(records) + + # Average calls per tested tool -> data richness + avg_calls = sum(r.total_calls for r in records) / len(records) + richness = min(1.0, avg_calls / 20) # Cap at 20 calls average + + # Combine coverage and richness + confidence = (coverage * 0.5 + richness * 0.5) + + # Map to weight range [0.1, 0.5] + weight = 0.1 + confidence * 0.4 + + return round(weight, 2) + + def should_reevaluate_description(self, tool: "BaseTool") -> bool: + """ + Check if a tool's description should be re-evaluated. + + Triggers re-evaluation when: + - Description hash changed + - Success rate dropped significantly + - No evaluation yet but enough calls + """ + record = self._records.get(self.get_tool_key(tool)) + if not record: + return True + + # Check hash change + current_hash = self._compute_description_hash(tool) + if current_hash != record.description_hash: + return True + + # No evaluation yet but enough data + if not record.description_quality and record.total_calls >= 5: + return True + + # Success rate dropped significantly (maybe description is misleading) + if record.description_quality and record.total_calls >= 10: + if record.recent_success_rate < 0.5 and record.description_quality.overall_score > 0.7: + # High doc quality but low success -> mismatch + return True + + return False + + async def evolve(self, tools: List["BaseTool"]) -> Dict: + """ + Run self-evolution cycle on given tools. + + This method: + 1. Detects tool changes + 2. Re-evaluates descriptions where needed + 3. Updates quality weights + 4. Returns evolution report + """ + report = { + "changes_detected": {}, + "descriptions_evaluated": 0, + "adaptive_weight": 0.0, + "recommendations": [], + } + + # 1. Detect changes + report["changes_detected"] = self.check_changes(tools) + + # 2. Find tools needing re-evaluation + needs_eval = [t for t in tools if self.should_reevaluate_description(t)] + + # 3. Evaluate descriptions (limit to avoid too many LLM calls) + if needs_eval and self._llm_client: + for tool in needs_eval[:5]: # Max 5 per cycle + result = await self.evaluate_description(tool, force=True) + if result: + report["descriptions_evaluated"] += 1 + + # 4. Compute adaptive weight + report["adaptive_weight"] = self.compute_adaptive_quality_weight() + + # 5. Generate recommendations + problematic = self.get_problematic_tools() + report["recommendations"] = self._generate_recommendations( + list(self._records.values()), problematic + ) + + # 6. Update last evolve count + self._last_evolve_count = self._global_execution_count + + # Save + if self._store: + await self._store.save_all(self._records, self._global_execution_count) + + logger.info( + f"Evolution cycle complete: " + f"changes={len([v for v in report['changes_detected'].values() if v != 'unchanged'])}, " + f"evaluated={report['descriptions_evaluated']}, " + f"weight={report['adaptive_weight']}, " + f"global_count={self._global_execution_count}" + ) + + return report + + def should_evolve(self) -> bool: + """Check if evolution should be triggered based on global execution count.""" + return self._global_execution_count >= self._last_evolve_count + self._evolve_interval + + def get_tool_insights(self, tool: "BaseTool") -> Dict: + """ + Get detailed insights for a specific tool (for debugging/analysis). + + Returns comprehensive info about tool's quality history. + """ + record = self._records.get(self.get_tool_key(tool)) + if not record: + return {"status": "not_tracked", "tool": tool.name} + + # Count recent failures + recent_failures_count = sum( + 1 for e in record.recent_executions[-20:] + if not e.success + ) + + return { + "tool_key": record.tool_key, + "total_calls": record.total_calls, + "success_rate": record.success_rate, + "recent_success_rate": record.recent_success_rate, + "avg_execution_time_ms": record.avg_execution_time_ms, + "quality_score": record.quality_score, + "description_quality": { + "overall_score": record.description_quality.overall_score, + "clarity": record.description_quality.clarity, + "completeness": record.description_quality.completeness, + "reasoning": record.description_quality.reasoning, + } if record.description_quality else None, + "llm_flagged_count": record.llm_flagged_count, + "recent_failures_count": recent_failures_count, + "first_seen": record.first_seen.isoformat(), + "last_updated": record.last_updated.isoformat(), + } diff --git a/openspace/grounding/core/quality/store.py b/openspace/grounding/core/quality/store.py new file mode 100644 index 0000000000000000000000000000000000000000..f77da9c3b9872e0b8de46718c4ce0605a700bca0 --- /dev/null +++ b/openspace/grounding/core/quality/store.py @@ -0,0 +1,287 @@ +""" +SQLite-backed persistence for tool quality data. Shares the same database file as SkillStore. + +Storage location (default): + /.openspace/openspace.db + +Tables managed by this module: + tool_quality_records — one row per tool (aggregate stats) + tool_execution_history — rolling window of per-call records + tool_quality_meta — key-value metadata (global_execution_count) +""" + +import sqlite3 +import threading +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional, Tuple + +from .types import ToolQualityRecord, ExecutionRecord, DescriptionQuality +from openspace.utils.logging import Logger +from openspace.config.constants import PROJECT_ROOT + +logger = Logger.get_logger(__name__) + + +_DDL = """ +CREATE TABLE IF NOT EXISTS tool_quality_records ( + tool_key TEXT PRIMARY KEY, + backend TEXT NOT NULL, + server TEXT NOT NULL DEFAULT 'default', + tool_name TEXT NOT NULL, + total_calls INTEGER NOT NULL DEFAULT 0, + success_count INTEGER NOT NULL DEFAULT 0, + total_execution_time_ms REAL NOT NULL DEFAULT 0.0, + llm_flagged_count INTEGER NOT NULL DEFAULT 0, + description_hash TEXT, + desc_clarity REAL, + desc_completeness REAL, + desc_evaluated_at TEXT, + desc_reasoning TEXT, + first_seen TEXT NOT NULL, + last_updated TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_tqr_backend ON tool_quality_records(backend); +CREATE INDEX IF NOT EXISTS idx_tqr_flagged ON tool_quality_records(llm_flagged_count); + +CREATE TABLE IF NOT EXISTS tool_execution_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tool_key TEXT NOT NULL + REFERENCES tool_quality_records(tool_key) ON DELETE CASCADE, + timestamp TEXT NOT NULL, + success INTEGER NOT NULL, + execution_time_ms REAL NOT NULL DEFAULT 0.0, + error_message TEXT +); +CREATE INDEX IF NOT EXISTS idx_teh_key ON tool_execution_history(tool_key); +CREATE INDEX IF NOT EXISTS idx_teh_ts ON tool_execution_history(timestamp); + +CREATE TABLE IF NOT EXISTS tool_quality_meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" + + +class QualityStore: + """SQLite-backed persistence for tool quality data. + + By default uses the same ``.db`` file as ``SkillStore`` + (``/.openspace/openspace.db``). + Each subsystem creates its own tables independently. + """ + + def __init__(self, db_path: Optional[Path] = None): + if db_path is None: + db_dir = PROJECT_ROOT / ".openspace" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / "openspace.db" + + self._db_path = Path(db_path) + self._mu = threading.Lock() + + self._conn = sqlite3.connect( + str(self._db_path), + timeout=30.0, + check_same_thread=False, + ) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=30000") + self._conn.execute("PRAGMA foreign_keys=ON") + self._conn.row_factory = sqlite3.Row + + self._init_tables() + + logger.debug(f"QualityStore ready (SQLite) at {self._db_path}") + + def _init_tables(self) -> None: + with self._mu: + self._conn.executescript(_DDL) + self._conn.commit() + + def load_all(self) -> Tuple[Dict[str, ToolQualityRecord], int]: + """Load all quality records and global execution count.""" + with self._mu: + rows = self._conn.execute( + "SELECT * FROM tool_quality_records" + ).fetchall() + + records: Dict[str, ToolQualityRecord] = {} + for row in rows: + tool_key = row["tool_key"] + record = ToolQualityRecord( + tool_key=tool_key, + backend=row["backend"], + server=row["server"], + tool_name=row["tool_name"], + total_calls=row["total_calls"], + success_count=row["success_count"], + total_execution_time_ms=row["total_execution_time_ms"], + llm_flagged_count=row["llm_flagged_count"], + description_hash=row["description_hash"], + first_seen=datetime.fromisoformat(row["first_seen"]), + last_updated=datetime.fromisoformat(row["last_updated"]), + ) + + # Description quality (all-or-nothing: clarity present → all present) + if row["desc_clarity"] is not None: + record.description_quality = DescriptionQuality( + clarity=row["desc_clarity"], + completeness=row["desc_completeness"], + evaluated_at=datetime.fromisoformat(row["desc_evaluated_at"]), + reasoning=row["desc_reasoning"] or "", + ) + + # Recent execution history (most recent N, restored chronologically) + exec_rows = self._conn.execute( + "SELECT timestamp, success, execution_time_ms, error_message " + "FROM tool_execution_history " + "WHERE tool_key = ? ORDER BY id DESC LIMIT ?", + (tool_key, ToolQualityRecord.MAX_RECENT_EXECUTIONS), + ).fetchall() + record.recent_executions = [ + ExecutionRecord( + timestamp=datetime.fromisoformat(er["timestamp"]), + success=bool(er["success"]), + execution_time_ms=er["execution_time_ms"], + error_message=er["error_message"], + ) + for er in reversed(exec_rows) + ] + + records[tool_key] = record + + # Global metadata + meta_row = self._conn.execute( + "SELECT value FROM tool_quality_meta " + "WHERE key = 'global_execution_count'" + ).fetchone() + global_count = int(meta_row["value"]) if meta_row else 0 + + logger.info( + f"Loaded {len(records)} quality records from SQLite " + f"(global_count={global_count})" + ) + return records, global_count + + async def save_all( + self, + records: Dict[str, ToolQualityRecord], + global_execution_count: int = 0, + ) -> None: + """Persist all records (bulk).""" + self._save_all_sync(records, global_execution_count) + + async def save_record( + self, + record: ToolQualityRecord, + all_records: Dict[str, ToolQualityRecord], + global_execution_count: int = 0, + ) -> None: + """Persist a single record (incremental — much cheaper than save_all).""" + with self._mu: + try: + self._upsert_record(record) + self._conn.execute( + "INSERT OR REPLACE INTO tool_quality_meta " + "(key, value) VALUES (?, ?)", + ("global_execution_count", str(global_execution_count)), + ) + self._conn.commit() + except Exception as e: + self._conn.rollback() + logger.error(f"Failed to save record {record.tool_key}: {e}") + + def clear(self) -> None: + """Delete all quality data.""" + with self._mu: + self._conn.execute("DELETE FROM tool_execution_history") + self._conn.execute("DELETE FROM tool_quality_records") + self._conn.execute("DELETE FROM tool_quality_meta") + self._conn.commit() + logger.info("Quality data cleared") + + def close(self) -> None: + """Close the database connection.""" + try: + self._conn.close() + except Exception: + pass + + def _save_all_sync( + self, + records: Dict[str, ToolQualityRecord], + global_execution_count: int = 0, + ) -> None: + """Synchronous full save (used by async wrapper and migration).""" + with self._mu: + try: + for record in records.values(): + self._upsert_record(record) + self._conn.execute( + "INSERT OR REPLACE INTO tool_quality_meta " + "(key, value) VALUES (?, ?)", + ("global_execution_count", str(global_execution_count)), + ) + self._conn.commit() + except Exception as e: + self._conn.rollback() + logger.error(f"Failed to bulk-save quality records: {e}") + + def _upsert_record(self, record: ToolQualityRecord) -> None: + """Upsert one tool_quality_records row + its execution history. + + Caller MUST hold ``self._mu``. Does NOT commit — caller manages + the transaction boundary. + """ + dq = record.description_quality + self._conn.execute( + """INSERT OR REPLACE INTO tool_quality_records + (tool_key, backend, server, tool_name, + total_calls, success_count, total_execution_time_ms, + llm_flagged_count, description_hash, + desc_clarity, desc_completeness, desc_evaluated_at, desc_reasoning, + first_seen, last_updated) + VALUES (?,?,?,?, ?,?,?, ?,?, ?,?,?,?, ?,?)""", + ( + record.tool_key, + record.backend, + record.server, + record.tool_name, + record.total_calls, + record.success_count, + record.total_execution_time_ms, + record.llm_flagged_count, + record.description_hash, + dq.clarity if dq else None, + dq.completeness if dq else None, + dq.evaluated_at.isoformat() if dq else None, + dq.reasoning if dq else None, + record.first_seen.isoformat(), + record.last_updated.isoformat(), + ), + ) + + # Sync execution history: delete + re-insert. + # For ≤ MAX_RECENT_EXECUTIONS rows this is fast and avoids + # complex diff logic between in-memory and DB state. + self._conn.execute( + "DELETE FROM tool_execution_history WHERE tool_key = ?", + (record.tool_key,), + ) + if record.recent_executions: + self._conn.executemany( + "INSERT INTO tool_execution_history " + "(tool_key, timestamp, success, execution_time_ms, error_message) " + "VALUES (?,?,?,?,?)", + [ + ( + record.tool_key, + e.timestamp.isoformat(), + int(e.success), + e.execution_time_ms, + e.error_message, + ) + for e in record.recent_executions + ], + ) diff --git a/openspace/grounding/core/quality/types.py b/openspace/grounding/core/quality/types.py new file mode 100644 index 0000000000000000000000000000000000000000..84756255991a66f6919531f403640663a06c9eb8 --- /dev/null +++ b/openspace/grounding/core/quality/types.py @@ -0,0 +1,187 @@ +""" +Data types for tool quality tracking. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import ClassVar, List, Optional + + +@dataclass +class ExecutionRecord: + """Single execution record.""" + timestamp: datetime + success: bool + execution_time_ms: float + error_message: Optional[str] = None + + +@dataclass +class DescriptionQuality: + """LLM-evaluated description quality.""" + clarity: float # 0-1: Is the purpose and usage clear? + completeness: float # 0-1: Are inputs/outputs documented? + evaluated_at: datetime + reasoning: str = "" # LLM's reasoning for the scores + + @property + def overall_score(self) -> float: + """Computed overall score (average of all dimensions).""" + return (self.clarity + self.completeness) / 2 + + +@dataclass +class ToolQualityRecord: + """ + Complete quality record for a tool. + + Key: "{backend}:{server}:{tool_name}" + """ + tool_key: str + backend: str + server: str + tool_name: str + + # Execution stats + total_calls: int = 0 + success_count: int = 0 + total_execution_time_ms: float = 0.0 + + # Recent execution history (rolling window) + recent_executions: List[ExecutionRecord] = field(default_factory=list) + + # Description quality (LLM-evaluated) + description_quality: Optional[DescriptionQuality] = None + + # LLM analysis feedback — how many times the analysis LLM flagged this tool. + # LLM-identified issues are also injected into recent_executions as + # ExecutionRecord(success=False, error_message="[LLM] ...") so they feed + # into the same recent_success_rate → penalty pipeline as rule-based tracking. + llm_flagged_count: int = 0 + + # Metadata + description_hash: Optional[str] = None + first_seen: datetime = field(default_factory=datetime.now) + last_updated: datetime = field(default_factory=datetime.now) + + # Keep only recent N executions + MAX_RECENT_EXECUTIONS: ClassVar[int] = 100 + + # Penalty threshold: only penalize tools with success rate below this value + # Tools with success rate >= this threshold get penalty = 1.0 (no penalty) + PENALTY_THRESHOLD: ClassVar[float] = 0.4 + + @property + def success_rate(self) -> float: + """Overall success rate.""" + if self.total_calls == 0: + return 0.0 + return self.success_count / self.total_calls + + @property + def avg_execution_time_ms(self) -> float: + """Average execution time.""" + if self.total_calls == 0: + return 0.0 + return self.total_execution_time_ms / self.total_calls + + @property + def recent_success_rate(self) -> float: + """Success rate from recent executions.""" + if not self.recent_executions: + return self.success_rate + successes = sum(1 for e in self.recent_executions if e.success) + return successes / len(self.recent_executions) + + @property + def consecutive_failures(self) -> int: + """Count consecutive failures from the most recent execution.""" + count = 0 + for exec_record in reversed(self.recent_executions): + if not exec_record.success: + count += 1 + else: + break + return count + + @property + def penalty(self) -> float: + """ + Compute penalty factor based on failure rate. + + Design principles: + - Only penalize tools with success rate < PENALTY_THRESHOLD (default 40%) + - New tools (< 3 calls) get no penalty to allow fair evaluation + + Returns value between 0.2-1.0: + - 1.0: No penalty (success rate >= threshold or insufficient data) + - 0.2: Maximum penalty (consistently failing tool) + """ + if self.total_calls < 3: + return 1.0 + + success_rate = self.recent_success_rate + threshold = self.PENALTY_THRESHOLD + + if success_rate >= threshold: + return 1.0 + + # Linear mapping: penalty = 0.3 + (success_rate / threshold) * 0.7 + base_penalty = 0.3 + (success_rate / threshold) * 0.7 + + # Extra penalty for consecutive failures (indicates systematic issues) + consec = self.consecutive_failures + if consec >= 3: + # 3 consecutive → extra 0.1, 5 consecutive → extra 0.3 + extra_penalty = min(0.3, (consec - 2) * 0.1) + base_penalty -= extra_penalty + + # Clamp to [0.2, 1.0] + return max(0.2, min(1.0, base_penalty)) + + @property + def quality_score(self) -> float: + """ + Legacy quality score for backward compatibility. + Now delegates to penalty property. + """ + return self.penalty + + def add_llm_issue(self, description: str) -> None: + """Record an LLM-identified issue as a failure in recent_executions. + + Unlike ``add_execution()``, this does NOT increment ``total_calls`` + or ``total_execution_time_ms`` — the real execution was already + counted by the rule-based system. The LLM's qualitative judgment + supplements it by catching semantic failures (e.g. HTTP 200 but + wrong data) that rule-based tracking missed. + + The injected record feeds into ``recent_success_rate`` → ``penalty``, + so one unified quality metric drives ranking and future batch updates. + """ + self.llm_flagged_count += 1 + self.recent_executions.append(ExecutionRecord( + timestamp=datetime.now(), + success=False, + execution_time_ms=0.0, + error_message=f"[LLM] {description}", + )) + if len(self.recent_executions) > self.MAX_RECENT_EXECUTIONS: + self.recent_executions = self.recent_executions[-self.MAX_RECENT_EXECUTIONS:] + self.last_updated = datetime.now() + + def add_execution(self, record: ExecutionRecord) -> None: + """Add execution record and update stats.""" + self.total_calls += 1 + self.total_execution_time_ms += record.execution_time_ms + + if record.success: + self.success_count += 1 + + self.recent_executions.append(record) + + # Trim to max size + if len(self.recent_executions) > self.MAX_RECENT_EXECUTIONS: + self.recent_executions = self.recent_executions[-self.MAX_RECENT_EXECUTIONS:] + + self.last_updated = datetime.now() diff --git a/openspace/grounding/core/search_tools.py b/openspace/grounding/core/search_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..4f701fb84807056298a9c4f8323fbcf686ca21fa --- /dev/null +++ b/openspace/grounding/core/search_tools.py @@ -0,0 +1,1118 @@ +from openspace.grounding.core.tool.base import BaseTool +import re +import os +import numpy as np +import httpx +from typing import Iterable, List, Tuple, Dict, Optional, Any, TYPE_CHECKING +from enum import Enum +import json +import pickle +from pathlib import Path +from datetime import datetime + +from .tool import BaseTool +from .types import BackendType +from openspace.llm import LLMClient +from openspace.utils.logging import Logger +from openspace.config.constants import PROJECT_ROOT + +if TYPE_CHECKING: + from .quality import ToolQualityManager + +logger = Logger.get_logger(__name__) + + +class SearchMode(str, Enum): + SEMANTIC = "semantic" + KEYWORD = "keyword" + HYBRID = "hybrid" + + +class ToolRanker: + """ + ToolRanker: rank tools by keyword, semantic or hybrid + """ + # Cache version for persistent storage - increment when cache format changes + CACHE_VERSION = 1 + + def __init__( + self, + model_name: Optional[str] = None, + cache_dir: Optional[str | Path] = None, + enable_cache_persistence: bool = False + ): + """Initialize ToolRanker. + + Args: + model_name: Embedding model name. If None, will use env or config value. + cache_dir: Directory to store persistent embedding cache. + enable_cache_persistence: Whether to persist embeddings to disk. + """ + # Check for remote API config from environment + self._api_base_url = os.getenv("EMBEDDING_BASE_URL") + self._api_key = os.getenv("EMBEDDING_API_KEY") + self._use_remote_api = bool(self._api_key and self._api_base_url) + + # Get model name: env > param > config > default + if model_name is None: + model_name = os.getenv("EMBEDDING_MODEL") + + if model_name is None: + try: + from openspace.config import get_config + config = get_config() + model_name = config.tool_search.embedding_model + except Exception as exc: + logger.warning(f"Failed to load config, using default model: {exc}") + model_name = "BAAI/bge-small-en-v1.5" + + self._model_name = model_name + self._embed_model = None # lazy load + self._embedding_fn = None + + if self._use_remote_api: + logger.info(f"Using remote embedding API: {self._api_base_url}, model: {model_name}") + + # Persistent cache settings + self._enable_cache_persistence = enable_cache_persistence + if cache_dir is None: + cache_dir = PROJECT_ROOT / ".openspace" / "embedding_cache" + self._cache_dir = Path(cache_dir) + + # Log cache settings + logger.info( + f"ToolRanker initialized: enable_cache_persistence={enable_cache_persistence}, " + f"cache_dir={self._cache_dir}" + ) + + # Structured in-memory cache + # Structure: {backend: {server: {tool_name: {"embedding": np.ndarray, "description": str, "cached_at": str}}}} + self._structured_cache: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} + + # For backward compatibility and quick lookup: {text -> (backend, server, tool_name)} + self._text_to_key: Dict[str, Tuple[str, str, str]] = {} + + # Load persistent cache if enabled + if self._enable_cache_persistence: + logger.info(f"Loading persistent cache from {self._cache_dir}") + self._load_persistent_cache() + + def _get_cache_key(self, tool: BaseTool) -> Tuple[str, str, str]: + """Get structured cache key (backend, server, tool_name) from tool.""" + if tool.is_bound: + backend = tool.runtime_info.backend.value + server = tool.runtime_info.server_name or "default" + else: + if not tool.backend_type or tool.backend_type == BackendType.NOT_SET: + backend = "UNKNOWN" + else: + backend = tool.backend_type.value + server = "default" + + return (backend, server, tool.name) + + def _get_cache_file_path(self) -> Path: + """Get the cache file path for the current model.""" + # Use model name in filename to support multiple models + safe_model_name = self._model_name.replace("/", "_").replace("\\", "_") + return self._cache_dir / f"embeddings_{safe_model_name}_v{self.CACHE_VERSION}.pkl" + + def _load_persistent_cache(self) -> None: + """Load embeddings from disk cache.""" + cache_file = self._get_cache_file_path() + + if not cache_file.exists(): + logger.debug(f"No persistent cache found at {cache_file}") + return + + try: + with open(cache_file, 'rb') as f: + data = pickle.load(f) + + # Validate cache version + if isinstance(data, dict) and data.get("version") == self.CACHE_VERSION: + self._structured_cache = data.get("embeddings", {}) + self._rebuild_text_index() + + # Count total embeddings + total = sum( + len(tools) + for backend in self._structured_cache.values() + for tools in backend.values() + ) + logger.info(f"Loaded {total} embeddings from cache: {cache_file}") + else: + logger.warning(f"Cache version mismatch or invalid format, starting fresh") + self._structured_cache = {} + except Exception as exc: + logger.warning(f"Failed to load persistent cache: {exc}") + self._structured_cache = {} + + def _rebuild_text_index(self) -> None: + """Rebuild text-to-key mapping for quick lookup.""" + self._text_to_key.clear() + for backend, servers in self._structured_cache.items(): + for server, tools in servers.items(): + for tool_name, tool_data in tools.items(): + desc = tool_data.get("description", "") + text = f"{tool_name}: {desc}" + self._text_to_key[text] = (backend, server, tool_name) + + def _save_persistent_cache(self) -> None: + """Save embeddings to disk cache.""" + if not self._enable_cache_persistence or not self._structured_cache: + return + + cache_file = self._get_cache_file_path() + + try: + # Create directory if it doesn't exist + cache_file.parent.mkdir(parents=True, exist_ok=True) + + # Build cache data with metadata + cache_data = { + "version": self.CACHE_VERSION, + "model_name": self._model_name, + "last_updated": datetime.now().isoformat(), + "embeddings": self._structured_cache + } + + # Save cache + with open(cache_file, 'wb') as f: + pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL) + + # Count total embeddings + total = sum( + len(tools) + for backend in self._structured_cache.values() + for tools in backend.values() + ) + logger.debug(f"Saved {total} embeddings to cache: {cache_file}") + except Exception as exc: + logger.warning(f"Failed to save persistent cache: {exc}") + + def rank( + self, + query: str, + tools: List[BaseTool], + *, + top_k: int = 50, + mode: SearchMode = SearchMode.SEMANTIC, + ) -> List[Tuple[BaseTool, float]]: + if mode == SearchMode.KEYWORD: + return self._keyword_search(query, tools, top_k) + if mode == SearchMode.SEMANTIC: + return self._semantic_search(query, tools, top_k) + # hybrid + return self._hybrid_search(query, tools, top_k) + + @staticmethod + def _tokenize(text: str) -> list[str]: + tokens = re.split(r"[^\w]+", text.lower()) + tokens = [tok for tok in tokens if tok] + return tokens + + def _keyword_search( + self, query: str, tools: Iterable[BaseTool], top_k: int + ) -> List[Tuple[BaseTool, float]]: + try: + from rank_bm25 import BM25Okapi # type: ignore + except ImportError: + BM25Okapi = None # fallback below + + tool_list = list(tools) + if not tool_list: + return [] + + corpus_tokens: list[list[str]] = [self._tokenize(f"{t.name} {t.description}") for t in tool_list] + query_tokens = self._tokenize(query) + + if BM25Okapi and corpus_tokens: + bm25 = BM25Okapi(corpus_tokens) + scores = bm25.get_scores(query_tokens) + scored = [(t, float(s)) for t, s in zip(tool_list, scores, strict=True)] + else: + # fallback: simple term overlap ratio + q_set = set(query_tokens) + scored = [] + for t, toks in zip(tool_list, corpus_tokens, strict=True): + if not toks: + scored.append((t, 0.0)) # Include tool with 0 score + continue + overlap = q_set.intersection(toks) + score = len(overlap) / len(q_set) if len(q_set) > 0 else 0.0 + scored.append((t, score)) + + scored.sort(key=lambda x: x[1], reverse=True) + result = scored[:top_k] + + # If no matches found (all scores are 0), return all tools + if not result or all(score == 0.0 for _, score in result): + logger.debug(f"Keyword search found no matches, returning all {len(tool_list)} tools") + return [(t, 0.0) for t in tool_list] + + return result + + def _ensure_model(self) -> bool: + """Ensure embedding model is ready (local or remote).""" + if self._embedding_fn is not None: + return True + + if self._use_remote_api: + return self._init_remote_embedding() + return self._init_local_embedding() + + def _init_remote_embedding(self) -> bool: + """Initialize remote embedding API (OpenRouter/OpenAI compatible).""" + try: + def embed_texts(texts: List[str]) -> List[np.ndarray]: + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{self._api_base_url}/embeddings", + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + }, + json={"model": self._model_name, "input": texts} + ) + response.raise_for_status() + data = response.json() + return [np.array(item["embedding"]) for item in data["data"]] + + self._embedding_fn = embed_texts + logger.info(f"Remote embedding API initialized: {self._model_name}") + return True + except Exception as exc: + logger.error(f"Failed to initialize remote embedding API: {exc}") + return False + + def _init_local_embedding(self) -> bool: + """Initialize local fastembed model.""" + try: + from fastembed import TextEmbedding + logger.debug(f"fastembed imported successfully, loading model: {self._model_name}") + except ImportError as e: + logger.warning( + f"fastembed not installed (ImportError: {e}), semantic search unavailable. " + f"Install with: pip install fastembed" + ) + return False + + try: + logger.info(f"Loading embedding model: {self._model_name}...") + self._embed_model = TextEmbedding(model_name=self._model_name) + self._embedding_fn = lambda txts: list(self._embed_model.embed(txts)) + logger.info(f"Embedding model '{self._model_name}' loaded successfully") + return True + except Exception as exc: + logger.error(f"Embedding model '{self._model_name}' loading failed: {exc}") + return False + + def _get_embedding(self, tool: BaseTool) -> Optional[np.ndarray]: + """Get embedding from structured cache.""" + backend, server, tool_name = self._get_cache_key(tool) + + if backend not in self._structured_cache: + return None + if server not in self._structured_cache[backend]: + return None + if tool_name not in self._structured_cache[backend][server]: + return None + + return self._structured_cache[backend][server][tool_name].get("embedding") + + def _set_embedding(self, tool: BaseTool, embedding: np.ndarray) -> None: + """Store embedding in structured cache.""" + backend, server, tool_name = self._get_cache_key(tool) + + # Initialize nested structure if needed + if backend not in self._structured_cache: + self._structured_cache[backend] = {} + if server not in self._structured_cache[backend]: + self._structured_cache[backend][server] = {} + + # Store embedding with metadata + self._structured_cache[backend][server][tool_name] = { + "embedding": embedding, + "description": tool.description or "", + "cached_at": datetime.now().isoformat() + } + + # Update text index for backward compatibility + text = f"{tool.name}: {tool.description}" + self._text_to_key[text] = (backend, server, tool_name) + + def _semantic_search( + self, query: str, tools: Iterable[BaseTool], top_k: int + ) -> List[Tuple[BaseTool, float]]: + if not self._ensure_model(): + logger.debug("Semantic search unavailable, returning empty list") + return [] + + tools_list = list(tools) + + # Collect embeddings with cache reuse + missing_tools = [t for t in tools_list if self._get_embedding(t) is None] + cache_updated = False + + if missing_tools: + try: + # Generate embeddings for missing tools + missing_texts = [f"{t.name}: {t.description}" for t in missing_tools] + new_embs = self._embedding_fn(missing_texts) + + for tool, emb in zip(missing_tools, new_embs, strict=True): + self._set_embedding(tool, emb) + + cache_updated = True + logger.debug(f"Computed embeddings for {len(missing_tools)} new tools") + except Exception as exc: + logger.error("Failed to generate embeddings: %s", exc) + return [] + + # Save to persistent cache if updated + if cache_updated: + self._save_persistent_cache() + + try: + q_emb = self._embedding_fn([query])[0] + except Exception as exc: + logger.error("Failed to embed query: %s", exc) + return [] + + scored: list[tuple[BaseTool, float]] = [] + for t in tools_list: + emb = self._get_embedding(t) + if emb is None: + # Should not happen, but handle gracefully + logger.warning(f"No embedding found for tool: {t.name}") + scored.append((t, 0.0)) + continue + + # Calculate cosine similarity with zero-division protection + q_norm = np.linalg.norm(q_emb) + emb_norm = np.linalg.norm(emb) + if q_norm == 0 or emb_norm == 0: + sim = 0.0 + else: + sim = float(np.dot(q_emb, emb) / (q_norm * emb_norm)) + scored.append((t, sim)) + + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:top_k] + + def _hybrid_search( + self, query: str, tools: Iterable[BaseTool], top_k: int + ) -> List[Tuple[BaseTool, float]]: + # keyword filter + kw_top = self._keyword_search(query, tools, top_k * 3) + if not kw_top: + # No keyword matches, try semantic search + semantic_results = self._semantic_search(query, tools, top_k) + if semantic_results: + return semantic_results + # Both failed, return top N tools + logger.warning("Both keyword and semantic search failed, returning top N tools") + return [(t, 0.0) for t in list(tools)[:top_k]] + + # semantic ranking on keyword results + semantic_results = self._semantic_search(query, [t for t, _ in kw_top], top_k) + if semantic_results: + return semantic_results + + # Semantic unavailable, return keyword results + logger.debug("Semantic search unavailable, using keyword results only") + return kw_top[:top_k] + + def get_cache_stats(self) -> Dict[str, Any]: + """Get statistics about the embedding cache. + + Returns: + Dict with structure: { + "total_embeddings": int, + "backends": { + "backend_name": { + "total": int, + "servers": { + "server_name": int # count of tools + } + } + } + } + """ + stats = { + "total_embeddings": 0, + "backends": {} + } + + for backend, servers in self._structured_cache.items(): + backend_total = 0 + server_stats = {} + + for server, tools in servers.items(): + tool_count = len(tools) + backend_total += tool_count + server_stats[server] = tool_count + + stats["backends"][backend] = { + "total": backend_total, + "servers": server_stats + } + stats["total_embeddings"] += backend_total + + return stats + + def clear_cache(self, backend: Optional[str] = None, server: Optional[str] = None) -> int: + """Clear embeddings from cache. + + Args: + backend: If provided, only clear this backend. If None, clear all. + server: If provided (and backend is provided), only clear this server. + + Returns: + Number of embeddings cleared. + """ + cleared_count = 0 + + if backend is None: + # Clear everything + for b in self._structured_cache.values(): + for s in b.values(): + cleared_count += len(s) + self._structured_cache.clear() + self._text_to_key.clear() + elif server is None: + # Clear specific backend + if backend in self._structured_cache: + for s in self._structured_cache[backend].values(): + cleared_count += len(s) + del self._structured_cache[backend] + # Rebuild text index + self._rebuild_text_index() + else: + # Clear specific backend+server + if backend in self._structured_cache and server in self._structured_cache[backend]: + cleared_count = len(self._structured_cache[backend][server]) + del self._structured_cache[backend][server] + # Clean up empty backend + if not self._structured_cache[backend]: + del self._structured_cache[backend] + # Rebuild text index + self._rebuild_text_index() + + # Save after clearing + if cleared_count > 0 and self._enable_cache_persistence: + self._save_persistent_cache() + logger.info(f"Cleared {cleared_count} embeddings from cache") + + return cleared_count + + +class SearchDebugInfo: + """Debug information from tool search process.""" + + def __init__(self): + self.search_mode: str = "" + self.total_candidates: int = 0 + self.mcp_count: int = 0 + self.non_mcp_count: int = 0 + + # LLM filter info + self.llm_filter_used: bool = False + self.llm_brief_plan: str = "" + self.llm_utility_tools: Dict[str, List[str]] = {} # server -> tool names + self.llm_domain_servers: List[str] = [] + self.llm_utility_count: int = 0 + self.llm_domain_count: int = 0 + + # Semantic search scores + self.tool_scores: List[Dict[str, Any]] = [] # [{name, server, score, selected}] + + # Final selected tools + self.selected_tools: List[Dict[str, Any]] = [] # [{name, server, backend}] + + def to_dict(self) -> Dict[str, Any]: + return { + "search_mode": self.search_mode, + "total_candidates": self.total_candidates, + "mcp_count": self.mcp_count, + "non_mcp_count": self.non_mcp_count, + "llm_filter": { + "used": self.llm_filter_used, + "brief_plan": self.llm_brief_plan, + "utility_tools": self.llm_utility_tools, + "domain_servers": self.llm_domain_servers, + "utility_count": self.llm_utility_count, + "domain_count": self.llm_domain_count, + }, + "tool_scores": self.tool_scores, + "selected_tools": self.selected_tools, + } + + +class SearchCoordinator(BaseTool): + _name = "_filter_tools" + _description = "Internal helper: filter & rank tools from a given list." + + # Fallback defaults when config loading fails + DEFAULT_MAX_TOOLS: int = 20 + DEFAULT_LLM_FILTER: bool = True + DEFAULT_LLM_THRESHOLD: int = 50 + DEFAULT_CACHE_PERSISTENCE: bool = False + DEFAULT_SEARCH_MODE: str = "hybrid" + + @classmethod + def get_parameters_schema(cls) -> Dict[str, Any]: + """Override to avoid JSON schema generation for list[BaseTool] parameter. + + The _arun method uses `candidate_tools: list[BaseTool]` which cannot be + converted to JSON Schema because BaseTool is an ABC class, not a Pydantic model. + Since this is an internal tool, we return an empty schema. + """ + return {} + + def __init__( + self, + *, + max_tools: Optional[int] = None, + llm: LLMClient = LLMClient(), + enable_llm_filter: Optional[bool] = None, + llm_filter_threshold: Optional[int] = None, + enable_cache_persistence: Optional[bool] = None, + cache_dir: Optional[str | Path] = None, + quality_manager: Optional["ToolQualityManager"] = None, + enable_quality_ranking: bool = True, + ): + """Create a SearchCoordinator. + + Args: + max_tools: max number of tools to return. If None, will use the value from config. + llm: optional async LLM, used to filter backend/server first + enable_llm_filter: whether to use LLM to pre-filter by backend/server. + If None, uses config value. + llm_filter_threshold: only apply LLM filter when tool count > this threshold. + If None, always apply (when enabled). + enable_cache_persistence: whether to persist embeddings to disk. If None, uses config value. + cache_dir: directory to store persistent embedding cache. If None, uses config value or default. + """ + super().__init__() + + # Load config (may be None if loading fails) + tool_search_config = None + try: + from openspace.config import get_config + tool_search_config = getattr(get_config(), 'tool_search', None) + except Exception as exc: + logger.warning(f"Failed to load config: {exc}") + + def resolve(user_value, config_attr: str, default): + """Priority: user_value → config → default""" + if user_value is not None: + return user_value + if tool_search_config is not None: + config_value = getattr(tool_search_config, config_attr, None) + if config_value is not None: + return config_value + return default + + # Resolve each setting with priority: user → config → default + self.max_tools = resolve(max_tools, 'max_tools', self.DEFAULT_MAX_TOOLS) + enable_llm_filter = resolve(enable_llm_filter, 'enable_llm_filter', self.DEFAULT_LLM_FILTER) + llm_filter_threshold = resolve(llm_filter_threshold, 'llm_filter_threshold', self.DEFAULT_LLM_THRESHOLD) + enable_cache_persistence = resolve(enable_cache_persistence, 'enable_cache_persistence', self.DEFAULT_CACHE_PERSISTENCE) + cache_dir = resolve(cache_dir, 'cache_dir', None) + self._default_mode = resolve(None, 'search_mode', self.DEFAULT_SEARCH_MODE) + + # Log cache settings for debugging + logger.info( + f"SearchCoordinator initialized with cache settings: " + f"enable_cache_persistence={enable_cache_persistence}, cache_dir={cache_dir}" + ) + + self._ranker = ToolRanker( + enable_cache_persistence=enable_cache_persistence, + cache_dir=cache_dir + ) + self._llm: LLMClient | None = llm if llm is not None else LLMClient() + + # LLM filter settings + self._enable_llm_filter = enable_llm_filter + self._llm_filter_threshold = llm_filter_threshold + + # Quality-aware ranking settings + self._quality_manager = quality_manager + self._enable_quality_ranking = enable_quality_ranking + + # Debug info from last search + self._last_search_debug_info: Optional[SearchDebugInfo] = None + + async def _arun( + self, + task_prompt: str, + candidate_tools: list[BaseTool], + *, + max_tools: int | None = None, + mode: str | None = None, # "semantic" | "keyword" | "hybrid" + ) -> list[BaseTool]: + max_tools = self.max_tools if max_tools is None else max_tools + mode = self._default_mode if mode is None else mode + + # Initialize debug info + debug_info = SearchDebugInfo() + debug_info.search_mode = mode + debug_info.total_candidates = len(candidate_tools) + self._last_search_debug_info = debug_info + + # Cache check + cache_key = (id(candidate_tools), task_prompt, mode, max_tools) + if not hasattr(self, "_query_cache"): + self._query_cache: Dict[tuple, list[BaseTool]] = {} + if cache_key in self._query_cache: + return self._query_cache[cache_key] + + # Split MCP tools and non-MCP tools + # Non-MCP tools (shell, gui, web, etc.) are always included, skip all filtering + mcp_tools = [] + non_mcp_tools = [] + + for t in candidate_tools: + if t.is_bound: + backend = t.runtime_info.backend.value + else: + backend = t.backend_type.value if t.backend_type else "UNKNOWN" + + if backend.lower() == "mcp": + mcp_tools.append(t) + else: + non_mcp_tools.append(t) + + debug_info.mcp_count = len(mcp_tools) + debug_info.non_mcp_count = len(non_mcp_tools) + logger.info(f"Tool split: {len(mcp_tools)} MCP, {len(non_mcp_tools)} non-MCP (always included)") + + # If MCP tools within limit, return all + if len(mcp_tools) <= max_tools: + result = mcp_tools + non_mcp_tools + self._query_cache[cache_key] = result + self._populate_selected_tools(debug_info, result) + return result + + mcp_count = len(mcp_tools) + should_use_llm_filter = ( + self._llm and + self._enable_llm_filter and + mcp_count > self._llm_filter_threshold + ) + + # Path 1: LLM pre-filter (large MCP tool set) + if should_use_llm_filter: + logger.info(f"Path 1: MCP count ({mcp_count}) > threshold, using LLM filter...") + debug_info.llm_filter_used = True + + try: + utility_tools, domain_tools, llm_filter_info = await self._llm_filter_with_planning( + task_prompt, mcp_tools + ) + + # Record LLM filter results + debug_info.llm_brief_plan = llm_filter_info.get("brief_plan", "") + debug_info.llm_utility_tools = llm_filter_info.get("utility_tools", {}) + debug_info.llm_domain_servers = llm_filter_info.get("domain_servers", []) + + utility_count = len(utility_tools) + domain_count = len(domain_tools) + debug_info.llm_utility_count = utility_count + debug_info.llm_domain_count = domain_count + total_count = utility_count + domain_count + + if total_count <= max_tools: + mcp_result = utility_tools + domain_tools + else: + # Exceeds limit: keep utility, search domain + domain_quota = max(max_tools - utility_count, 5) + logger.info( + f"Total ({total_count}) > max_tools ({max_tools}), " + f"keeping {utility_count} utility, searching {domain_count} domain (quota: {domain_quota})" + ) + + # Compute scores for utility tools (marked as LLM-selected) + if utility_tools: + utility_ranked = self._ranker.rank( + task_prompt, utility_tools, + top_k=len(utility_tools), mode=SearchMode(mode) + ) + self._record_tool_scores(debug_info, utility_ranked, is_selected=True) + + if domain_tools: + # Rank all domain tools to see all scores for debugging + all_domain_ranked = self._ranker.rank( + task_prompt, domain_tools, + top_k=len(domain_tools), mode=SearchMode(mode) + ) + # Save scores for all domain tools (mark which ones are selected) + for i, (tool, score) in enumerate(all_domain_ranked): + server_name = None + if tool.is_bound and tool.runtime_info: + server_name = tool.runtime_info.server_name + debug_info.tool_scores.append({ + "name": tool.name, + "server": server_name, + "score": round(score, 4), + "selected": i < domain_quota, + }) + searched_domain = [t for t, _ in all_domain_ranked[:domain_quota]] + else: + searched_domain = [] + + mcp_result = utility_tools + searched_domain + + except Exception as exc: + logger.warning(f"LLM filter failed ({exc}), fallback to direct ranking") + ranked = self._ranker.rank(task_prompt, mcp_tools, top_k=max_tools, mode=SearchMode(mode)) + self._record_tool_scores(debug_info, ranked, is_selected=True) + mcp_result = [t for t, _ in ranked] + + # Path 2: Plan-enhanced search (small MCP tool set) + else: + logger.info(f"Path 2: MCP count ({mcp_count}) <= threshold, using enhanced search...") + debug_info.llm_filter_used = False + + if self._llm: + try: + enhanced_query = await self._generate_search_query(task_prompt) + except Exception: + enhanced_query = task_prompt + else: + enhanced_query = task_prompt + + try: + ranked = self._ranker.rank( + enhanced_query, mcp_tools, + top_k=max_tools, mode=SearchMode(mode) + ) + # Record all scores from semantic search + self._record_tool_scores(debug_info, ranked, is_selected=True) + mcp_result = [t for t, _ in ranked] + except Exception: + ranked = self._ranker._keyword_search( + enhanced_query, mcp_tools, max_tools + ) + self._record_tool_scores(debug_info, ranked, is_selected=True) + mcp_result = [t for t, _ in ranked] + + # Apply quality ranking on MCP results + if self._enable_quality_ranking and self._quality_manager and mcp_result: + try: + ranked_with_scores = [(t, 1.0) for t in mcp_result] + ranked_with_scores = self._quality_manager.adjust_ranking(ranked_with_scores) + mcp_result = [t for t, _ in ranked_with_scores] + except Exception: + pass + + # Limit MCP tools, then combine with non-MCP tools + mcp_result = mcp_result[:max_tools] + result = mcp_result + non_mcp_tools + + # Populate final selected tools in debug info + self._populate_selected_tools(debug_info, result) + + self._log_search_results(candidate_tools, result, mode) + self._query_cache[cache_key] = result + return result + + def _record_tool_scores( + self, + debug_info: SearchDebugInfo, + ranked: List[Tuple[BaseTool, float]], + is_selected: bool = False + ) -> None: + """Record tool scores from ranking results.""" + for tool, score in ranked: + server_name = None + if tool.is_bound and tool.runtime_info: + server_name = tool.runtime_info.server_name + + debug_info.tool_scores.append({ + "name": tool.name, + "server": server_name, + "score": round(score, 4), + "selected": is_selected, + }) + + def _populate_selected_tools( + self, + debug_info: SearchDebugInfo, + tools: List[BaseTool] + ) -> None: + """Populate selected tools in debug info.""" + for tool in tools: + backend = "UNKNOWN" + server_name = None + + if tool.is_bound and tool.runtime_info: + backend = tool.runtime_info.backend.value + server_name = tool.runtime_info.server_name + elif tool.backend_type: + backend = tool.backend_type.value + + debug_info.selected_tools.append({ + "name": tool.name, + "server": server_name, + "backend": backend, + }) + + async def _llm_filter_with_planning( + self, + task_prompt: str, + tools: list[BaseTool] + ) -> tuple[list[BaseTool], list[BaseTool], Dict[str, Any]]: + """ + LLM pre-filter for MCP servers. + Returns (utility_tools, domain_tools, llm_filter_info). + """ + from collections import defaultdict + + # Group tools by server name + server_tools: Dict[str, list[BaseTool]] = defaultdict(list) + for t in tools: + if t.is_bound and t.runtime_info: + server = t.runtime_info.server_name or "default" + else: + server = "unknown" + server_tools[server].append(t) + + # Build tool name -> tool object mapping + tool_name_map: Dict[str, BaseTool] = {t.name: t for t in tools} + + # Build server description with tool names + lines: list[str] = ["Available MCP servers:"] + lines.append("") + + for server, tool_list in server_tools.items(): + lines.append(f"### Server: {server} ({len(tool_list)} tools)") + tool_names = [t.name for t in tool_list] + lines.append(f" All tools: {', '.join(tool_names)}") + if tool_list: + lines.append(f" Example capabilities:") + for tool in tool_list[:5]: + tool_desc = tool.description or "No description" + if len(tool_desc) > 100: + tool_desc = tool_desc[:97] + "..." + lines.append(f" - {tool.name}: {tool_desc}") + lines.append("") + + servers_block = "\n".join(lines) + + TOOL_FILTER_SYSTEM_PROMPT = f"""You are an expert tool selection assistant. + +# Your task +Analyze the given task and determine which MCP servers and tools are needed. +Think about how you would accomplish this task step by step, then classify needed servers and tools. + +# Important guidelines +- **Focus on tool names and capabilities**: Carefully examine the tool names to understand what each server can do +- **Be inclusive for domain servers**: If a server has tools that might be relevant to the core task, include it +- **Be precise for utility tools**: Only select the specific auxiliary tools needed (e.g., file save, time query) +- **When in doubt, include in domain_servers**: It's better to include a server than miss relevant tools + +{servers_block} + +# Output format +Return ONLY a JSON object (no markdown, no explanation): +{{ + "brief_plan": "1-2 sentence execution plan", + "utility_tools": {{ + "server1": ["tool1", "tool2"] + }}, + "domain_servers": ["server2", "server3"] +}} + +- **utility_tools**: Dict mapping server name to list of specific tool names. + These are auxiliary tools for supporting operations (e.g., filesystem: ["write_file"], time-server: ["get_time"]). + Only include the specific tools needed, NOT the entire server. +- **domain_servers**: Server names that directly provide the main capabilities for the task. + All tools from these servers will be considered. Be inclusive here.""" + + user_query = f"Task: {task_prompt}\n\nClassify the needed servers and tools." + + messages_text = LLMClient.format_messages_to_text([ + {"role": "system", "content": TOOL_FILTER_SYSTEM_PROMPT}, + {"role": "user", "content": user_query} + ]) + resp = await self._llm.complete(messages_text) + content = resp["message"]["content"].strip() + + # Extract JSON + code_block_pattern = r'```(?:json)?\s*\n?(.*?)\n?```' + match = re.search(code_block_pattern, content, re.DOTALL) + if match: + content = match.group(1).strip() + else: + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + content = json_match.group() + + try: + result = json.loads(content) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse LLM response: {e}") + return [], tools + + # Parse utility_tools: {server: [tool_names]} + utility_tools_config = result.get("utility_tools", {}) + domain_servers = set(result.get("domain_servers", [])) + brief_plan = result.get("brief_plan", "N/A") + + logger.info(f"LLM Planning: {brief_plan}") + logger.info(f"Utility tools: {utility_tools_config}") + logger.info(f"Domain servers: {domain_servers}") + + # Collect utility tools (specific tools only) + utility_tools = [] + for server_name, tool_names in utility_tools_config.items(): + if server_name in server_tools: + server_tool_names = {t.name for t in server_tools[server_name]} + for tool_name in tool_names: + if tool_name in server_tool_names and tool_name in tool_name_map: + utility_tools.append(tool_name_map[tool_name]) + + # Collect domain tools (entire servers) + domain_tools = [] + for server, tool_list in server_tools.items(): + if server in domain_servers: + domain_tools.extend(tool_list) + + logger.info(f"LLM filter result: {len(utility_tools)} utility tools, {len(domain_tools)} domain tools") + + # Build LLM filter info for debugging + llm_filter_info = { + "brief_plan": brief_plan, + "utility_tools": utility_tools_config, + "domain_servers": list(domain_servers), + } + + # Fallback if no match + if not utility_tools and not domain_tools: + logger.warning(f"LLM filter matched 0 tools, returning all as domain") + return [], tools, llm_filter_info + + return utility_tools, domain_tools, llm_filter_info + + async def _generate_search_query(self, task_prompt: str) -> str: + prompt = f"""Task: {task_prompt} + +List keywords for the capabilities needed (comma-separated, brief):""" + + resp = await self._llm.complete(prompt) + capabilities = resp["message"]["content"].strip().replace("\n", " ") + + enhanced_query = f"{task_prompt} {capabilities}" + logger.debug(f"Enhanced search query: {enhanced_query[:150]}...") + + return enhanced_query + + def _log_search_results(self, all_tools: list[BaseTool], filtered_tools: list[BaseTool], mode: str) -> None: + """ + Log search results in a concise, grouped format. + Shows backend/server breakdown and tool names (truncated if too many). + """ + from collections import defaultdict + + # Group filtered tools by backend and server + grouped: Dict[str, Dict[str | None, list[str]]] = defaultdict(lambda: defaultdict(list)) + + for t in filtered_tools: + # Get backend and server info + if t.is_bound: + backend = t.runtime_info.backend.value + server = t.runtime_info.server_name if backend.lower() == "mcp" else None + else: + if not t.backend_type or t.backend_type == BackendType.NOT_SET: + backend = "UNKNOWN" + server = None + else: + backend = t.backend_type.value + server = None + + grouped[backend][server].append(t.name) + + # Build concise summary + lines = [f"\n{'='*60}"] + lines.append(f"🔍 Tool Search Results (mode: {mode})") + lines.append(f" {len(all_tools)} candidates → {len(filtered_tools)} selected tools") + lines.append(f"{'='*60}") + + for backend, srv_map in sorted(grouped.items()): + backend_total = sum(len(tools) for tools in srv_map.values()) + lines.append(f"\n📦 {backend} ({backend_total} tools)") + + for server, tool_names in sorted(srv_map.items()): + if backend.lower() == "mcp" and server: + prefix = f" └─ {server}: " + else: + prefix = f" └─ " + + # Limit display to avoid overwhelming output + if len(tool_names) <= 8: + tools_display = ", ".join(tool_names) + else: + tools_display = ", ".join(tool_names[:8]) + f" ... (+{len(tool_names)-8} more)" + + lines.append(f"{prefix}{tools_display}") + + lines.append(f"{'='*60}\n") + + # Use info level so users can see it + logger.info("\n".join(lines)) + + @staticmethod + def _format_tool_list(tools: list[BaseTool]) -> str: + rows = [f"{i}. **{t.name}**: {t.description}" for i, t in enumerate(tools, 1)] + return f"Total {len(tools)} tools, list out directly:\n\n" + "\n".join(rows) + + @staticmethod + def _format_ranked(results: list[tuple[BaseTool, float]], mode: SearchMode) -> str: + lines = [f"Search results (mode={mode}) total {len(results)}:\n"] + for i, (tool, score) in enumerate(results, 1): + lines.append(f"{i}. {tool.name} (score: {score:.3f})\n {tool.description}") + return "\n".join(lines) + + def _run(self, *args, **kwargs): + raise NotImplementedError("SearchCoordinator only supports asynchronous calls. Use _arun instead.") + + def get_embedding_cache_stats(self) -> Dict[str, Any]: + """Get statistics about the embedding cache. + + Returns: + Dict with cache statistics including total embeddings and breakdown by backend/server. + """ + return self._ranker.get_cache_stats() + + def clear_embedding_cache(self, backend: Optional[str] = None, server: Optional[str] = None) -> int: + """Clear embeddings from cache. + + Args: + backend: If provided, only clear this backend. If None, clear all. + server: If provided (and backend is provided), only clear this server. + + Returns: + Number of embeddings cleared. + """ + return self._ranker.clear_cache(backend=backend, server=server) + + def get_last_search_debug_info(self) -> Optional[Dict[str, Any]]: + """Get debug info from the last search operation. + + Returns: + Dict containing search debug info, or None if no search has been performed. + Includes: + - search_mode: The search mode used + - total_candidates: Total number of candidate tools + - mcp_count/non_mcp_count: Tool counts by type + - llm_filter: LLM filter information if used + - tool_scores: Similarity scores for each tool + - selected_tools: Final selected tools + """ + if self._last_search_debug_info is None: + return None + return self._last_search_debug_info.to_dict() \ No newline at end of file diff --git a/openspace/grounding/core/security/__init__.py b/openspace/grounding/core/security/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f261c22ca3cfa143eeafed9043f59595e7eea4 --- /dev/null +++ b/openspace/grounding/core/security/__init__.py @@ -0,0 +1,20 @@ +from .sandbox import BaseSandbox, SandboxManager +from .policies import SecurityPolicyManager, SecurityPolicy + +# Try to import E2BSandbox (optional dependency) +try: + from .e2b_sandbox import E2BSandbox + E2B_AVAILABLE = True +except ImportError: + E2BSandbox = None + E2B_AVAILABLE = False + +__all__ = [ + "BaseSandbox", + "SandboxManager", + "SecurityPolicyManager", + "SecurityPolicy" +] + +if E2B_AVAILABLE: + __all__.append("E2BSandbox") \ No newline at end of file diff --git a/openspace/grounding/core/security/e2b_sandbox.py b/openspace/grounding/core/security/e2b_sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a3ffed4fdb925b454c3e819b7a746d155a6ee0 --- /dev/null +++ b/openspace/grounding/core/security/e2b_sandbox.py @@ -0,0 +1,213 @@ +""" +E2B Sandbox implementation. + +This module provides a concrete implementation of BaseSandbox using E2B. +""" + +import os +from typing import Any, Dict, Optional, TYPE_CHECKING + +from openspace.utils.logging import Logger +from .sandbox import BaseSandbox +from ..types import SandboxOptions + +logger = Logger.get_logger(__name__) + +# Import E2B SDK components (optional dependency) +if TYPE_CHECKING: + # For type checking purposes only + try: + from e2b_code_interpreter import CommandHandle, Sandbox + except ImportError: + CommandHandle = None # type: ignore + Sandbox = None # type: ignore + +try: + logger.debug("Attempting to import e2b_code_interpreter...") + from e2b_code_interpreter import ( # type: ignore + CommandHandle, + Sandbox, + ) + logger.debug("Successfully imported e2b_code_interpreter") + E2B_AVAILABLE = True +except ImportError as e: + logger.debug(f"Failed to import e2b_code_interpreter: {e}") + CommandHandle = None # type: ignore + Sandbox = None # type: ignore + E2B_AVAILABLE = False + + +class E2BSandbox(BaseSandbox): + """E2B sandbox implementation for secure code execution.""" + + def __init__(self, options: SandboxOptions): + """Initialize E2B sandbox. + + Args: + options: Sandbox configuration options including: + - api_key: E2B API key (or use E2B_API_KEY env var) + - sandbox_template_id: Template ID for the sandbox (default: "base") + - timeout: Command execution timeout in seconds + """ + super().__init__(options) + + if not E2B_AVAILABLE: + raise ImportError( + "E2B SDK (e2b-code-interpreter) not found. Please install it with " + "'pip install e2b-code-interpreter'." + ) + + # Get API key from options or environment + self.api_key = options.get("api_key") or os.environ.get("E2B_API_KEY") + if not self.api_key: + raise ValueError( + "E2B API key is required. Provide it via 'options.api_key'" + " or the E2B_API_KEY environment variable." + ) + + # Get sandbox configuration + self.sandbox_template_id = options.get("sandbox_template_id", "base") + self.timeout = options.get("timeout", 600) # Default 10 minutes + + # Sandbox instance (using Any to avoid import issues with optional dependency) + self._sandbox: Any = None + self._process: Any = None + + async def start(self) -> bool: + """Start the E2B sandbox instance. + + Returns: + True if sandbox started successfully, False otherwise. + """ + if self._active: + logger.debug("E2B sandbox already active") + return True + + try: + logger.debug(f"Creating E2B sandbox with template: {self.sandbox_template_id}") + self._sandbox = Sandbox( + template=self.sandbox_template_id, + api_key=self.api_key, + ) + self._active = True + logger.info(f"E2B sandbox started successfully (template: {self.sandbox_template_id})") + return True + + except Exception as e: + logger.error(f"Failed to start E2B sandbox: {e}") + self._active = False + return False + + async def stop(self) -> None: + """Stop the E2B sandbox instance.""" + if not self._active: + logger.debug("E2B sandbox not active") + return + + try: + # Terminate any running process + if self._process: + try: + logger.debug("Terminating sandbox process") + self._process.kill() + except Exception as e: + logger.warning(f"Error terminating sandbox process: {e}") + finally: + self._process = None + + # Close the sandbox + if self._sandbox: + try: + logger.debug("Closing E2B sandbox instance") + self._sandbox.kill() + logger.info("E2B sandbox stopped successfully") + except Exception as e: + logger.warning(f"Error closing E2B sandbox: {e}") + finally: + self._sandbox = None + + self._active = False + + except Exception as e: + logger.error(f"Error stopping E2B sandbox: {e}") + raise + + async def execute_safe(self, command: str, **kwargs) -> Any: + """Execute a command safely in the E2B sandbox. + + Args: + command: The command to execute + **kwargs: Additional options: + - envs: Environment variables (dict) + - timeout: Command timeout in milliseconds + - background: Run in background (bool) + - on_stdout: Stdout callback function + - on_stderr: Stderr callback function + + Returns: + CommandHandle object representing the running process + """ + if not self._active or not self._sandbox: + raise RuntimeError("E2B sandbox is not active. Call start() first.") + + try: + # Extract execution options + envs = kwargs.get("envs", {}) + timeout = kwargs.get("timeout", self.timeout * 1000) # Convert to ms + background = kwargs.get("background", False) + on_stdout = kwargs.get("on_stdout") + on_stderr = kwargs.get("on_stderr") + + logger.debug(f"Executing command in E2B sandbox: {command}") + + # Execute the command + self._process = self._sandbox.commands.run( + command, + envs=envs, + timeout=timeout, + background=background, + on_stdout=on_stdout, + on_stderr=on_stderr, + ) + + return self._process + + except Exception as e: + logger.error(f"Failed to execute command in E2B sandbox: {e}") + raise + + def get_connector(self) -> Any: + """Get the underlying E2B sandbox connector. + + Returns: + The E2B Sandbox instance, or None if not active. + """ + return self._sandbox + + def get_host(self, port: int) -> str: + """Get the host URL for a specific port. + + Args: + port: The port number to get the host for + + Returns: + The host URL string + + Raises: + RuntimeError: If sandbox is not active + """ + if not self._active or not self._sandbox: + raise RuntimeError("E2B sandbox is not active. Call start() first.") + + return self._sandbox.get_host(port) + + @property + def sandbox(self) -> Any: + """Get the underlying E2B Sandbox instance.""" + return self._sandbox + + @property + def process(self) -> Any: + """Get the current running process handle.""" + return self._process + diff --git a/openspace/grounding/core/security/policies.py b/openspace/grounding/core/security/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..def81271a8799135c4369cd60a05f8dbe11f8b6c --- /dev/null +++ b/openspace/grounding/core/security/policies.py @@ -0,0 +1,162 @@ +import asyncio +import sys +from typing import Callable, Awaitable, Dict, Optional +from ..types import SecurityPolicy, BackendType + +PromptFunc = Callable[[str], Awaitable[bool]] + + +# ANSI color codes +class Colors: + RESET = "\033[0m" + BOLD = "\033[1m" + RED = "\033[91m" + YELLOW = "\033[93m" + GREEN = "\033[92m" + CYAN = "\033[96m" + GRAY = "\033[90m" + WHITE = "\033[97m" + + +class SecurityPolicyManager: + def __init__(self, prompt: PromptFunc | None = None): + self._policies: Dict[BackendType, SecurityPolicy] = {} + self._global_policy: Optional[SecurityPolicy] = None + self._prompt: PromptFunc | None = prompt or self._default_cli_prompt + + async def _default_cli_prompt(self, message: str) -> bool: + # Clean and professional prompt using unified display + from openspace.utils.display import Box, BoxStyle, colorize, print_separator + + print() + print_separator(70, 'y', 2) + print(f" {colorize('⚠️ Security Policy Warning', color=Colors.RED, bold=True)}") + print_separator(70, 'y', 2) + print(f" {message}") + print_separator(70, 'gr', 2) + print(f" {colorize('[y/yes]', color=Colors.GREEN)} Allow | {colorize('[n/no]', color=Colors.RED)} Deny") + print_separator(70, 'gr', 2) + print(f" {colorize('Your choice:', bold=True)} ", end="", flush=True) + + answer = await asyncio.get_running_loop().run_in_executor(None, sys.stdin.readline) + response = answer.strip().lower() in {"y", "yes"} + + if response: + print(f" {colorize('✓ Allowed', color=Colors.GREEN)}\n") + else: + print(f" {colorize('✗ Denied', color=Colors.RED)}\n") + + return response + + def set_global_policy(self, policy: SecurityPolicy) -> None: + self._global_policy = policy + + def set_backend_policy(self, backend_type: BackendType, policy: SecurityPolicy) -> None: + self._policies[backend_type] = policy + + def get_policy(self, backend_type: BackendType) -> SecurityPolicy: + policy = self._policies.get(backend_type) + if policy: + return policy + + if self._global_policy: + return self._global_policy + + return SecurityPolicy() + + async def _ask_user(self, message: str) -> bool: + """If prompt is provided, ask user for confirmation, otherwise default to deny""" + if self._prompt: + try: + return await self._prompt(message) + except Exception: + return False + return False + + async def check_command_allowed(self, backend_type: BackendType, command: str) -> bool: + policy = self.get_policy(backend_type) + + if policy.check(command=command): + return True + + # Find dangerous tokens + dangerous_tokens = policy.find_dangerous_tokens(command) + + # Extract only lines containing dangerous commands + lines = command.split('\n') + dangerous_lines = [] + for i, line in enumerate(lines): + line_lower = line.lower() + if any(token in line_lower for token in dangerous_tokens): + # Add line number and the line itself + dangerous_lines.append((i + 1, line.strip())) + + # If no specific dangerous lines found but policy failed, show first few lines + if not dangerous_lines: + dangerous_lines = [(i + 1, line.strip()) for i, line in enumerate(lines[:5])] + + # Format dangerous lines for display (limit to 10 lines) + max_display_lines = 10 + if len(dangerous_lines) > max_display_lines: + display_lines = dangerous_lines[:max_display_lines] + truncated = True + else: + display_lines = dangerous_lines + truncated = False + + # Build formatted command display + formatted_cmd_lines = [] + for line_num, line in display_lines: + # Truncate very long lines + if len(line) > 80: + line = line[:77] + "..." + formatted_cmd_lines.append(f" L{line_num}: {line}") + + if truncated: + formatted_cmd_lines.append(" ... (more lines)") + + formatted_command = '\n'.join(formatted_cmd_lines) + + # Show which dangerous commands were detected + dangerous_list = ', '.join([f"{Colors.RED}{tok}{Colors.RESET}" for tok in dangerous_tokens[:5]]) + + from openspace.utils.display import Box, BoxStyle, colorize + + # Build command box + box = Box(width=66, style=BoxStyle.SQUARE, color='gr') + cmd_box = [ + box.top_line(2), + box.empty_line(2), + ] + for line in formatted_cmd_lines: + cmd_box.append(box.text_line(line, indent=2)) + cmd_box.extend([ + box.empty_line(2), + box.bottom_line(2) + ]) + + message = ( + f"\n{colorize('Potentially dangerous command detected', color=Colors.WHITE)}\n\n" + f"Backend: {colorize(backend_type.value, color=Colors.CYAN)}\n" + f"Dangerous commands: {dangerous_list}\n\n" + f"Affected lines:\n" + + "\n".join(cmd_box) + "\n\n" + f"{colorize('This command may contain risky operations. Continue?', color=Colors.YELLOW)}" + ) + + return await self._ask_user(message) + + async def check_domain_allowed(self, backend_type: BackendType, domain: str) -> bool: + policy = self.get_policy(backend_type) + + if policy.check(domain=domain): + return True + + message = ( + f"\n{Colors.WHITE}Unauthorized domain access detected{Colors.RESET}\n\n" + f"Backend: {Colors.CYAN}{backend_type.value}{Colors.RESET}\n" + f"Domain: {Colors.YELLOW}{domain}{Colors.RESET}\n\n" + f"{Colors.YELLOW}This domain is not in the allowed list. Continue?{Colors.RESET}" + ) + + return await self._ask_user(message) \ No newline at end of file diff --git a/openspace/grounding/core/security/sandbox.py b/openspace/grounding/core/security/sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbad814d8e361d2efdbc8990d8598de26db6024 --- /dev/null +++ b/openspace/grounding/core/security/sandbox.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional +from abc import ABC, abstractmethod + +from ..types import SandboxOptions, BackendType + + +class BaseSandbox(ABC): + def __init__(self, options: SandboxOptions): + self.options = options + self._active = False + + @abstractmethod + async def start(self) -> bool: + """Set self._active to True""" + pass + + @abstractmethod + async def stop(self) -> None: + """Set self._active to False""" + pass + + @abstractmethod + async def execute_safe(self, command: str, **kwargs) -> Any: + pass + + @abstractmethod + def get_connector(self) -> Any: + pass + + @property + def is_active(self) -> bool: + return self._active + + +class SandboxManager: + def __init__(self): + self._sandboxes: Dict[BackendType, BaseSandbox] = {} + + def register_sandbox(self, backend_type: BackendType, sandbox: BaseSandbox) -> None: + self._sandboxes[backend_type] = sandbox + + def get_sandbox(self, backend_type: BackendType) -> Optional[BaseSandbox]: + return self._sandboxes.get(backend_type) + + async def start_all(self) -> None: + for sandbox in self._sandboxes.values(): + await sandbox.start() + + async def stop_all(self) -> None: + for sandbox in self._sandboxes.values(): + await sandbox.stop() \ No newline at end of file diff --git a/openspace/grounding/core/session.py b/openspace/grounding/core/session.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3bb8ce82e9db42a03fa97cb75bb13783aef92b --- /dev/null +++ b/openspace/grounding/core/session.py @@ -0,0 +1,118 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from datetime import datetime + +from .tool import BaseTool +from .transport.connectors import BaseConnector +from .types import SessionInfo, SessionStatus, BackendType, ToolResult +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class BaseSession(ABC): + """ + Session manager for all backends. + """ + def __init__( + self, + connector: BaseConnector, + *, + session_id: str, + backend_type: BackendType | None = None, + auto_connect: bool = True, + auto_initialize: bool = True, + ) -> None: + self.connector = connector + self.session_id = session_id + self.backend_type = backend_type or BackendType.NOT_SET + self.auto_connect = auto_connect + self.auto_initialize = auto_initialize + + self.status: SessionStatus = SessionStatus.DISCONNECTED + self.session_info: Dict[str, Any] | None = None + self._created_at = datetime.utcnow() + self._last_activity = self._created_at + self.tools: List[BaseTool] = [] + + async def __aenter__(self) -> "BaseSession": + if self.auto_connect: + await self.connect() + if self.auto_initialize: + self.session_info = await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit the async context manager. + + Args: + exc_type: The exception type, if an exception was raised. + exc_val: The exception value, if an exception was raised. + exc_tb: The exception traceback, if an exception was raised. + """ + await self.disconnect() + + async def connect(self) -> None: + if self.connector.is_connected: + return + self.status = SessionStatus.CONNECTING + await self.connector.connect() + self.status = SessionStatus.CONNECTED + + async def disconnect(self) -> None: + if not self.connector.is_connected: + return + await self.connector.disconnect() + self.status = SessionStatus.DISCONNECTED + + @property + def is_connected(self) -> bool: + return self.connector.is_connected + + @abstractmethod + async def initialize(self) -> Dict[str, Any]: + """ + Negotiate with the backend, discover tools, etc. + Return session information (can be an empty dict). + + `self.tools` need to be set in this method. + """ + raise NotImplementedError("Sub-class must implement this method") + + async def list_tools(self) -> List[BaseTool]: + """ + Return tools discovered during `initialize()`. + """ + if not self.tools: + self.session_info = await self.initialize() + return self.tools + + async def call_tool(self, tool_name: str, parameters=None) -> ToolResult: + parameters = parameters or {} + + # Ensure tools are initialized before calling + if not self.tools: + logger.debug(f"Tools not initialized for session {self.session_id}, initializing now...") + self.session_info = await self.initialize() + + tool_map = {t.schema.name: t for t in self.tools} + if tool_name not in tool_map: + raise ValueError(f"Unknown tool: {tool_name}") + result = await tool_map[tool_name].arun(**parameters) + self._touch() + return result + + # Update when a successful call is made + def _touch(self): + self._last_activity = datetime.utcnow() + + @property + def info(self) -> SessionInfo: + return SessionInfo( + session_id=self.session_id, + backend_type=getattr(self, "backend_type", BackendType.NOT_SET), + status=self.status, + created_at=self._created_at, + last_activity=self._last_activity, + metadata=self.session_info or {}, + ) \ No newline at end of file diff --git a/openspace/grounding/core/system/__init__.py b/openspace/grounding/core/system/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8172328f5caac43cbf5a1bf9604a1dc61c33c964 --- /dev/null +++ b/openspace/grounding/core/system/__init__.py @@ -0,0 +1,7 @@ +from .provider import SystemProvider +from .tool import SYSTEM_TOOLS + +__all__ = [ + "SystemProvider", + "SYSTEM_TOOLS", +] \ No newline at end of file diff --git a/openspace/grounding/core/system/provider.py b/openspace/grounding/core/system/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..b974e009fcd15999a020884be9b36cec4e8e9c00 --- /dev/null +++ b/openspace/grounding/core/system/provider.py @@ -0,0 +1,45 @@ +from typing import List, Dict, Any +from ..provider import Provider +from ..types import BackendType, SessionConfig +from ..grounding_client import GroundingClient +from .tool import SYSTEM_TOOLS, _BaseSystemTool +from ..exceptions import GroundingError, ErrorCode + + +class SystemProvider(Provider): + """ + Provider for system-level query tools + """ + def __init__(self, client: GroundingClient): + super().__init__(BackendType.SYSTEM, {}) + # Instantiates all system tools + self._tools: List[_BaseSystemTool] = [tool_cls(client) for tool_cls in SYSTEM_TOOLS] + + async def initialize(self): + self.is_initialized = True + + async def create_session(self, session_config: SessionConfig): + raise GroundingError( + "SystemProvider does not support sessions", + code=ErrorCode.CONFIG_INVALID, + ) + + async def list_tools(self, session_name: str | None = None): + return self._tools + + async def call_tool( + self, + session_name: str, + tool_name: str, + parameters: Dict[str, Any] | None = None, + ): + tool_map = {t.schema.name: t for t in self._tools} + if tool_name not in tool_map: + raise GroundingError( + f"System tool '{tool_name}' not found", + code=ErrorCode.TOOL_NOT_FOUND, + ) + return await tool_map[tool_name].arun(**(parameters or {})) + + async def close_session(self, session_name: str) -> None: + return \ No newline at end of file diff --git a/openspace/grounding/core/system/tool.py b/openspace/grounding/core/system/tool.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f9da532ea0796d31590267856e3b041d5dbcd1 --- /dev/null +++ b/openspace/grounding/core/system/tool.py @@ -0,0 +1,82 @@ +from ..tool.local_tool import LocalTool +from ..types import BackendType, ToolResult, ToolStatus +from ..grounding_client import GroundingClient + + +class _BaseSystemTool(LocalTool): + backend_type = BackendType.SYSTEM + + def __init__(self, client: GroundingClient): + super().__init__(verbose=False, handle_errors=True) + self._client = client + + @property + def client(self) -> GroundingClient: + return self._client + + +class ListProvidersTool(_BaseSystemTool): + _name = "list_providers" + _description = "List all registered backend providers" + + async def _arun(self) -> ToolResult: + prov = list(self.client.list_providers().keys()) + return ToolResult( + status=ToolStatus.SUCCESS, + content=", ".join(prov), + ) + + +class ListBackendToolsTool(_BaseSystemTool): + _name = "list_backend_tools" + _description = "List static tools for a backend" + + async def _arun(self, backend: str) -> ToolResult: + try: + be = BackendType(backend.lower()) + except ValueError: + return ToolResult(ToolStatus.ERROR, error=f"Unknown backend '{backend}'") + + tools = await self.client.list_backend_tools(be) + names = [t.schema.name for t in tools] + return ToolResult( + status=ToolStatus.SUCCESS, + content=", ".join(names), + ) + + +class ListSessionToolsTool(_BaseSystemTool): + _name = "list_session_tools" + _description = "List tools (incl. dynamic) for a session" + + async def _arun(self, session_id: str) -> ToolResult: + tools = await self.client.list_session_tools(session_id) + names = [t.schema.name for t in tools] + return ToolResult( + status=ToolStatus.SUCCESS, + content=", ".join(names), + ) + + +class ListAllBackendToolsTool(_BaseSystemTool): + _name = "list_all_backend_tools" + _description = "List static tools for every registered backend" + + async def _arun(self, use_cache: bool = False) -> ToolResult: + all_tools = await self.client.list_all_backend_tools(use_cache=use_cache) + lines = [ + f"{backend.value}: {', '.join(t.schema.name for t in tools)}" + for backend, tools in all_tools.items() + ] + return ToolResult( + status=ToolStatus.SUCCESS, + content="\n".join(lines), + ) + + +SYSTEM_TOOLS: list[type[_BaseSystemTool]] = [ + ListProvidersTool, + ListBackendToolsTool, + ListSessionToolsTool, + ListAllBackendToolsTool, +] \ No newline at end of file diff --git a/openspace/grounding/core/tool/__init__.py b/openspace/grounding/core/tool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b35d271ec9b0348e2536b020c52f3c107af5e1e --- /dev/null +++ b/openspace/grounding/core/tool/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseTool +from .local_tool import LocalTool +from .remote_tool import RemoteTool + +__all__ = ["BaseTool", "LocalTool", "RemoteTool"] \ No newline at end of file diff --git a/openspace/grounding/core/tool/base.py b/openspace/grounding/core/tool/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4b635314e0d46c847dff12290096dbcb993187 --- /dev/null +++ b/openspace/grounding/core/tool/base.py @@ -0,0 +1,326 @@ +""" +BaseTool. +All pre-defined grounding atomic operations will inherit this tool class. +RemoteTool needs to pass in connector. +""" +import asyncio, time, inspect +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING +from pydantic import BaseModel, ConfigDict, Field, create_model + +from ..types import BackendType, ToolResult, ToolSchema, ToolStatus +from ..exceptions import GroundingError, ErrorCode +from openspace.utils.logging import Logger +import jsonschema + +if TYPE_CHECKING: + from ..grounding_client import GroundingClient + +logger = Logger.get_logger(__name__) + + +class ToolRuntimeInfo: + """Runtime information for a tool instance""" + def __init__( + self, + backend: BackendType, + session_name: str, + server_name: Optional[str] = None, + grounding_client: Optional['GroundingClient'] = None, + ): + self.backend = backend + self.session_name = session_name + self.server_name = server_name + self.grounding_client = grounding_client + + def __repr__(self): + return f"" + + +class BaseTool(ABC): + _name: ClassVar[str] = "" + _description: ClassVar[str] = "" + backend_type: ClassVar[BackendType] = BackendType.NOT_SET + + def __init__(self, + schema: Optional[ToolSchema] = None, + *, + verbose: bool = False, + handle_errors: bool = True) -> None: + self.verbose = verbose + self.handle_errors = handle_errors + self.schema: ToolSchema = schema or ToolSchema( + name=self._name or self.__class__.__name__.lower(), + description=self._description, + parameters=self.get_parameters_schema(), + backend_type=self.backend_type, + ) + + self._runtime_info: Optional[ToolRuntimeInfo] = None + self._disable_outer_recording = True + + @property + def name(self) -> str: + """Get tool name from schema (supports both class-defined and runtime-injected names)""" + return self.schema.name if hasattr(self, 'schema') and self.schema else self._name + + @property + def description(self) -> str: + """Get tool description from schema (supports both class-defined and runtime-injected descriptions)""" + return self.schema.description if hasattr(self, 'schema') and self.schema else self._description + + @classmethod + @lru_cache + def get_parameters_schema(cls) -> Dict[str, Any]: + """Auto-generate JSON-schema from _run() or _arun() signature. + + Returns empty dict for tools with no parameters. + Priority: prefer _arun if overridden, otherwise use _run. + """ + # Priority: prefer _arun if it's overridden by subclass, else use _run + # This allows async-first tools to define their signature via _arun + sig_src = None + + # Check if _arun is overridden (not from BaseTool) + if cls._arun is not BaseTool._arun: + sig_src = cls._arun + # Otherwise check if _run is overridden + elif cls._run is not BaseTool._run: + sig_src = cls._run + # If neither is overridden, raise error + else: + raise ValueError( + f"{cls.__name__} must implement _run() or _arun() to define its parameters schema" + ) + + sig = inspect.signature(sig_src) + fields: dict[str, Any] = {} + for name, p in sig.parameters.items(): + # Skip 'self' and **kwargs / *args + if name == "self" or p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL): + continue + typ = p.annotation if p.annotation is not inspect._empty else str + default = p.default if p.default is not inspect._empty else ... + fields[name] = (typ, Field(default)) + + if not fields: + return {} + + PModel: type[BaseModel] = create_model( + f"{cls.__name__}Params", + __config__=ConfigDict(arbitrary_types_allowed=True), + **fields + ) + return PModel.model_json_schema() + + def validate_parameters(self, params: Dict[str, Any]) -> None: + try: + self.schema.validate_parameters(params, raise_exc=True) + except jsonschema.ValidationError as ve: + raise GroundingError( + f"Invalid parameters: {ve.message}", + code=ErrorCode.TOOL_EXECUTION_FAIL, + tool_name=self.schema.name, + ) from ve + + def run(self, **kwargs): + try: + return asyncio.run(self.invoke(**kwargs)) + except RuntimeError: # already in running loop + loop = asyncio.get_running_loop() + return loop.create_task(self.invoke(**kwargs)) + + def __call__(self, **kwargs): + return self.run(**kwargs) + + async def __acall__(self, **kwargs): + return await self.arun(**kwargs) + + async def arun(self, **kwargs) -> ToolResult: + start = time.time() + try: + self.validate_parameters(kwargs) + raw = await self._arun(**kwargs) + result = self._wrap_result(raw, time.time() - start) + except Exception as e: + if self.handle_errors: + result = ToolResult( + status=ToolStatus.ERROR, + error=str(e), + metadata={"tool": self.schema.name}, + ) + else: + raise + + await self._auto_record_execution(kwargs, result, time.time() - start) + return result + + # to be implemented by subclasses + @abstractmethod + async def _arun(self, **kwargs): ... + + def bind_runtime_info( + self, + backend: BackendType, + session_name: str, + server_name: Optional[str] = None, + grounding_client: Optional['GroundingClient'] = None, + ) -> 'BaseTool': + """ + Bind runtime information to the tool instance. + Allow the tool to be invoked directly without specifying backend/session/server. + + Args: + backend: Backend type + session_name: Session name + server_name: Server name (for MCP) + grounding_client: Optional reference to GroundingClient for direct invocation + """ + self._runtime_info = ToolRuntimeInfo( + backend=backend, + session_name=session_name, + server_name=server_name, + grounding_client=grounding_client, + ) + return self + + @property + def runtime_info(self) -> Optional['ToolRuntimeInfo']: + """Get runtime information if bound""" + return self._runtime_info + + @property + def is_bound(self) -> bool: + """Check if tool has runtime information bound""" + return self._runtime_info is not None + + async def invoke( + self, + parameters: Dict[str, Any] | None = None, + keep_session: bool = True, + **kwargs + ) -> ToolResult: + """ + Invoke this tool using bound runtime information. + Requires runtime info to be bound via bind_runtime_info(). + If no runtime info is bound, the tool will be executed locally. + """ + params = parameters or kwargs + + if self.is_bound and self._runtime_info.grounding_client: + return await self._runtime_info.grounding_client.invoke_tool( + tool=self, + parameters=params, + keep_session=keep_session, + ) + + return await self.arun(**params) + + def _wrap_result(self, obj: Any, elapsed: float) -> ToolResult: + if isinstance(obj, ToolResult): + obj.execution_time = elapsed + return obj + if self.verbose: + logger.debug("[%s] done in %.2f s", self.schema.name, elapsed) + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode("utf-8", errors="replace") + return ToolResult( + status=ToolStatus.SUCCESS, + content=str(obj), + execution_time=elapsed, + metadata={"tool": self.schema.name}, + ) + + async def _auto_record_execution( + self, + parameters: Dict[str, Any], + result: ToolResult, + execution_time: float, + ): + """Auto-record tool execution to recording manager and quality manager.""" + # Record to quality manager (for quality tracking) + await self._record_to_quality_manager(result, execution_time * 1000) + + # Record to recording manager (for trajectory recording) + try: + from openspace.recording import RecordingManager + + if not RecordingManager.is_recording(): + return + + # Check if tool has disabled outer recording (e.g., GUI agent with intermediate steps) + if hasattr(self, '_disable_outer_recording') and self._disable_outer_recording: + logger.debug(f"Skipping outer recording for {self.schema.name} (intermediate steps recorded)") + return + + # Get backend and server_name from runtime_info (if bound) + backend = self.backend_type.value + server_name = None + + if self.is_bound and self._runtime_info: + # Prefer runtime_info information (more accurate) + backend = self._runtime_info.backend.value + server_name = self._runtime_info.server_name + + # Get screenshot (if GUI backend) + screenshot = None + if self.backend_type == BackendType.GUI and hasattr(self, 'connector'): + try: + screenshot = await self.connector.get_screenshot() + except Exception as e: + logger.debug(f"Failed to capture screenshot: {e}") + + # Record tool execution with complete runtime information + await RecordingManager.record_tool_execution( + tool_name=self.schema.name, + backend=backend, + parameters=parameters, + result=result.content, + server_name=server_name, + is_success=result.is_success, # Pass actual success status from ToolResult + ) + except Exception as e: + logger.warning(f"Failed to auto-record tool execution for {self.schema.name}: {e}") + + async def _record_to_quality_manager( + self, + result: ToolResult, + execution_time_ms: float, + ): + """Record execution result to quality manager for quality tracking.""" + try: + from openspace.grounding.core.quality import get_quality_manager + + manager = get_quality_manager() + if manager: + await manager.record_execution(self, result, execution_time_ms) + except Exception as e: + # Quality recording failure should not affect tool execution + logger.debug(f"Failed to record to quality manager: {e}") + + # keep _run for backward-compatibility / thread-pool fallback + def _run(self, **kwargs): + raise NotImplementedError + + def __repr__(self): + base = f"" + + def __init_subclass__(cls, **kwargs): + """ + - at least implement _run or _arun + - backend_type is NOT_SET, only give a warning, allow RemoteTool to inject at runtime + """ + super().__init_subclass__(**kwargs) + + if cls._arun is BaseTool._arun and cls._run is BaseTool._run: + raise ValueError(f"{cls.__name__} must implement _run() or _arun()") + + if cls.backend_type is BackendType.NOT_SET: + logger.debug( + "%s.backend_type is NOT_SET; remember to override or set at runtime.", + cls.__name__, + ) \ No newline at end of file diff --git a/openspace/grounding/core/tool/local_tool.py b/openspace/grounding/core/tool/local_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ac5bf9f76f57ff139345fcb8b73993e181825c --- /dev/null +++ b/openspace/grounding/core/tool/local_tool.py @@ -0,0 +1,29 @@ +""" +LocalTool. +Executes entirely inside this Python process. +""" +import asyncio +from typing import Any +from .base import BaseTool + + +class LocalTool(BaseTool): + def _run(self, **kwargs): + raise NotImplementedError + + async def _dispatch_run(self, **kwargs) -> Any: + # Prefer subclass's own _arun if it was overridden + if self.__class__._arun is not LocalTool._arun: + return await super()._arun(**kwargs) + + # Else fall back to thread-pooled _run if provided + if self.__class__._run is not LocalTool._run: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: self._run(**kwargs)) + + raise NotImplementedError( + f"{self.__class__.__name__} must implement _run() or _arun()" + ) + + async def _arun(self, **kwargs): + return await self._dispatch_run(**kwargs) \ No newline at end of file diff --git a/openspace/grounding/core/tool/remote_tool.py b/openspace/grounding/core/tool/remote_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..6e98ccfe6e24560992a129f8e9168a5dbacdbc33 --- /dev/null +++ b/openspace/grounding/core/tool/remote_tool.py @@ -0,0 +1,89 @@ +""" +RemoteTool. +Wrapper around a connector that calls a remote tool. +""" +from typing import Optional +from openspace.utils.logging import Logger +from ..types import BackendType, ToolResult, ToolSchema, ToolStatus +from .base import BaseTool +from openspace.grounding.core.transport.connectors import BaseConnector + +logger = Logger.get_logger(__name__) + + +class RemoteTool(BaseTool): + backend_type = BackendType.NOT_SET + + def __init__( + self, + schema: ToolSchema | None = None, + connector: Optional[BaseConnector] = None, + remote_name: str = "", + *, + verbose: bool = False, + backend: BackendType = BackendType.NOT_SET, + ): + self._conn = connector + self._remote_name = remote_name or (schema.name if schema else "") + self.backend_type = backend + super().__init__(schema=schema, verbose=verbose) + + async def _arun(self, **kwargs): + # If no connector, tool must be invoked via grounding_client (on-demand startup) + if self._conn is None: + raise RuntimeError( + f"Tool '{self.name}' has no connector. " + "Use grounding_client.invoke_tool() to execute it with on-demand server startup." + ) + + raw = await self._conn.invoke(self._remote_name, kwargs) + + if hasattr(raw, 'content') and hasattr(raw, 'isError'): + content_parts = [] + for item in (raw.content or []): + # Extract text from TextContent + if hasattr(item, 'text') and item.text: + content_parts.append(item.text) + # Handle ImageContent (just note its presence) + elif hasattr(item, 'data'): + content_parts.append(f"[Image data: {len(item.data) if item.data else 0} bytes]") + # Handle EmbeddedResource + elif hasattr(item, 'resource'): + content_parts.append(f"[Embedded resource: {getattr(item.resource, 'uri', 'unknown')}]") + + content = "\n".join(content_parts) if content_parts else "" + is_error = getattr(raw, 'isError', False) + + return ToolResult( + status=ToolStatus.ERROR if is_error else ToolStatus.SUCCESS, + content=content, + error=content if is_error else None, + ) + + # Handle dict response + if isinstance(raw, dict): + import json + try: + content = json.dumps(raw, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + content = str(raw) + # Handle list/tuple response + elif isinstance(raw, (list, tuple)): + import json + try: + content = json.dumps(raw, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + content = str(raw) + # Handle primitive types + elif isinstance(raw, (int, float, bool)): + content = str(raw) + elif isinstance(raw, str): + content = raw + # Fallback for unknown types + else: + content = str(raw) + + return ToolResult( + status=ToolStatus.SUCCESS, + content=content, + ) \ No newline at end of file diff --git a/openspace/grounding/core/transport/connectors/__init__.py b/openspace/grounding/core/transport/connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..038b7f30e6179c6e3319cf7160e2cc9cafdbc84a --- /dev/null +++ b/openspace/grounding/core/transport/connectors/__init__.py @@ -0,0 +1,7 @@ +from .base import BaseConnector +from .aiohttp_connector import AioHttpConnector + +__all__ = [ + "BaseConnector", + "AioHttpConnector", +] \ No newline at end of file diff --git a/openspace/grounding/core/transport/connectors/aiohttp_connector.py b/openspace/grounding/core/transport/connectors/aiohttp_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..51a348faeb100d891b2b3b830eda63793d953dc6 --- /dev/null +++ b/openspace/grounding/core/transport/connectors/aiohttp_connector.py @@ -0,0 +1,145 @@ +from typing import Any +from yarl import URL +import aiohttp + +from ..task_managers import AioHttpConnectionManager +from .base import BaseConnector +from openspace.utils.logging import Logger +from pydantic import BaseModel + +logger = Logger.get_logger(__name__) + + +class AioHttpConnector(BaseConnector[aiohttp.ClientSession]): + """Generic HTTP-based connector with auto-reconnect & helper methods.""" + + def __init__(self, base_url: str, **session_kw): + connection_manager = AioHttpConnectionManager(base_url, **session_kw) + super().__init__(connection_manager) + self.base_url = base_url.rstrip("/") + + async def connect(self) -> None: + await super().connect() + try: + async with self._connection.get(self.base_url, timeout=5) as resp: + if resp.status >= 500: + raise ConnectionError(f"HTTP {resp.status}") + except Exception as e: + await self.disconnect() + raise ConnectionError(f"Ping {self.base_url} failed: {e}") + + async def _request( + self, + method: str, + path: str, + *, + json: Any | BaseModel | None = None, + data: Any | None = None, + params: dict[str, Any] | None = None, + **kw, + ) -> aiohttp.ClientResponse: + if not self.is_connected: + await self.connect() + + assert self._connection is not None # for mypy + url = URL(self.base_url) / path.lstrip("/") + logger.debug("%s %s", method.upper(), url) + return await self._connection.request( + method.upper(), + url, + json=self._to_json_compatible(json), + data=data, + params=params, + **kw, + ) + + async def get_json(self, path: str, **kw) -> Any: + response_model: type[BaseModel] | None = kw.pop("response_model", None) + resp = await self._request("GET", path, **kw) + resp.raise_for_status() + data = await resp.json() + return self._parse_as(data, response_model) + + async def get_bytes(self, path: str, **kw) -> bytes: + resp = await self._request("GET", path, **kw) + resp.raise_for_status() + return await resp.read() + + async def post_json( + self, + path: str, + payload: Any | BaseModel, + *, + response_model: type[BaseModel] | None = None, + **kw, + ) -> Any | BaseModel: + resp = await self._request("POST", path, json=payload, **kw) + + try: + data = await resp.json() + except Exception: + data = None + + if resp.status >= 400: + # Extract detailed error from response body + detail = "" + if data: + detail = data.get("output") or data.get("message") or data.get("error") or "" + error_msg = f"{resp.status}, message='{resp.reason}'" + if detail: + error_msg += f", detail='{detail}'" + raise aiohttp.ClientResponseError( + resp.request_info, + resp.history, + status=resp.status, + message=error_msg, + ) + + return self._parse_as(data, response_model) + + async def request(self, method: str, path: str, **kw) -> aiohttp.ClientResponse: + return await self._request(method, path, **kw) + + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + """ + Generic tool-invocation mapping for HTTP back-ends. + + name rule (case-insensitive): + - "GET /path" -> GET, return JSON + - "GET_TEXT /path" -> GET, return str + - "GET_BYTES /path" -> GET, return bytes + - "POST /path" -> POST, payload = params (JSON) + - other -> default POST /{name}, payload = params + + If PUT/PATCH/DELETE is needed in the future, it can be reused in _handle_other_json. + """ + verb_path = name.strip().split(maxsplit=1) + verb = verb_path[0].upper() + path = verb_path[1] if len(verb_path) == 2 else verb_path[0] + + if verb == "GET_BYTES": + return await self.get_bytes(path, params=params) + + if verb == "GET_TEXT": + resp = await self._request("GET", path, params=params) + resp.raise_for_status() + return await resp.text() + + if verb in {"GET", "POST"} and len(verb_path) == 2: + if verb == "GET": + return await self.get_json(path, params=params) + return await self.post_json(path, payload=params) + + if verb in {"PUT", "PATCH", "DELETE"} and len(verb_path) == 2: + return await self._handle_other_json(verb, path, params) + + return await self.post_json(name, payload=params) + + async def _handle_other_json(self, method: str, path: str, params: dict[str, Any]): + """Fallback implementation for PUT/PATCH/DELETE returning JSON/text, can be overridden by subclasses.""" + resp = await self._request(method, path, json=params) + resp.raise_for_status() + try: + return await resp.json() + except Exception: + return await resp.text() \ No newline at end of file diff --git a/openspace/grounding/core/transport/connectors/base.py b/openspace/grounding/core/transport/connectors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7b70c563cea97cc1422a7f9d3b7305093caa068e --- /dev/null +++ b/openspace/grounding/core/transport/connectors/base.py @@ -0,0 +1,138 @@ +""" +Base connector abstraction. + +A connector is a very thin wrapper-class that owns a *connection manager* +(e.g. AioHttpConnectionManager, AsyncContextConnectionManager, …). +It exposes a unified `connect / disconnect / is_connected` lifecycle and +defines an abstract `request()` method which concrete back-ends must +implement. +""" +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar, Type +from pydantic import BaseModel +from ..task_managers import BaseConnectionManager + +T = TypeVar("T") # The object returned by manager.start(): session / connection + + +class BaseConnector(ABC, Generic[T]): + """ + Generic connector that delegates the heavy lifting to the supplied + *connection manager*. Concrete subclasses only need to implement + their own `request()` method. + """ + + def __init__(self, connection_manager: BaseConnectionManager[T]): + self._connection_manager = connection_manager # e.g. AioHttpConnectionManager instance + # The raw connection object returned by the manager, for reusing the established long-term connection + self._connection: T | None = None + self._connected = False + + async def connect(self) -> None: + """Create the underlying session/connection via the manager.""" + if self._connected: + return + + try: + # Hook: before connection + await self._before_connect() + + # Start the connection manager + self._connection = await self._connection_manager.start() + + # Hook: after connection established + await self._after_connect() + + # Mark as connected + self._connected = True + except Exception: + # Clean up on failure + await self._cleanup_on_connect_failure() + raise + + async def disconnect(self) -> None: + """Close the session/connection and reset state. + + Ensures proper cleanup of all resources including aiohttp sessions. + """ + if not self._connected: + return + + # Hook: before disconnection + await self._before_disconnect() + + # Stop the connection manager + if self._connection_manager: + await self._connection_manager.stop() + self._connection = None + + # Hook: after disconnection + await self._after_disconnect() + + self._connected = False + + async def _before_connect(self) -> None: + """Hook called before establishing connection. Override in subclasses if needed.""" + pass + + async def _after_connect(self) -> None: + """Hook called after connection is established. Override in subclasses if needed.""" + pass + + async def _cleanup_on_connect_failure(self) -> None: + """Hook called when connection fails. Override in subclasses if needed.""" + if self._connection_manager: + try: + await self._connection_manager.stop() + except Exception: + pass + self._connection = None + + async def _before_disconnect(self) -> None: + """Hook called before disconnection. Override in subclasses if needed.""" + pass + + async def _after_disconnect(self) -> None: + """Hook called after disconnection. Override in subclasses if needed.""" + pass + + @property + def is_connected(self) -> bool: + """Return True iff `connect()` has completed successfully.""" + return self._connected + + @staticmethod + def _to_json_compatible(obj: Any) -> Any: + """ + Convert a Pydantic BaseModel to a JSON-serialisable dict (by_alias=True). + Leave all other types unchanged. + """ + if isinstance(obj, BaseModel): + return obj.model_dump(by_alias=True) + return obj + + @staticmethod + def _parse_as(data: Any, model_cls: "Type[BaseModel] | None" = None) -> Any: + """ + Try to parse *data* into *model_cls* (a subclass of BaseModel). + If `model_cls` is None or not a subclass of BaseModel, return the original data. + """ + if model_cls is None: + return data + if isinstance(model_cls, type) and issubclass(model_cls, BaseModel): + return model_cls.model_validate(data) + return data + + @abstractmethod + async def invoke(self, name: str, params: dict[str, Any]) -> Any: + """ + Unified RPC entry for all tools. + Sub-class maps this to its actual RPC like call_tool / run_cmd. + """ + raise NotImplementedError + + @abstractmethod + async def request(self, *args: Any, **kwargs: Any) -> Any: + """Abstract RPC / HTTP / WS request method to be implemented by child classes.""" + raise NotImplementedError("This connector has not implemented 'request'") \ No newline at end of file diff --git a/openspace/grounding/core/transport/task_managers/__init__.py b/openspace/grounding/core/transport/task_managers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6df6edccc14efcfec9d2fe616493fdd47ccb5d53 --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/__init__.py @@ -0,0 +1,13 @@ +from .base import BaseConnectionManager +from .aiohttp_connection_manager import AioHttpConnectionManager +from .async_ctx import AsyncContextConnectionManager +from .placeholder import PlaceholderConnectionManager +from .noop import NoOpConnectionManager + +__all__ = [ + "BaseConnectionManager", + "AioHttpConnectionManager", + "AsyncContextConnectionManager", + "PlaceholderConnectionManager", + "NoOpConnectionManager", +] \ No newline at end of file diff --git a/openspace/grounding/core/transport/task_managers/aiohttp_connection_manager.py b/openspace/grounding/core/transport/task_managers/aiohttp_connection_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..00ea7cb3f725125a79df3f45fe551bfc71dd36e7 --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/aiohttp_connection_manager.py @@ -0,0 +1,56 @@ +""" +Long-lived aiohttp ClientSession manager based on AsyncContextConnectionManager. + +It keeps a single ClientSession open during the lifetime of a backend +session, saving the overhead of creating and closing a TCP connection +for every request. +""" +from typing import Optional +import aiohttp + +from .async_ctx import AsyncContextConnectionManager + + +class AioHttpConnectionManager( + AsyncContextConnectionManager[aiohttp.ClientSession, ...] +): + """Manage a persistent aiohttp.ClientSession.""" + + def __init__( + self, + base_url: str, + headers: Optional[dict[str, str]] = None, + timeout: float = 30, + ): + self.base_url = base_url.rstrip("/") + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + super().__init__( + aiohttp.ClientSession, + timeout=timeout_cfg, + headers=headers or {}, + ) + self._logger.debug( + "Init AioHttpConnectionManager base_url=%s timeout=%s", self.base_url, timeout + ) + + async def _establish_connection(self) -> aiohttp.ClientSession: + """Create and enter the aiohttp.ClientSession context.""" + session = await super()._establish_connection() + self._logger.debug("aiohttp ClientSession created") + return session + + async def _close_connection(self) -> None: + """Close the session and then call the parent cleanup. + + Ensures proper cleanup even if close() fails. + """ + if self._ctx: + try: + self._logger.debug("Closing aiohttp ClientSession") + await self._ctx.close() + # Give aiohttp time to finish its internal cleanup callbacks + import asyncio + await asyncio.sleep(0.1) + except Exception as e: + self._logger.warning(f"Error closing aiohttp ClientSession: {e}") + await super()._close_connection() \ No newline at end of file diff --git a/openspace/grounding/core/transport/task_managers/async_ctx.py b/openspace/grounding/core/transport/task_managers/async_ctx.py new file mode 100644 index 0000000000000000000000000000000000000000..a010f6bb10a2968405b37ec9ed7f45c053ef395d --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/async_ctx.py @@ -0,0 +1,116 @@ +""" +Generic connection manager based on an *async context manager*. +Give it any factory that returns an async–context-manager. +""" +import sys +from typing import Any, Callable, Generic, Optional, ParamSpec, TypeVar +from .base import BaseConnectionManager + +# BaseExceptionGroup only exists in Python 3.11+ +if sys.version_info >= (3, 11): + _BaseExceptionGroup = BaseExceptionGroup +else: + # Dummy class for older Python versions + class _BaseExceptionGroup(Exception): + pass + +T = TypeVar("T") # Return type of the async context +P = ParamSpec("P") # Parameter specification of the factory + + +class AsyncContextConnectionManager(Generic[T, P], BaseConnectionManager[T]): + def __init__(self, + ctx_factory: Callable[P, Any], + *args: P.args, + **kwargs: P.kwargs): + super().__init__() + self._factory = ctx_factory + self._factory_args = args + self._factory_kwargs = kwargs + self._ctx: Optional[Any] = None + + async def _establish_connection(self) -> T: + """Create the context manager and enter it.""" + self._logger.debug("Creating context via %s", self._factory.__name__) + try: + self._ctx = self._factory(*self._factory_args, **self._factory_kwargs) + result: T = await self._ctx.__aenter__() + self._logger.debug("Context %s entered successfully", self._factory.__name__) + return result + except Exception as e: + # Check if this is a benign ExceptionGroup/TaskGroup error + # These occur during concurrent initialization and cleanup + error_msg = str(e).lower() + is_taskgroup_error = ( + "unhandled errors in a taskgroup" in error_msg or + "cancel scope in a different task" in error_msg or + "exceptiongroup" in type(e).__name__.lower() + ) + + if is_taskgroup_error: + # This is a benign race condition during concurrent connection setup + # Log at debug level and re-raise to trigger retry logic + self._logger.debug( + f"Benign TaskGroup race condition during {self._factory.__name__} connection: {type(e).__name__}" + ) + # Clean up the partially created context + if self._ctx is not None: + try: + await self._ctx.__aexit__(None, None, None) + except Exception: + pass # Ignore cleanup errors + self._ctx = None + raise + else: + # Real error - log at error level + self._logger.error(f"Error establishing connection via {self._factory.__name__}: {e}") + raise + + async def _close_connection(self) -> None: + """Exit the context manager if it exists. + + Uses try-finally to ensure ctx is cleared even if __aexit__ fails. + This prevents resource leaks when cleanup encounters errors. + """ + if self._ctx is not None: + try: + self._logger.debug("Exiting context %s", self._factory.__name__) + + # Give subprocesses a moment to flush buffers before closing + import asyncio + await asyncio.sleep(0.05) + + # Try to exit the context, but catch all possible exceptions + try: + await self._ctx.__aexit__(None, None, None) + except BaseException as e: + # Catch absolutely everything including SystemExit, KeyboardInterrupt, etc. + # Check if it's a benign error + benign_error_types = ( + BrokenPipeError, ConnectionResetError, ValueError, + OSError, IOError, ProcessLookupError, RuntimeError, + GeneratorExit + ) + + is_benign = False + + # Check direct exception type + if isinstance(e, benign_error_types): + is_benign = True + # Check for BaseExceptionGroup (Python 3.11+) + elif hasattr(e, 'exceptions'): + # It's an exception group, check all sub-exceptions + is_benign = all(isinstance(sub_e, benign_error_types) for sub_e in e.exceptions) + + if is_benign: + self._logger.debug(f"Benign cleanup error for {self._factory.__name__}: {type(e).__name__}") + else: + self._logger.warning(f"Error during context exit for {self._factory.__name__}: {type(e).__name__}: {e}") + + # Don't re-raise - we want cleanup to complete + + except Exception as e: + # Catch any other unexpected errors in the outer try block + self._logger.warning(f"Unexpected error during cleanup for {self._factory.__name__}: {e}") + finally: + self._ctx = None \ No newline at end of file diff --git a/openspace/grounding/core/transport/task_managers/base.py b/openspace/grounding/core/transport/task_managers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..965e4c7fa4d48355ebe0969292b12dea83daa5bd --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/base.py @@ -0,0 +1,215 @@ +""" +Base connection manager for all backend connectors. + +This module provides an abstract base class for different types of connection +managers used in all backend connectors. + +Flow: start() → launch_connection_task() → call subclass _establish_connection() → notify ready → maintain connection until stop() → call subclass _close_connection() → cleanup +""" +import asyncio +from abc import ABC, abstractmethod +from typing import Generic, TypeVar +from openspace.utils.logging import Logger + +T = TypeVar("T") + + +class BaseConnectionManager(Generic[T], ABC): + """Abstract base class for connection managers. + + This class defines the interface for different types of connection managers + used with all backend connectors. + """ + + def __init__(self): + """Initialize a new connection manager.""" + self._ready_event = asyncio.Event() + self._done_event = asyncio.Event() + self._exception: Exception | None = None + self._connection: T | None = None + self._task: asyncio.Task | None = None + self._logger = Logger.get_logger(f"{__name__}.{self.__class__.__name__}") + + @abstractmethod + async def _establish_connection(self) -> T: + """Establish the connection. + + This method should be implemented by subclasses to establish + the specific type of connection needed. + + Returns: + The established connection. + + Raises: + Exception: If connection cannot be established. + """ + pass + + @abstractmethod + async def _close_connection(self) -> None: + """Close the connection. + + This method should be implemented by subclasses to close + the specific type of connection. + + """ + pass + + async def start(self, timeout: float | None = None) -> T: + """Start the connection manager and establish a connection. + + Args: + timeout: Optional timeout in seconds. If None, waits indefinitely. + If specified, will cancel the background task on timeout. + + Returns: + The established connection. + + Raises: + TimeoutError: If connection establishment times out. + Exception: If connection cannot be established. + """ + # Reset state + self._ready_event.clear() + self._done_event.clear() + self._exception = None + + # Create a task to establish and maintain the connection + self._task = asyncio.create_task(self._connection_task(), name=f"{self.__class__.__name__}_task") + + # Wait for the connection to be ready or fail (with optional timeout) + try: + if timeout is not None: + await asyncio.wait_for(self._ready_event.wait(), timeout=timeout) + else: + await self._ready_event.wait() + except asyncio.TimeoutError: + # Timeout! Cancel the background task + self._logger.warning(f"Connection establishment timed out after {timeout}s, cancelling...") + if self._task and not self._task.done(): + self._task.cancel() + try: + await asyncio.wait_for(self._task, timeout=2.0) # Give it 2s to cleanup + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + except Exception as e: + self._logger.debug(f"Error during task cancellation: {e}") + raise TimeoutError(f"Connection establishment timed out after {timeout}s") + + # If there was an exception, raise it + if self._exception: + # Check if this is a benign TaskGroup race condition + error_msg = str(self._exception).lower() + is_benign_taskgroup_error = ( + "unhandled errors in a taskgroup" in error_msg or + "cancel scope in a different task" in error_msg or + "exceptiongroup" in type(self._exception).__name__.lower() + ) + + if is_benign_taskgroup_error: + # Log as debug - this is expected and will be retried + self._logger.debug(f"Benign TaskGroup race condition, will retry: {type(self._exception).__name__}") + else: + # Real error - log at error level + self._logger.error(f"Failed to start connection: {self._exception}") + + raise self._exception + + # Return the connection + if self._connection is None: + error_msg = "Connection was not established" + self._logger.error(error_msg) + raise RuntimeError(error_msg) + + self._logger.info("Connection manager started successfully") + return self._connection + + async def stop(self, timeout: float = 5.0) -> None: + """Stop the connection manager and close the connection. + + Args: + timeout: Maximum time to wait for cleanup (default 5s). + + Ensures all async resources (including aiohttp sessions) are properly closed. + """ + if self._task and not self._task.done(): + self._task.cancel() + try: + await asyncio.wait_for(self._task, timeout=timeout) + except asyncio.TimeoutError: + self._logger.warning(f"Task cleanup timed out after {timeout}s") + except asyncio.CancelledError: + pass # Expected + except Exception as e: + self._logger.warning(f"Error stopping task: {e}") + + # Wait for the connection to be done (with timeout) + try: + await asyncio.wait_for(self._done_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + self._logger.warning(f"Done event wait timed out after {timeout}s") + + self._logger.info("Connection manager stopped") + + def get_streams(self) -> T | None: + """Get the current connection streams. + + Returns: + The current connection (typically a tuple of read_stream, write_stream) or None if not connected. + """ + return self._connection + + async def _connection_task(self) -> None: + """Run the connection task. + + This task establishes and maintains the connection until cancelled. + """ + try: + # Establish the connection + self._connection = await self._establish_connection() + self._logger.debug("Connection established") + + # Signal that the connection is ready + self._ready_event.set() + + # Wait indefinitely until cancelled + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + raise + + except asyncio.CancelledError: + raise + except Exception as e: + # Store the exception + self._exception = e + + # Check if this is a benign TaskGroup race condition + error_msg = str(e).lower() + is_benign_taskgroup_error = ( + "unhandled errors in a taskgroup" in error_msg or + "cancel scope in a different task" in error_msg or + "exceptiongroup" in type(e).__name__.lower() + ) + + if is_benign_taskgroup_error: + # Log as debug - this is expected during concurrent connection setup + self._logger.debug(f"Benign TaskGroup race condition in connection task: {type(e).__name__}") + else: + # Real error - log at error level + self._logger.error(f"Connection task failed: {e}") + + # Signal that the connection is ready (with error) + self._ready_event.set() + + finally: + # Close the connection if it was established + if self._connection is not None: + try: + await self._close_connection() + except Exception as e: + self._logger.warning(f"Error closing connection: {e}") + self._connection = None + + # Signal that the connection is done + self._done_event.set() \ No newline at end of file diff --git a/openspace/grounding/core/transport/task_managers/noop.py b/openspace/grounding/core/transport/task_managers/noop.py new file mode 100644 index 0000000000000000000000000000000000000000..61bb45bea12afcd889dda3b1a071e77ec4c74aa0 --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/noop.py @@ -0,0 +1,26 @@ +"""No-op connection manager for local (in-process) connectors. + +Local connectors execute commands directly via subprocess, so they don't +need a real network connection. This manager satisfies the +BaseConnectionManager interface that BaseConnector requires. +""" +import asyncio +from typing import Any +from .base import BaseConnectionManager + + +class NoOpConnectionManager(BaseConnectionManager[Any]): + """Connection manager that immediately reports 'ready' without + establishing any real connection. + + Used by LocalShellConnector and LocalGUIConnector. + """ + + async def _establish_connection(self) -> Any: + """No-op: return a sentinel value.""" + return True + + async def _close_connection(self) -> None: + """No-op: nothing to close.""" + pass + diff --git a/openspace/grounding/core/transport/task_managers/placeholder.py b/openspace/grounding/core/transport/task_managers/placeholder.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3c6bb66eedf1526bcea79b7174b497e4c5b279 --- /dev/null +++ b/openspace/grounding/core/transport/task_managers/placeholder.py @@ -0,0 +1,18 @@ +from typing import Any +from .base import BaseConnectionManager + + +class PlaceholderConnectionManager(BaseConnectionManager[Any]): + """A placeholder connection manager that does nothing. + + This is used by connectors that set up their real connection manager + during the connect() phase. + """ + + async def _establish_connection(self) -> Any: + """Establish the connection (placeholder implementation).""" + raise NotImplementedError("PlaceholderConnectionManager should be replaced before use") + + async def _close_connection(self) -> None: + """Close the connection (placeholder implementation).""" + pass \ No newline at end of file diff --git a/openspace/grounding/core/types.py b/openspace/grounding/core/types.py new file mode 100644 index 0000000000000000000000000000000000000000..ca239b524df418102f21cc0c56746edeeab82c33 --- /dev/null +++ b/openspace/grounding/core/types.py @@ -0,0 +1,286 @@ +from enum import Enum +from datetime import datetime +from typing import Any, Dict, Generic, List, TypeVar, Optional +import jsonschema +from pydantic import BaseModel, Field, ConfigDict + +# Pydantic v2 compatibility +try: + from pydantic import RootModel + PYDANTIC_V2 = True +except ImportError: + PYDANTIC_V2 = False + + +class BackendType(str, Enum): + MCP = "mcp" + SHELL = "shell" + WEB = "web" + GUI = "gui" + SYSTEM = "system" + NOT_SET = "not_set" + + +class ToolStatus(str, Enum): + SUCCESS = "success" + ERROR = "error" + + +class SessionStatus(str, Enum): + CONNECTED = "connected" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + + +ProgressToken = str | int +RequestId = str | int + +RequestParamsT = TypeVar("RequestParamsT", bound=BaseModel | Dict[str, Any] | None) +NotificationParamsT = TypeVar("NotificationParamsT", bound=BaseModel | Dict[str, Any] | None) +MethodT = TypeVar("MethodT", bound=str) + + +class BaseEntity(BaseModel): + metadata: Dict[str, Any] = Field(default_factory=dict) + model_config = ConfigDict(extra="allow") + + +class JsonRpcBase(BaseEntity): + jsonrpc: str = "2.0" + + +class RpcMessage(JsonRpcBase, Generic[MethodT, RequestParamsT]): + method: MethodT + params: RequestParamsT + + +class Request(RpcMessage[MethodT, RequestParamsT]): + id: RequestId | None = None # id is None means Notification + + +class Notification(RpcMessage[MethodT, NotificationParamsT]): + pass + + +class Result(JsonRpcBase): + pass + + +class ErrorData(BaseEntity): + code: int + message: str + data: Any | None = None + + +class ToolResult(Result): + """Tool execution result""" + status: ToolStatus + content: Any = "" + error: ErrorData | str | None = None + execution_time: float | None = None + + @property + def is_success(self) -> bool: return self.status == ToolStatus.SUCCESS + + @property + def is_error(self) -> bool: return self.status == ToolStatus.ERROR + + +class SecurityPolicy(BaseEntity): + allow_shell_commands: bool = True + allow_network_access: bool = True + allow_file_access: bool = True + allowed_domains: List[str] = Field(default_factory=list) + blocked_commands: List[str] = Field(default_factory=list) + sandbox_enabled: bool = False + + @classmethod + def from_dict(cls, data: Dict) -> "SecurityPolicy": + """ + Create SecurityPolicy from configuration dict. + + Supports two formats for blocked_commands: + 1. List format (applies to all OS): ["cmd1", "cmd2"] + 2. Dict format (OS-specific): + { + "common": ["cmd1", "cmd2"], + "linux": ["cmd3"], + "darwin": ["cmd4"], + "windows": ["cmd5"] + } + + When using dict format, merges 'common' commands with current OS-specific commands. + """ + import sys + import platform + + processed_data = {} + for k, v in data.items(): + if k not in cls.model_fields: + continue + + # Special handling for blocked_commands + if k == "blocked_commands": + if isinstance(v, dict): + # Dict format: merge common + OS-specific + blocked_list = list(v.get("common", [])) + + # Determine current OS + system = sys.platform + if system.startswith("linux"): + os_key = "linux" + elif system == "darwin": + os_key = "darwin" + elif system.startswith("win"): + os_key = "windows" + else: + os_key = None + + # Merge OS-specific commands + if os_key and os_key in v: + blocked_list.extend(v[os_key]) + + processed_data[k] = blocked_list + elif isinstance(v, list): + # List format: use as-is + processed_data[k] = v + else: + # Invalid format, use empty list + processed_data[k] = [] + else: + processed_data[k] = v + + return cls(**processed_data) + + def check(self, *, command: str | None = None, domain: str | None = None) -> bool: + """ + return True if allowed, False if denied. + Command check uses token-level matching to prevent simple space/escape bypasses. + """ + import shlex + + # Shell / Python command check + if command: + if not self.allow_shell_commands: + return False + + tokens = [t.lower() for t in shlex.split(command, posix=True)] + blocked_set = {b.lower() for b in self.blocked_commands} + if any(tok in blocked_set for tok in tokens): + return False + + # Network access check + if domain: + if not self.allow_network_access: + return False + if self.allowed_domains and domain not in self.allowed_domains: + return False + + return True + + def find_dangerous_tokens(self, command: str) -> List[str]: + """ + Find and return all dangerous tokens in the command. + Returns empty list if no dangerous tokens found. + """ + import shlex + + if not command: + return [] + + try: + tokens = [t.lower() for t in shlex.split(command, posix=True)] + except ValueError: + # If shlex.split fails, fall back to simple split + tokens = [t.lower() for t in command.split()] + + blocked_set = {b.lower() for b in self.blocked_commands} + dangerous = [tok for tok in tokens if tok in blocked_set] + + return dangerous + + +class ToolSchema(BaseEntity): + name: str + description: str | None = None + parameters: Dict[str, Any] = Field(default_factory=dict) # JSON Schema, optional + return_schema: Dict[str, Any] = Field(default_factory=dict) + examples: List[dict] = Field(default_factory=list) + usage_hint: str | None = None + latency_hint: str | None = None + backend_type: BackendType + security_policy: SecurityPolicy | None = None + + def validate_parameters(self, params: Dict[str, Any], *, raise_exc: bool = False) -> bool: + """use jsonschema to validate parameters + + Returns True if parameters are valid or if tool has no parameters. + """ + # If tool has no parameters defined and no parameters are provided, validation passes + if not self.parameters and not params: + return True + + # If tool has no parameters defined but parameters are provided, validation fails + if not self.parameters and params: + if raise_exc: + raise ValueError(f"Tool '{self.name}' does not accept any parameters, but got: {list(params.keys())}") + return False + + try: + jsonschema.validate(params, self.parameters) + return True + except jsonschema.ValidationError: + if raise_exc: + raise + return False + + def is_allowed(self, *, command: str | None = None, domain: str | None = None) -> bool: + """check security policy""" + return self.security_policy.check(command=command, domain=domain) if self.security_policy else True + + +class SessionConfig(BaseEntity): + session_name: str + backend_type: BackendType + connection_params: Dict[str, Any] = Field(default_factory=dict) + timeout: int = 30 + max_retries: int = 3 + auto_reconnect: bool = True + auto_connect: bool = True + health_check_interval: int = 5 + custom_settings: Dict[str, Any] = Field(default_factory=dict) + + +class SessionInfo(SessionConfig): + status: SessionStatus + created_at: datetime + last_activity: datetime + + +class SandboxOptions(BaseEntity): + api_key: str + """Direct API key for sandbox provider (e.g., E2B API key). + If not provided, will use E2B_API_KEY environment variable.""" + + sandbox_template_id: Optional[str] = None + """Template ID for the sandbox environment. + Default: 'base'""" + + supergateway_command: Optional[str] = None + """Command to run supergateway. + Default: 'npx -y supergateway'""" + + +# ClientMessage: Only available in Pydantic v2 +if PYDANTIC_V2: + class ClientMessage( + RootModel[ + Request[Any, str] | Notification[Any, str] + ] + ): + """ + Unified deserialization entry: `ClientMessage.model_validate_json(raw_bytes)` + """ +else: + # Pydantic v1 fallback: not used in current codebase + ClientMessage = None # type: ignore \ No newline at end of file diff --git a/openspace/host_detection/__init__.py b/openspace/host_detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..820dbfc8ef806eed34191648df1cb985c04878ed --- /dev/null +++ b/openspace/host_detection/__init__.py @@ -0,0 +1,91 @@ +"""Host-agent config auto-detection. + +Public API consumed by other OpenSpace subsystems (cloud, mcp_server, …): + + - ``build_llm_kwargs`` — resolve LLM credentials + - ``build_grounding_config_path`` — resolve grounding config + - ``read_host_mcp_env`` — host-agnostic skill env reader + - ``get_openai_api_key`` — OpenAI key resolution (multi-host) + +Internal / legacy re-exports (prefer the generic names above): + + - ``read_nanobot_mcp_env`` — nanobot-specific, kept for backward compat + - ``try_read_nanobot_config`` + +Supported host agents: + + - **nanobot** — ``~/.nanobot/config.json`` (``tools.mcpServers.openspace.env``) + - **openclaw** — ``~/.openclaw/openclaw.json`` (``skills.entries.openspace.env``) +""" + +import logging +from typing import Dict, Optional + +from openspace.host_detection.resolver import build_llm_kwargs, build_grounding_config_path +from openspace.host_detection.nanobot import ( + get_openai_api_key as _nanobot_get_openai_api_key, + read_nanobot_mcp_env, + try_read_nanobot_config, +) +from openspace.host_detection.openclaw import ( + get_openclaw_openai_api_key as _openclaw_get_openai_api_key, + is_openclaw_host, + read_openclaw_skill_env, +) + +logger = logging.getLogger("openspace.host_detection") + + +def read_host_mcp_env() -> Dict[str, str]: + """Read the OpenSpace env block from the current host agent config. + + Resolution order: + 1. nanobot — ``tools.mcpServers.openspace.env`` + 2. openclaw — ``skills.entries.openspace.env`` + 3. Empty dict (no host detected) + + Callers (e.g. ``cloud.auth``) use this single entry point and never + need to know which host agent is active. + """ + # Try nanobot first (most common deployment) + env = read_nanobot_mcp_env() + if env: + return env + + # Try openclaw + env = read_openclaw_skill_env("openspace") + if env: + logger.debug("read_host_mcp_env: resolved from OpenClaw config") + return env + + return {} + + +def get_openai_api_key() -> Optional[str]: + """Get OpenAI API key for embedding generation (multi-host). + + Resolution: + 1. ``OPENAI_API_KEY`` env var (checked inside nanobot reader) + 2. nanobot config ``providers.openai.apiKey`` + 3. openclaw config ``skills.entries.openspace.env.OPENAI_API_KEY`` + 4. None + """ + # nanobot reader already checks OPENAI_API_KEY env var first + key = _nanobot_get_openai_api_key() + if key: + return key + return _openclaw_get_openai_api_key() + + +__all__ = [ + "build_llm_kwargs", + "build_grounding_config_path", + "get_openai_api_key", + "read_host_mcp_env", + # legacy re-exports + "read_nanobot_mcp_env", + "try_read_nanobot_config", + # openclaw-specific (for direct use if needed) + "is_openclaw_host", + "read_openclaw_skill_env", +] diff --git a/openspace/host_detection/nanobot.py b/openspace/host_detection/nanobot.py new file mode 100644 index 0000000000000000000000000000000000000000..c06c74357e57cf3dc0242f2e16f92d33084e36e6 --- /dev/null +++ b/openspace/host_detection/nanobot.py @@ -0,0 +1,212 @@ +"""Nanobot host-agent config reader. + +Reads ``~/.nanobot/config.json`` to auto-detect: + - LLM provider credentials (``providers.*``) + - MCP env block for the ``openspace`` server + - Default model and forced provider settings + +Provider keyword → config field mapping mirrors ``nanobot/providers/registry.py``. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger("openspace.host_detection") + +PROVIDER_REGISTRY: List[tuple] = [ + # Gateways + ("openrouter", ("openrouter",), "https://openrouter.ai/api/v1"), + ("aihubmix", ("aihubmix",), "https://aihubmix.com/v1"), + ("siliconflow", ("siliconflow",), "https://api.siliconflow.cn/v1"), + ("volcengine", ("volcengine", "volces", "ark"), "https://ark.cn-beijing.volces.com/api/v3"), + # Standard providers + ("anthropic", ("anthropic", "claude"), ""), + ("openai", ("openai", "gpt"), ""), + ("deepseek", ("deepseek",), ""), + ("gemini", ("gemini",), ""), + ("zhipu", ("zhipu", "glm", "zai"), ""), + ("dashscope", ("qwen", "dashscope"), ""), + ("moonshot", ("moonshot", "kimi"), "https://api.moonshot.ai/v1"), + ("minimax", ("minimax",), "https://api.minimax.io/v1"), + ("groq", ("groq",), ""), +] + +NANOBOT_CONFIG_PATH = Path.home() / ".nanobot" / "config.json" + + +def _load_nanobot_config() -> Optional[Dict[str, Any]]: + """Load and parse ``~/.nanobot/config.json``. Returns None on failure.""" + if not NANOBOT_CONFIG_PATH.is_file(): + return None + try: + with open(NANOBOT_CONFIG_PATH, encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else None + except (json.JSONDecodeError, OSError) as e: + logger.warning("Failed to read nanobot config %s: %s", NANOBOT_CONFIG_PATH, e) + return None + + +def match_provider( + providers: Dict[str, Any], + model: str, + forced_provider: str = "auto", +) -> Optional[Dict[str, Any]]: + """Match a provider config dict from nanobot's ``providers`` section. + + Resolution order: + 1. ``forced_provider`` (if not "auto") → use that config field directly. + 2. Prefix match: model's first ``/``-separated segment == config field name. + 3. Keyword match: iterate ``PROVIDER_REGISTRY`` by priority. + 4. Fallback: first provider with a non-empty ``apiKey``. + + Returns: + ``{"api_key": ..., "api_base": ..., "extra_headers": ...}`` + (litellm-compatible), or None. + """ + def _extract(prov_dict: Dict[str, Any], default_base: str = "") -> Optional[Dict[str, Any]]: + api_key = prov_dict.get("apiKey") or prov_dict.get("api_key") or "" + if not api_key: + return None + result: Dict[str, Any] = {"api_key": api_key} + api_base = prov_dict.get("apiBase") or prov_dict.get("api_base") or default_base + if api_base: + result["api_base"] = api_base + extra = prov_dict.get("extraHeaders") or prov_dict.get("extra_headers") + if extra and isinstance(extra, dict): + result["extra_headers"] = extra + return result + + model_lower = model.lower() + model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" + normalized_prefix = model_prefix.replace("-", "_") + + # 1. Forced provider + if forced_provider and forced_provider != "auto": + p = providers.get(forced_provider) + if p and isinstance(p, dict): + # Look up the default api_base for this provider from the registry + forced_default_base = "" + for _name, _kws, _base in PROVIDER_REGISTRY: + if _name == forced_provider: + forced_default_base = _base + break + return _extract(p, forced_default_base) + + # 2. Prefix match + for name, _kws, default_base in PROVIDER_REGISTRY: + if model_prefix and normalized_prefix == name: + p = providers.get(name) + if p and isinstance(p, dict): + result = _extract(p, default_base) + if result: + return result + + # 3. Keyword match + for name, keywords, default_base in PROVIDER_REGISTRY: + if any(kw in model_lower for kw in keywords): + p = providers.get(name) + if p and isinstance(p, dict): + result = _extract(p, default_base) + if result: + return result + + # 4. Fallback: first provider with an api_key + for name, _kws, default_base in PROVIDER_REGISTRY: + p = providers.get(name) + if p and isinstance(p, dict): + result = _extract(p, default_base) + if result: + return result + + return None + + +def try_read_nanobot_config(model: str) -> Optional[Dict[str, Any]]: + """Read LLM credentials from ``~/.nanobot/config.json``. + + Returns litellm kwargs dict (``api_key``, ``api_base``, ``extra_headers``), + or None. May include a ``"_model"`` key with the nanobot default model. + """ + data = _load_nanobot_config() + if data is None: + return None + + providers = data.get("providers", {}) + if not isinstance(providers, dict): + return None + + agents = data.get("agents", {}) + defaults = agents.get("defaults", {}) if isinstance(agents, dict) else {} + nanobot_model = defaults.get("model", "") if isinstance(defaults, dict) else "" + forced_provider = defaults.get("provider", "auto") if isinstance(defaults, dict) else "auto" + + match_model = model or nanobot_model or "" + result = match_provider(providers, match_model, forced_provider) + + if result and nanobot_model: + result["_model"] = nanobot_model + if result and forced_provider and forced_provider != "auto": + result["_forced_provider"] = forced_provider + + if result: + logger.info( + "Auto-detected LLM credentials from nanobot config (%s), " + "provider matched for model=%r", + NANOBOT_CONFIG_PATH, match_model, + ) + + return result + + +def read_nanobot_mcp_env() -> Dict[str, str]: + """Read ``tools.mcpServers.openspace.env`` from nanobot config. + + Returns the env dict (empty if not found / parse error). + """ + data = _load_nanobot_config() + if data is None: + return {} + + tools = data.get("tools", {}) + if not isinstance(tools, dict): + return {} + mcp_servers = tools.get("mcpServers") or tools.get("mcp_servers") or {} + if not isinstance(mcp_servers, dict): + return {} + openspace_cfg = mcp_servers.get("openspace", {}) + if not isinstance(openspace_cfg, dict): + return {} + env_block = openspace_cfg.get("env", {}) + return env_block if isinstance(env_block, dict) else {} + + +def get_openai_api_key() -> Optional[str]: + """Get OpenAI API key for embedding generation. + + Resolution: + 1. ``OPENAI_API_KEY`` env var + 2. nanobot config ``providers.openai.apiKey`` + 3. None + """ + import os + key = os.environ.get("OPENAI_API_KEY") + if key: + return key + + data = _load_nanobot_config() + if data: + providers = data.get("providers", {}) + if isinstance(providers, dict): + openai_cfg = providers.get("openai", {}) + if isinstance(openai_cfg, dict): + api_key = openai_cfg.get("apiKey") + if api_key: + logger.debug("Using OpenAI API key from nanobot config for embeddings") + return api_key + return None + diff --git a/openspace/host_detection/openclaw.py b/openspace/host_detection/openclaw.py new file mode 100644 index 0000000000000000000000000000000000000000..1107b5c512252a187e713274f62e65e6426237f6 --- /dev/null +++ b/openspace/host_detection/openclaw.py @@ -0,0 +1,137 @@ +"""OpenClaw host-agent config reader. + +Reads ``~/.openclaw/openclaw.json`` to auto-detect: + - LLM provider credentials (via ``auth-profiles`` — not yet implemented) + - Skill-level env block (``skills.entries.openspace.env``) + - OpenAI API key for embedding generation + +Config path resolution mirrors OpenClaw's own logic: + 1. ``OPENCLAW_CONFIG_PATH`` env var + 2. ``OPENCLAW_STATE_DIR/openclaw.json`` + 3. ``~/.openclaw/openclaw.json`` (default) + +Fallback legacy dirs: ``~/.clawdbot``, ``~/.moldbot``, ``~/.moltbot``. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +logger = logging.getLogger("openspace.host_detection") + +_STATE_DIRNAMES = [".openclaw", ".clawdbot", ".moldbot", ".moltbot"] +_CONFIG_FILENAMES = ["openclaw.json", "clawdbot.json", "moldbot.json", "moltbot.json"] + + +def _resolve_openclaw_config_path() -> Optional[Path]: + """Find the OpenClaw config file on disk.""" + import os + + # 1. Explicit env override + explicit = os.environ.get("OPENCLAW_CONFIG_PATH", "").strip() + if explicit: + p = Path(explicit).expanduser() + if p.is_file(): + return p + return None + + # 2. State dir override + state_dir = os.environ.get("OPENCLAW_STATE_DIR", "").strip() + if state_dir: + for fname in _CONFIG_FILENAMES: + p = Path(state_dir) / fname + if p.is_file(): + return p + + # 3. Default locations + home = Path.home() + for dirname in _STATE_DIRNAMES: + for fname in _CONFIG_FILENAMES: + p = home / dirname / fname + if p.is_file(): + return p + + return None + + +def _load_openclaw_config() -> Optional[Dict[str, Any]]: + """Load and parse the OpenClaw config file. Returns None on failure.""" + config_path = _resolve_openclaw_config_path() + if config_path is None: + return None + try: + with open(config_path, encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else None + except (json.JSONDecodeError, OSError) as e: + logger.warning("Failed to read OpenClaw config %s: %s", config_path, e) + return None + + +def read_openclaw_skill_env(skill_name: str = "openspace") -> Dict[str, str]: + """Read ``skills.entries..env`` from OpenClaw config. + + This is the OpenClaw equivalent of nanobot's + ``tools.mcpServers.openspace.env``. + + Returns the env dict (empty if not found / parse error). + """ + data = _load_openclaw_config() + if data is None: + return {} + + skills = data.get("skills", {}) + if not isinstance(skills, dict): + return {} + entries = skills.get("entries", {}) + if not isinstance(entries, dict): + return {} + skill_cfg = entries.get(skill_name, {}) + if not isinstance(skill_cfg, dict): + return {} + env_block = skill_cfg.get("env", {}) + return env_block if isinstance(env_block, dict) else {} + + +def get_openclaw_openai_api_key() -> Optional[str]: + """Get OpenAI API key from OpenClaw config. + + Checks ``skills.entries.openspace.env.OPENAI_API_KEY`` first, + then any top-level env vars in the config. + + Returns the key string, or None. + """ + # Try skill-level env + env = read_openclaw_skill_env("openspace") + key = env.get("OPENAI_API_KEY", "").strip() + if key: + logger.debug("Using OpenAI API key from OpenClaw skill env config") + return key + + # Try top-level config env.vars + data = _load_openclaw_config() + if data: + env_section = data.get("env", {}) + if isinstance(env_section, dict): + vars_block = env_section.get("vars", {}) + if isinstance(vars_block, dict): + key = vars_block.get("OPENAI_API_KEY", "").strip() + if key: + logger.debug("Using OpenAI API key from OpenClaw env.vars config") + return key + + return None + + +def is_openclaw_host() -> bool: + """Detect if the current environment is running under OpenClaw.""" + import os + # Check OpenClaw-specific env vars + if os.environ.get("OPENCLAW_STATE_DIR") or os.environ.get("OPENCLAW_CONFIG_PATH"): + return True + # Check if config exists + return _resolve_openclaw_config_path() is not None + diff --git a/openspace/host_detection/resolver.py b/openspace/host_detection/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb36116b997c8d78c0a59a5ac2526963914a044 --- /dev/null +++ b/openspace/host_detection/resolver.py @@ -0,0 +1,183 @@ +"""LLM credential and grounding config resolution. + +Resolves the model name and litellm kwargs for OpenSpace's LLM client, +and assembles grounding config from env-var overrides. +""" + +from __future__ import annotations + +import json +import logging +import os +import tempfile +from typing import Any, Dict, Optional + +logger = logging.getLogger("openspace.host_detection") + + +def build_llm_kwargs(model: str) -> tuple[str, Dict[str, Any]]: + """Build litellm kwargs and resolve model for OpenSpace's LLM client. + + Resolution order (highest → lowest priority): + + Tier 1 — Explicit ``OPENSPACE_LLM_*`` env vars:: + + OPENSPACE_LLM_API_KEY → litellm ``api_key`` + OPENSPACE_LLM_API_BASE → litellm ``api_base`` + OPENSPACE_LLM_EXTRA_HEADERS → litellm ``extra_headers`` (JSON string) + OPENSPACE_LLM_CONFIG → arbitrary litellm kwargs (JSON string) + + Tier 2 — Auto-detect from host agent config file:: + + ~/.nanobot/config.json → providers.{matched}.apiKey / apiBase + + Tier 3 — Provider-native env vars inherited from the parent process + (e.g. ``OPENROUTER_API_KEY``). Read by litellm automatically. + + Returns: + ``(resolved_model, llm_kwargs_dict)`` + """ + from openspace.host_detection.nanobot import try_read_nanobot_config + + kwargs: Dict[str, Any] = {} + resolved_model = model + source = "inherited env" + + # --- Tier 2: auto-detect from host config (filled first, may be overridden) --- + host_config = try_read_nanobot_config(model) + if host_config: + host_model = host_config.pop("_model", None) + forced_provider = host_config.pop("_forced_provider", None) + if not resolved_model and host_model: + resolved_model = host_model + # If the host config forces a gateway provider (e.g. openrouter) + # and the model name doesn't already carry that prefix, prepend + # it so that litellm uses the correct request format (OpenAI- + # compatible for gateways vs native for direct providers). + _GATEWAY_PROVIDERS = {"openrouter", "aihubmix", "siliconflow"} + if ( + forced_provider + and forced_provider in _GATEWAY_PROVIDERS + and resolved_model + and not resolved_model.lower().startswith(f"{forced_provider}/") + ): + resolved_model = f"{forced_provider}/{resolved_model}" + logger.info( + "Prepended gateway prefix: model=%r (forced_provider=%s)", + resolved_model, forced_provider, + ) + kwargs.update(host_config) + source = "nanobot config" + + # --- Tier 1: explicit env vars override everything --- + api_key = os.environ.get("OPENSPACE_LLM_API_KEY") + if api_key: + kwargs["api_key"] = api_key + source = "OPENSPACE_LLM_* env" + + api_base = os.environ.get("OPENSPACE_LLM_API_BASE") + if api_base: + kwargs["api_base"] = api_base + + extra_headers_raw = os.environ.get("OPENSPACE_LLM_EXTRA_HEADERS") + if extra_headers_raw: + try: + headers = json.loads(extra_headers_raw) + if isinstance(headers, dict): + kwargs["extra_headers"] = headers + except json.JSONDecodeError: + logger.warning("Invalid JSON in OPENSPACE_LLM_EXTRA_HEADERS: %r", extra_headers_raw) + + llm_config_raw = os.environ.get("OPENSPACE_LLM_CONFIG") + if llm_config_raw: + try: + llm_config = json.loads(llm_config_raw) + if isinstance(llm_config, dict): + kwargs.update(llm_config) + source = "OPENSPACE_LLM_CONFIG env" + except json.JSONDecodeError: + logger.warning("Invalid JSON in OPENSPACE_LLM_CONFIG: %r", llm_config_raw) + + # Default model fallback + if not resolved_model: + resolved_model = "openrouter/anthropic/claude-sonnet-4.5" + + if kwargs: + safe = { + k: (v[:8] + "..." if k == "api_key" and isinstance(v, str) and len(v) > 8 else v) + for k, v in kwargs.items() + } + logger.info("LLM kwargs resolved (source=%s): %s", source, safe) + + return resolved_model, kwargs + + +def build_grounding_config_path() -> Optional[str]: + """Resolve grounding config: inline JSON > file path > None. + + Supports: + * ``OPENSPACE_CONFIG_JSON`` — inline JSON string (written to a temp file) + * ``OPENSPACE_CONFIG_PATH`` — path to a JSON config file + + Granular env-var overrides (``OPENSPACE_SHELL_*``, ``OPENSPACE_SKILLS_*``, + etc.) are merged before writing. + + Returns: + Path to the resolved config file, or None. + """ + config_json_raw = os.environ.get("OPENSPACE_CONFIG_JSON", "").strip() + overrides: Dict[str, Any] = {} + if config_json_raw: + try: + overrides = json.loads(config_json_raw) + if not isinstance(overrides, dict): + logger.warning("OPENSPACE_CONFIG_JSON is not a dict, ignoring") + overrides = {} + else: + logger.info("Loaded inline config from OPENSPACE_CONFIG_JSON") + except json.JSONDecodeError as e: + logger.warning("Invalid JSON in OPENSPACE_CONFIG_JSON: %s", e) + + # --- Granular env-var overrides --- + conda_env = os.environ.get("OPENSPACE_SHELL_CONDA_ENV", "").strip() + if conda_env: + overrides.setdefault("shell", {})["conda_env"] = conda_env + + shell_wd = os.environ.get("OPENSPACE_SHELL_WORKING_DIR", "").strip() + if shell_wd: + overrides.setdefault("shell", {})["working_dir"] = shell_wd + + skills_dirs_raw = os.environ.get("OPENSPACE_SKILLS_DIRS", "").strip() + if skills_dirs_raw: + dirs = [d.strip() for d in skills_dirs_raw.split(",") if d.strip()] + if dirs: + overrides.setdefault("skills", {})["skill_dirs"] = dirs + + mcp_servers_raw = os.environ.get("OPENSPACE_MCP_SERVERS_JSON", "").strip() + if mcp_servers_raw: + try: + servers = json.loads(mcp_servers_raw) + if isinstance(servers, dict): + overrides["mcpServers"] = servers + except json.JSONDecodeError as e: + logger.warning("Invalid JSON in OPENSPACE_MCP_SERVERS_JSON: %s", e) + + log_level = os.environ.get("OPENSPACE_LOG_LEVEL", "").strip().upper() + if log_level and log_level in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"): + overrides["log_level"] = log_level + + if overrides: + try: + fd, tmp_path = tempfile.mkstemp(suffix=".json", prefix="openspace_cfg_") + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(overrides, f, ensure_ascii=False) + logger.info( + "Grounding config overrides written to %s (%d keys)", + tmp_path, len(overrides), + ) + return tmp_path + except Exception as e: + logger.warning("Failed to write config overrides: %s", e) + + return os.environ.get("OPENSPACE_CONFIG_PATH") + diff --git a/openspace/host_skills/README.md b/openspace/host_skills/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0a555ede34038745212bb1d13d3af8c2566d023c --- /dev/null +++ b/openspace/host_skills/README.md @@ -0,0 +1,115 @@ +# Host Skills Integration Guide + +This guide covers **agent-specific setup** for integrating OpenSpace. For installation and general concepts, see the [main README](../../README.md#-quick-start). + +**Pick your agent:** + +| Agent | Setup Guide | +|------------|-------------| +| **[nanobot](https://github.com/HKUDS/nanobot)** | [Setup for nanobot](#setup-for-nanobot) | +| **[openclaw](https://github.com/openclaw/openclaw)** | [Setup for openclaw](#setup-for-openclaw) | +| **Other agents** | Follow the [generic setup](../../README.md#-path-a-empower-your-agent-with-openspace) in the main README | + +--- + +## Setup for nanobot + +### 1. Copy host skills + +```bash +cp -r host_skills/skill-discovery/ /path/to/nanobot/nanobot/skills/ +cp -r host_skills/delegate-task/ /path/to/nanobot/nanobot/skills/ +``` + +### 2. Add MCP server to `~/.nanobot/config.json` + +```json +{ + "tools": { + "mcpServers": { + "openspace": { + "command": "openspace-mcp", + "toolTimeout": 1200, + "env": { + "OPENSPACE_HOST_SKILL_DIRS": "/path/to/nanobot/nanobot/skills", + "OPENSPACE_WORKSPACE": "/path/to/OpenSpace", + "OPENSPACE_API_KEY": "sk-xxx" + } + } + } + } +} +``` + +> [!TIP] +> LLM credentials are auto-detected from nanobot's `providers.*` config — no need to set `OPENSPACE_LLM_API_KEY`. + +--- + +## Setup for openclaw + +### 1. Copy host skills + +```bash +cp -r host_skills/skill-discovery/ /path/to/openclaw/skills/ +cp -r host_skills/delegate-task/ /path/to/openclaw/skills/ +``` + +### 2. Register MCP server with env vars + +openclaw uses [mcporter](https://github.com/steipete/mcporter) as its MCP runtime. Register the server and pass env vars in one command: + +```bash +mcporter config add openspace --command "openspace-mcp" \ + --env OPENSPACE_HOST_SKILL_DIRS=/path/to/openclaw/skills \ + --env OPENSPACE_WORKSPACE=/path/to/OpenSpace \ + --env OPENSPACE_API_KEY=sk-xxx +``` + +--- + +## Environment Variables (Agent-Specific) + +The three env vars in each agent's setup above are the most important. For the **full env var list**, config files reference, and advanced settings, see the [Configuration Guide](../../README.md#configuration-guide) in the main README. + +
+What needs OPENSPACE_API_KEY? + +| Capability | Without API Key | With API Key | +|-----------|----------------|--------------| +| `execute_task` | ✅ works (local skills only) | ✅ + cloud skill search | +| `search_skills` | ✅ works (local results only) | ✅ + cloud results | +| `fix_skill` | ✅ works | ✅ works | +| `upload_skill` | ❌ fails | ✅ uploads to cloud | + +All tools default to `"all"` (local + cloud) and **automatically fall back** to local-only if no API key is configured. No need to change tool parameters. + +
+ +--- + +## How It Works + +``` +Your Agent (nanobot / openclaw / ...) + │ + │ MCP protocol (stdio) + ▼ +openspace-mcp ← 4 tools exposed + ├── execute_task ← multi-step grounding agent loop + ├── search_skills ← local + cloud skill search + ├── fix_skill ← repair a broken SKILL.md + └── upload_skill ← push skill to cloud community +``` + +The two host skills teach the agent **when and how** to call these tools: + +| Skill | MCP Tools | Purpose | +|-------|-----------|---------| +| **skill-discovery** | `search_skills` | Search local + cloud skills → decide: follow it yourself, delegate, or skip | +| **delegate-task** | `execute_task` `search_skills` `fix_skill` `upload_skill` | Delegate tasks, search skills, repair broken skills, upload evolved skills | + +Skills auto-evolve inside `execute_task` (**FIX** / **DERIVED** / **CAPTURED**). After every call, your agent reports results to the user via its messaging tool. + +> [!NOTE] +> For full parameter tables, examples, and decision trees, see each skill's SKILL.md directly. \ No newline at end of file diff --git a/openspace/host_skills/delegate-task/SKILL.md b/openspace/host_skills/delegate-task/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..ce9406107070721d73a60a8b9cf0716443121496 --- /dev/null +++ b/openspace/host_skills/delegate-task/SKILL.md @@ -0,0 +1,131 @@ +--- +name: delegate-task +description: Delegate tasks to OpenSpace — a full-stack autonomous worker for coding, DevOps, web research, and desktop automation, backed by an extensive MCP tool and skill library. Skills auto-improve through use, reducing token consumption over time. A cloud community lets agents share and collectively evolve reusable skills. +--- + +# Delegate Tasks to OpenSpace + +OpenSpace is connected as an MCP server. You have 4 tools available: `execute_task`, `search_skills`, `fix_skill`, `upload_skill`. + +## When to use + +- **You lack the capability** — the task requires tools or capabilities beyond what you can access +- **You tried and failed** — you produced incorrect results; OpenSpace may have a tested skill for it +- **Complex multi-step task** — the task involves many steps, tools, or environments that benefit from OpenSpace's skill library and orchestration +- **User explicitly asks** — user requests delegation to OpenSpace + +## Tools + +### execute_task + +Delegate a task to OpenSpace. It will search for relevant skills, execute, and auto-evolve skills if needed. + +``` +execute_task(task="Monitor Docker containers, find the highest memory one, restart it gracefully", search_scope="all") +``` + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `task` | yes | — | Task instruction in natural language | +| `search_scope` | no | `"all"` | Local + cloud; falls back to local-only if no API key | +| `max_iterations` | no | `20` | Max agent iterations — increase for complex tasks, decrease for simple ones | + +Check response for `evolved_skills`. If present with `upload_ready: true`, decide whether to upload (see "When to upload" below). + +```json +{ + "status": "success", + "response": "Task completed successfully", + "evolved_skills": [ + { + "skill_dir": "/path/to/skills/new-skill", + "name": "new-skill", + "origin": "captured", + "change_summary": "Captured reusable workflow pattern", + "upload_ready": true + } + ] +} +``` + +### search_skills + +Search for available skills before deciding whether to handle a task yourself or delegate. + +``` +search_skills(query="docker container monitoring", source="all") +``` + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `query` | yes | — | Search query (natural language or keywords) | +| `source` | no | `"all"` | Local + cloud; falls back to local-only if no API key | +| `limit` | no | `20` | Max results | +| `auto_import` | no | `true` | Auto-download top cloud skills locally | + +### fix_skill + +Manually fix a broken skill. + +``` +fix_skill( + skill_dir="/path/to/skills/weather-api", + direction="The API endpoint changed from v1 to v2, update all URLs and add the new 'units' parameter" +) +``` + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `skill_dir` | yes | Path to skill directory (must contain SKILL.md) | +| `direction` | yes | What's broken and how to fix — be specific | + +Response has `upload_ready: true` → decide whether to upload. + +### upload_skill + +Upload a skill to the cloud community. For evolved/fixed skills, metadata is pre-saved — just provide `skill_dir` and `visibility`. + +``` +upload_skill( + skill_dir="/path/to/skills/weather-api", + visibility="public" +) +``` + +For new skills (no auto metadata — defaults apply, but richer metadata improves discoverability): + +``` +upload_skill( + skill_dir="/path/to/skills/my-new-skill", + visibility="public", + origin="imported", + tags=["weather", "api"], + created_by="my-bot", + change_summary="Initial upload of weather API skill" +) +``` + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `skill_dir` | yes | — | Path to skill directory (must contain SKILL.md) | +| `visibility` | no | `"public"` | `"public"` or `"private"` | +| `origin` | no | auto | How the skill was created | +| `parent_skill_ids` | no | auto | Parent skill IDs | +| `tags` | no | auto | Tags | +| `created_by` | no | auto | Creator | +| `change_summary` | no | auto | What changed | + +### When to upload + +| Situation | Action | +|-----------|--------| +| Skill was originally from the cloud | Upload back as `"public"` — return the improvement to the community | +| Fix/evolution is generally useful | Upload as `"public"` | +| Fix/evolution is project-specific | Upload as `"private"`, or skip | +| User says to share | Upload with the visibility the user wants | + +## Notes + +- `execute_task` may take minutes — this is expected for multi-step tasks. +- `upload_skill` requires a cloud API key; if it fails, the evolved skill is still saved locally. +- After every OpenSpace call, **tell the user** what happened: task result, any evolved skills, and your upload decision. diff --git a/openspace/host_skills/skill-discovery/SKILL.md b/openspace/host_skills/skill-discovery/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..cd554b56023a4d05dac71be88ca388b722ffd54f --- /dev/null +++ b/openspace/host_skills/skill-discovery/SKILL.md @@ -0,0 +1,47 @@ +--- +name: skill-discovery +description: Search for reusable skills across OpenSpace's local registry and cloud community. Reusing proven skills saves tokens, improves reliability, and extends your capabilities beyond built-in tools. +--- + +# Skill Discovery + +Discover and browse skills from OpenSpace's local and cloud skill library. + +## When to use + +- User asks "what skills are available?" or "is there a skill for X?" +- You encounter an unfamiliar task — a proven skill can save significant tokens over trial-and-error +- You need to decide: handle a task yourself, or delegate to OpenSpace + +## search_skills + +``` +search_skills(query="automated deployment with rollback", source="all") +``` + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `query` | yes | — | Natural language or keywords | +| `source` | no | `"all"` | Local + cloud; falls back to local-only if no API key | +| `limit` | no | `20` | Max results | +| `auto_import` | no | `true` | Auto-download top cloud hits locally | + +## After search + +Results are returned to you (not executed). Cloud hits with `auto_imported: true` include a `local_path`. + +``` +Found a matching skill? +├── YES, and I can follow it myself +│ → read SKILL.md at local_path, follow the instructions +├── YES, but I lack the capability +│ → delegate via execute_task (see delegate-task skill) +└── NO match + → handle it yourself, or delegate via execute_task +``` + +## Notes + +- This is for **discovery** — you see results and decide. For direct execution, use `execute_task` from the `delegate-task` skill. +- Cloud skills have been evolved through real use — more reliable than skills written from scratch. +- Always tell the user what you found (or didn't find) and what you recommend. diff --git a/openspace/llm/__init__.py b/openspace/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..370955dec4dc6b11bed87518d1cc77866a04ded8 --- /dev/null +++ b/openspace/llm/__init__.py @@ -0,0 +1 @@ +from .client import LLMClient \ No newline at end of file diff --git a/openspace/llm/client.py b/openspace/llm/client.py new file mode 100644 index 0000000000000000000000000000000000000000..19a166491a4365fa534849e0b37251b464e096fc --- /dev/null +++ b/openspace/llm/client.py @@ -0,0 +1,769 @@ +import litellm +import json +import asyncio +import time +from pathlib import Path +from typing import List, Sequence, Union, Dict, Optional +from dotenv import load_dotenv +from openai.types.chat import ChatCompletionToolParam + +from openspace.grounding.core.types import ToolSchema, ToolResult, ToolStatus +from openspace.grounding.core.tool import BaseTool +from openspace.utils.logging import Logger + +# Load .env from openspace package root (works regardless of CWD), +# then fall back to CWD/.env. override=False (default) means first-loaded wins. +_PKG_ENV = Path(__file__).resolve().parent.parent / ".env" # openspace/.env +if _PKG_ENV.is_file(): + load_dotenv(_PKG_ENV) +load_dotenv() # also try CWD/.env for any remaining vars + +# Disable LiteLLM verbose logging to prevent stdout blocking with large tool schemas +litellm.set_verbose = False +litellm.suppress_debug_info = True + +logger = Logger.get_logger(__name__) + + +def _sanitize_schema(params: Dict) -> Dict: + """Sanitize tool parameter schema to comply with Claude API requirements. + + Fixes common issues: + - Empty object schemas (no properties, no required) + - Missing required fields for Claude compatibility + """ + if not params: + return {"type": "object", "properties": {}, "required": []} + + # Deep copy to avoid modifying the original + import copy + sanitized = copy.deepcopy(params) + + # Anthropic API requires top-level type to be 'object' + # If it's not an object, wrap the schema as a property of an object + top_level_type = sanitized.get("type") + if top_level_type and top_level_type != "object": + # Wrap non-object schema as a single property called "value" + logger.debug(f"[SCHEMA_SANITIZE] Wrapping non-object schema (type={top_level_type}) into object") + wrapped = { + "type": "object", + "properties": { + "value": sanitized # The original schema becomes a property + }, + "required": ["value"] # Make it required + } + sanitized = wrapped + + # If type is object but missing properties/required, add them + if sanitized.get("type") == "object": + if "properties" not in sanitized: + sanitized["properties"] = {} + if "required" not in sanitized: + sanitized["required"] = [] + + # Remove non-standard fields that may cause issues (like 'title') + sanitized.pop("title", None) + + # Recursively sanitize nested properties + if "properties" in sanitized and isinstance(sanitized["properties"], dict): + for prop_name, prop_schema in list(sanitized["properties"].items()): + if isinstance(prop_schema, dict): + # Remove title from nested properties + prop_schema.pop("title", None) + + return sanitized + + +def _schema_to_openai(schema: ToolSchema) -> ChatCompletionToolParam: + """Convert ToolSchema to OpenAI ChatCompletion tool format""" + function_def = { + "name": schema.name, + "description": schema.description or "", + } + + # Sanitize and add parameters + if schema.parameters: + sanitized = _sanitize_schema(schema.parameters) + function_def["parameters"] = sanitized + # Debug: verify sanitization worked + if "title" in schema.parameters and "title" not in sanitized: + logger.debug(f"Sanitized tool '{schema.name}': removed title") + else: + # Claude requires parameters field even if empty + function_def["parameters"] = {"type": "object", "properties": {}, "required": []} + + return { + "type": "function", + "function": function_def + } + +def _prepare_tools_for_llmclient( + tools: List[BaseTool] | None, + fmt: str = "openai", +) -> tuple[Sequence[Union[ToolSchema, ChatCompletionToolParam]], Dict[str, BaseTool]]: + """Convert BaseTool list to LLMClient usable format, with deduplication. + + Args: + tools: BaseTool instance list (should be obtained from GroundingClient and bound to runtime_info) + if None or empty list, return empty list + fmt: output format, "openai" for OpenAI format + """ + if not tools: + return [], {} + + if fmt == "openai": + result = [] + tool_map = {} # llm_name -> BaseTool + name_count = {} + + for tool in tools: + name = tool.schema.name + name_count[name] = name_count.get(name, 0) + 1 + + + seen_names = set() + for tool in tools: + original_name = tool.schema.name + + if name_count[original_name] > 1: + server_name = "unknown" + if tool.is_bound and tool.runtime_info and tool.runtime_info.server_name: + server_name = tool.runtime_info.server_name + llm_name = f"{server_name}__{original_name}" + else: + llm_name = original_name + + if llm_name in seen_names: + logger.warning(f"[TOOL_DEDUP] Skipping duplicate tool: {llm_name}") + continue + seen_names.add(llm_name) + + tool_param = _schema_to_openai(tool.schema) + tool_param["function"]["name"] = llm_name + + # Tag the description with backend type so the LLM knows each + # tool's origin (e.g. "[MCP] ...", "[Shell] ..."). + backend_type = getattr(tool.schema, "backend_type", None) + if backend_type and backend_type.value not in ("not_set",): + _BACKEND_LABELS = { + "mcp": "MCP", + "shell": "Shell", + "gui": "GUI", + "web": "Web", + "system": "System", + } + label = _BACKEND_LABELS.get(backend_type.value, backend_type.value) + desc = tool_param["function"].get("description", "") + tool_param["function"]["description"] = f"[{label}] {desc}" + + result.append(tool_param) + + tool_map[llm_name] = tool + + if llm_name != original_name: + logger.info(f"[TOOL_RENAME] {original_name} -> {llm_name}") + + logger.info(f"[SCHEMA_SANITIZE] Prepared {len(result)} tools for LLM (from {len(tools)} total)") + return result, tool_map + + tool_map = {tool.schema.name: tool for tool in tools} + return [tool.schema for tool in tools], tool_map + + +def _infer_backend_from_tool_name(tool_name: str) -> Optional[str]: + """Infer backend when tool_results would otherwise have no backend (name mismatch or unbound tools).""" + if not tool_name or not isinstance(tool_name, str): + return None + name = tool_name.strip() + # Dedup format: "server__toolname" -> use suffix + if "__" in name: + name = name.split("__", 1)[-1] + shell_tools = {"shell_agent", "read_file", "write_file", "list_dir", "run_shell"} + if name in shell_tools: + return "shell" + if name in ("gui_agent",) or "gui" in name.lower(): + return "gui" + if "mcp" in name.lower() or ("." in name and "__" not in name): + return "mcp" + if name in ("deep_research_agent", "deep_research"): + return "web" + return None + + +DEFAULT_SUMMARIZE_THRESHOLD_CHARS = 200000 # ~50K tokens, lowered from 400K to prevent context overflow +MAX_TOOL_RESULT_CHARS = 200000 # Fallback truncation limit when summarization fails (~50K tokens) + +async def _summarize_tool_result( + content: str, + tool_name: str, + task: str = "", + model: str = "openrouter/anthropic/claude-sonnet-4.5", + timeout: float = 120.0 +) -> str: + """Use LLM to summarize large tool results.""" + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("summarizer") + except ImportError: + _src_tok = None + + try: + logger.info(f"Summarizing tool result from '{tool_name}': {len(content):,} chars") + + # Pre-truncate if content is too large for the model (leave room for prompt + output) + # Assuming ~4 chars per token, 200K tokens limit, 8K output, ~500 tokens for prompt + # Safe input limit: (200K - 8K - 0.5K) * 4 = ~766K chars, but be conservative at 400K + max_input_chars = 200000 + if len(content) > max_input_chars: + logger.warning(f"Pre-truncating content for summarization: {len(content):,} -> {max_input_chars:,} chars") + content = content[:max_input_chars] + f"\n\n[TRUNCATED for summarization: original was {len(content):,} chars]" + + task_hint = f"\n\nUser's task: {task}\nSummarize with focus on information relevant to this task." if task else "" + + prompt = f"""Tool '{tool_name}' returned a large result ({len(content):,} chars). Summarize it concisely.{task_hint} + +**Guidelines:** +- Structured data (coordinates, steps, etc.): Keep key summary (totals, start/end), omit repetitive details. +- Markup content (HTML, XML): Extract text and key data only, ignore tags/scripts. +- Long documents: Keep structure outline and essential sections. +- Lists/arrays: Summarize count and most relevant items. +- Always preserve: numbers, URLs, file paths, IDs, key identifiers. + +Content: +{content} + +Concise summary:""" + + response = await asyncio.wait_for( + litellm.acompletion( + model=model, + messages=[{"role": "user", "content": prompt}], + timeout=timeout + ), + timeout=timeout + 5 + ) + + summary = response.choices[0].message.content.strip() + result = f"[SUMMARY of {len(content):,} chars]\n{summary}" + + logger.info(f"Tool result summarized: {len(content):,} -> {len(result):,} chars") + return result + + except Exception as e: + logger.warning(f"Summarization failed for '{tool_name}': {e}") + return None + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + +async def _tool_result_to_message_async( + result: ToolResult, + *, + tool_call_id: str, + tool_name: str, + task: str = "", + summarize_threshold: int = DEFAULT_SUMMARIZE_THRESHOLD_CHARS, + summarize_model: str = "openrouter/anthropic/claude-sonnet-4.5", + enable_summarization: bool = True +) -> Dict: + """Convert ToolResult to LLMClient usable message format with LLM summarization for large results. + + Args: + result: Tool execution result + tool_call_id: OpenAI tool_call ID + tool_name: Tool name + task: User's original task for context-aware summarization + summarize_threshold: If content exceeds this, use LLM summarization + summarize_model: Model to use for summarization + enable_summarization: Whether to enable LLM summarization + + Returns: + OpenAI ChatCompletion tool message (text only) + """ + if result.is_error: + text_content = f"[ERROR] {result.error or 'unknown error'}" + else: + text_content = ( + result.content + if isinstance(result.content, str) + else json.dumps(result.content, ensure_ascii=False, default=str) + ) + + original_len = len(text_content) + + # Use LLM summarization if content exceeds threshold + if original_len > summarize_threshold and enable_summarization: + summary = await _summarize_tool_result(text_content, tool_name, task, summarize_model) + if summary: + text_content = summary + elif original_len > MAX_TOOL_RESULT_CHARS: + # Fallback: truncate if summarization failed and content is too large + truncate_msg = f"\n\n[TRUNCATED: Original content was {original_len:,} chars, showing first {MAX_TOOL_RESULT_CHARS:,}]" + text_content = text_content[:MAX_TOOL_RESULT_CHARS - len(truncate_msg)] + truncate_msg + logger.warning(f"Tool result truncated for '{tool_name}': {original_len:,} -> {len(text_content):,} chars (summarization failed)") + + return { + "role": "tool", + "name": tool_name, + "content": text_content, + "tool_call_id": tool_call_id, + } + +async def _execute_tool_call( + tool: BaseTool, + openai_tool_call: Dict, +) -> ToolResult: + """Execute LLMClient returned tool_call + + Args: + tool: BaseTool instance (must be obtained from GroundingClient and bound to runtime_info) + openai_tool_call: LLMClient usable tool_call object, contains id, type, function etc. fields + """ + if not tool.is_bound: + raise ValueError( + f"Tool '{tool.schema.name}' is not bound to runtime_info. " + f"Please ensure tools are obtained from GroundingClient.list_tools() " + f"with bind_runtime_info=True" + ) + + func = openai_tool_call["function"] + arguments = func.get("arguments", "{}") + if isinstance(arguments, str): + arguments = json.loads(arguments or "{}") + + # Filter out parameters that are not in the tool's schema + if isinstance(arguments, dict) and tool.schema.parameters: + # Get valid parameter names from tool schema (JSON Schema format) + schema_params = tool.schema.parameters + valid_params = set() + + if isinstance(schema_params, dict) and "properties" in schema_params: + valid_params = set(schema_params["properties"].keys()) + + # Check for invalid parameters + invalid_params = [] + for param_name in list(arguments.keys()): + if param_name == "skip_visual_analysis": + invalid_params.append(param_name) + continue + + # Check if parameter is in the tool's schema + if valid_params and param_name not in valid_params: + invalid_params.append(param_name) + + # Remove invalid parameters + for param in invalid_params: + arguments.pop(param) + logger.debug( + f"Removed parameter '{param}' from {tool.schema.name} " + f"(not in tool schema)" + ) + + return await tool.invoke( + parameters=arguments, + keep_session=True + ) + + +class LLMClient: + """LLMClient class for single round call""" + def __init__( + self, + model: str = "openrouter/anthropic/claude-sonnet-4.5", + enable_thinking: bool = False, + rate_limit_delay: float = 0.0, + max_retries: int = 3, + retry_delay: float = 1.0, + timeout: float = 120.0, + summarize_threshold_chars: int = DEFAULT_SUMMARIZE_THRESHOLD_CHARS, + enable_tool_result_summarization: bool = True, + **litellm_kwargs + ): + """ + Args: + model: LLM model identifier + enable_thinking: Whether to enable extended thinking mode + rate_limit_delay: Minimum delay between API calls in seconds (0 = no delay) + max_retries: Maximum number of retries on rate limit errors + retry_delay: Initial delay between retries in seconds (exponential backoff) + timeout: Request timeout in seconds (default: 120s) + summarize_threshold_chars: If tool result exceeds this threshold, use LLM to + summarize the result (default: 50000 chars ≈ 12.5K tokens) + enable_tool_result_summarization: Whether to enable LLM-based summarization for + large tool results (default: True) + **litellm_kwargs: Additional litellm parameters + """ + self.model = model + self.enable_thinking = enable_thinking + self.rate_limit_delay = rate_limit_delay + self.max_retries = max_retries + self.retry_delay = retry_delay + self.timeout = timeout + self.summarize_threshold_chars = summarize_threshold_chars + self.enable_tool_result_summarization = enable_tool_result_summarization + self.litellm_kwargs = litellm_kwargs + self._logger = Logger.get_logger(__name__) + self._last_call_time = 0.0 + + async def _rate_limit(self): + """Apply rate limiting by adding delay between API calls""" + if self.rate_limit_delay > 0: + current_time = time.time() + time_since_last_call = current_time - self._last_call_time + + if time_since_last_call < self.rate_limit_delay: + sleep_time = self.rate_limit_delay - time_since_last_call + self._logger.debug(f"Rate limiting: waiting {sleep_time:.2f}s before next API call") + await asyncio.sleep(sleep_time) + + self._last_call_time = time.time() + + async def _call_with_retry(self, **completion_kwargs): + """Call LLM with backoff retry on rate limit errors + + Timeout and retry strategy: + - Single call timeout: self.timeout (default 120s) + - Rate limit retry delays: 60s, 90s, 120s + - Total max time: timeout * max_retries + sum(retry_delays) + """ + last_exception = None + + for attempt in range(self.max_retries): + try: + # Add timeout to the completion call + response = await asyncio.wait_for( + litellm.acompletion(**completion_kwargs), + timeout=self.timeout + ) + return response + except asyncio.TimeoutError: + self._logger.error( + f"LLM call timed out after {self.timeout}s (attempt {attempt + 1}/{self.max_retries})" + ) + last_exception = TimeoutError(f"LLM call timed out after {self.timeout}s") + if attempt < self.max_retries - 1: + # Retry on timeout with shorter delay + self._logger.info(f"Retrying after {self.retry_delay}s delay...") + await asyncio.sleep(self.retry_delay) + continue + else: + raise last_exception + except Exception as e: + last_exception = e + error_str = str(e).lower() + + # Check if it's a retryable error + is_rate_limit = any( + keyword in error_str + for keyword in ['rate limit', 'rate_limit', 'too many requests', '429'] + ) + + is_overloaded = any( + keyword in error_str + for keyword in ['overloaded', '500', '502', '503', '504', 'internal server error', 'service unavailable'] + ) + + is_connection_error = any( + keyword in error_str + for keyword in ['cannot connect', 'connection refused', 'connection reset', + 'connectionerror', 'timeout', 'name resolution', + 'temporary failure', 'network unreachable'] + ) + + if attempt < self.max_retries - 1 and (is_rate_limit or is_overloaded or is_connection_error): + if is_rate_limit: + backoff_delay = 60 + (attempt * 30) # 60s, 90s, 120s + error_type = "Rate limit" + elif is_connection_error: + backoff_delay = min(10 * (2 ** attempt), 60) # 10s, 20s, 40s, max 60s + error_type = "Connection" + else: + backoff_delay = min(5 * (2 ** attempt), 60) # 5s, 10s, 20s, max 60s + error_type = "Server overload" + + self._logger.warning( + f"{error_type} error (attempt {attempt + 1}/{self.max_retries}), " + f"waiting {backoff_delay}s before retry..." + ) + await asyncio.sleep(backoff_delay) + continue + else: + # Not a retryable error, or max retries reached + if attempt >= self.max_retries - 1: + self._logger.error(f"Max retries ({self.max_retries}) reached, giving up") + raise + + raise last_exception + + async def complete( + self, + messages: List[Dict] | str, + tools: List[BaseTool] | None = None, + execute_tools: bool = True, + summary_prompt: Optional[str] = None, + tool_result_callback: Optional[callable] = None, + **kwargs + ) -> Dict: + """ + Single-round LLM call with optional tool execution. + + Args: + messages: conversation history (List[Dict] for standard OpenAI format, or str for text format) + tools: BaseTool instance list (must be obtained from GroundingClient and bound to runtime_info) + if None or empty list, only perform conversation, no tools + execute_tools: if LLM returns tool_calls, whether to automatically execute tools + summary_prompt: Optional custom prompt for requesting iteration summary. + If provided, will request summary after tool execution. + If None, no summary will be requested. + tool_result_callback: Optional async callback to process tool results after execution. + Signature: async def callback(result: ToolResult, tool_name: str, tool_call: Dict, backend: str) -> ToolResult + **kwargs: additional parameters for litellm completion + """ + # 1. Process messages + if isinstance(messages, str): + current_messages = [{"role": "user", "content": messages}] + user_task = messages + elif isinstance(messages, list): + current_messages = messages.copy() + # Extract first user message as task for context-aware summarization + user_task = next( + (m.get("content", "") for m in messages if m.get("role") == "user"), + "" + ) + else: + raise ValueError("messages must be List[Dict] or str") + + # 2. prepare base litellm completion kwargs + completion_kwargs = { + "model": kwargs.get("model", self.model), + **self.litellm_kwargs, + } + + # Add thinking/reasoning_effort only if explicitly enabled and not using tools + enable_thinking = kwargs.get("enable_thinking", self.enable_thinking) + + # 3. if tools are provided, add them to the request + llm_tools = None + tool_map = {} # llm_name -> BaseTool + if tools: + llm_tools, tool_map = _prepare_tools_for_llmclient(tools, fmt="openai") + if llm_tools: + completion_kwargs["tools"] = llm_tools + completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto") + # Disable thinking when using tools to avoid format conflicts + enable_thinking = False + self._logger.debug(f"Prepared {len(llm_tools)} tools for LLM") + else: + self._logger.warning("Tools provided but none could be prepared for LLM") + + # Add thinking parameters if enabled + if enable_thinking: + completion_kwargs["reasoning_effort"] = kwargs.get("reasoning_effort", "medium") + + # 4. Apply rate limiting + await self._rate_limit() + + # 5. Call LLM with retry (single round) + completion_kwargs["messages"] = current_messages + response = await self._call_with_retry(**completion_kwargs) + + if not response.choices: + raise ValueError("LLM response has no choices") + + response_message = response.choices[0].message + + # 6. Build assistant message + assistant_message = { + "role": "assistant", + "content": response_message.content or "", + } + + tool_calls = getattr(response_message, 'tool_calls', None) + if tool_calls: + assistant_message["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } + for tc in tool_calls + ] + + # Add assistant message to conversation + current_messages.append(assistant_message) + + # 7. Execute tools if requested + tool_results = [] + if execute_tools and tool_calls and tools: + self._logger.info(f"Executing {len(tool_calls)} tool calls...") + + for tool_call in tool_calls: + tool_name = tool_call.function.name + + # Resolve tool instance: key might differ from model response (e.g. API returns + # "read_file" while we stored "server__read_file" for dedup), so fallback by schema.name + tool_obj = tool_map.get(tool_name) + if tool_obj is None and tool_name: + for _k, _t in tool_map.items(): + if getattr(getattr(_t, "schema", None), "name", None) == tool_name: + tool_obj = _t + break + + backend = None + server_name = None + + if tool_obj: + try: + # Prefer runtime_info if bound + if getattr(tool_obj, 'is_bound', False) and getattr(tool_obj, 'runtime_info', None): + backend = tool_obj.runtime_info.backend.value + server_name = tool_obj.runtime_info.server_name + else: + bt = getattr(tool_obj, 'backend_type', None) + bv = getattr(bt, 'value', None) if bt is not None else None + if bv and bv not in ("not_set",): + backend = bv + except Exception as e: + self._logger.warning(f"Failed to resolve backend for tool '{tool_name}': {e}") + + # Ensure backend is set for recording: API may return different tool name, or + # runtime_info/backend_type can be missing or raise + if backend is None and tool_name: + backend = _infer_backend_from_tool_name(tool_name) + if backend is None: + self._logger.warning( + f"Could not resolve backend for tool '{tool_name}', " + f"recording will be skipped" + ) + + # Log tool execution + try: + if isinstance(tool_call.function.arguments, str): + safe_args_str = tool_call.function.arguments.strip() or "{}" + args = json.loads(safe_args_str) + else: + args = tool_call.function.arguments + + args_str = json.dumps(args, ensure_ascii=False)[:200] + self._logger.info(f"Calling {tool_name} with args: {args_str}") + except: + pass + + if tool_name not in tool_map: + result = ToolResult( + status=ToolStatus.ERROR, + error=f"Tool '{tool_name}' not found" + ) + else: + try: + result = await _execute_tool_call( + tool=tool_map[tool_name], + openai_tool_call={ + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } + ) + + # Apply tool result callback if provided + if tool_result_callback and not result.is_error: + try: + result = await tool_result_callback( + result=result, + tool_name=tool_name, + tool_call=tool_call, + backend=backend + ) + except Exception as e: + self._logger.warning(f"Tool result callback failed for {tool_name}: {e}") + except Exception as e: + result = ToolResult( + status=ToolStatus.ERROR, + error=str(e) + ) + + # Use async version with LLM summarization for large results + tool_message = await _tool_result_to_message_async( + result, + tool_call_id=tool_call.id, + tool_name=tool_name, + task=user_task, + summarize_threshold=self.summarize_threshold_chars, + summarize_model=self.model, + enable_summarization=self.enable_tool_result_summarization + ) + current_messages.append(tool_message) + + # Store result + tool_results.append({ + "tool_call": tool_call, + "result": result, + "message": tool_message, + "backend": backend, + "server_name": server_name, + }) + + self._logger.info(f"Tool execution completed, {len(tool_results)} tools executed") + + # 8. Request summary if provided and tools were executed + iteration_summary = None + + if summary_prompt and tool_results: + self._logger.debug("Requesting iteration summary from LLM") + summary_message = { + "role": "system", + "content": summary_prompt + } + current_messages.append(summary_message) + + # Apply rate limiting before summary call + await self._rate_limit() + + # Call LLM to generate summary (without tools) + summary_kwargs = { + **self.litellm_kwargs, + "model": self.model, + "messages": current_messages, + "tools": [], + "tool_choice": "none", + } + + summary_response = await self._call_with_retry(**summary_kwargs) + + if summary_response.choices: + summary_message = summary_response.choices[0].message + iteration_summary = summary_message.content or "" + + # Add summary response to messages + current_messages.append({ + "role": "assistant", + "content": iteration_summary + }) + + self._logger.debug(f"Generated iteration summary: {iteration_summary[:100]}...") + + # 9. Return single-round result + return { + "message": assistant_message, + "tool_results": tool_results, + "messages": current_messages, + "has_tool_calls": bool(tool_calls), + "iteration_summary": iteration_summary + } + + @staticmethod + def format_messages_to_text(messages: List[Dict]) -> str: + """Format conversation history to readable text (for logging/debugging)""" + formatted = "" + for msg in messages: + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") + formatted += f"[{role}]\n{content}\n\n" + return formatted \ No newline at end of file diff --git a/openspace/local_server/README.md b/openspace/local_server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0e0f8d414615682ce8fbdb73175ee089570e0e5f --- /dev/null +++ b/openspace/local_server/README.md @@ -0,0 +1,121 @@ +# OpenSpace Local Server + +The local server is a **lightweight Flask service** that runs on the host machine and exposes HTTP endpoints for shell execution and GUI automation. It is only needed in **server mode** — most users should use the default **local mode** instead. + +## When to Use Server Mode + +| | Local Mode (default) | Server Mode | +|---|---|---| +| **Setup** | Zero — just run OpenSpace | Start `local_server` first | +| **Use case** | Same-machine development | Remote VMs, sandboxing, multi-machine | +| **Shell** | `asyncio.subprocess` in-process | HTTP → Flask → `subprocess` | +| **GUI** | Direct pyautogui | HTTP → Flask → pyautogui | +| **Network** | None required | HTTP between agent ↔ server | + +Use server mode when: +- **Controlling a remote VM** — the agent runs on your host, the server runs inside the VM +- **Process isolation / sandboxing** — script execution in a separate process +- **Multi-machine deployments** — agent and execution environment on different machines + +## Enable Server Mode + +Set `"mode": "server"` in `openspace/config/config_grounding.json`: + +```jsonc +{ + "shell": { "mode": "server", ... }, // default: "local" + "gui": { "mode": "server", ... } // default: "local" +} +``` + +## Platform-Specific Dependencies + +> [!IMPORTANT] +> Install platform-specific dependencies **on the machine running the server** (not the agent). + +
+macOS + +```bash +pip install pyobjc-core pyobjc-framework-cocoa pyobjc-framework-quartz atomacos +``` + +**Permissions required** (macOS will prompt automatically on first run): +- **Accessibility** (for GUI control) +- **Screen Recording** (for screenshots and video capture) + +> If prompts don't appear, grant manually in System Settings → Privacy & Security. + +
+ +
+Linux + +```bash +pip install python-xlib pyatspi numpy +sudo apt install at-spi2-core python3-tk scrot +``` + +> **Optional:** `wmctrl` (window management), `libx11-dev` + `libxfixes-dev` (cursor in screenshots) + +
+ +
+Windows + +```bash +pip install pywinauto pywin32 PyGetWindow +``` + +
+ +## Launch + +```bash +# Python entry point +python -m openspace.local_server.main --host 127.0.0.1 --port 5000 + +# Or via helper script +./openspace/local_server/run.sh +``` + +Press `Ctrl+C` to stop. + +## Configuration + +Runtime options in `openspace/local_server/config.json`: + +```json +{ + "server": { + "host": "127.0.0.1", + "port": 5000, + "debug": false + } +} +``` + +## Architecture + +- **PlatformAdapter** — abstracts OS-specific primitives (Windows, macOS, Linux) +- **Accessibility Helper** — queries the UI accessibility tree +- **Screenshot Helper** — captures full or partial screenshots (PNG) +- **Recorder** — streams screen recordings for analysis +- **Health / Feature Checker** — validates runtime capabilities and permissions + +## REST Endpoints + +| Path | Method | Description | +|------|--------|-------------| +| `/` | GET | Liveness probe | +| `/platform` | GET | Host OS metadata | +| `/execute` | POST | Execute a PyAutoGUI script fragment | +| `/execute_with_verification` | POST | Execute + verify via template matching | +| `/run_python` | POST | Run Python in sandbox | +| `/run_bash_script` | POST | Run shell script (optional conda activation) | +| `/screenshot` | GET | PNG screenshot (full or ROI) | +| `/cursor_position` | GET | Current mouse coordinates | +| `/screen_size` | GET/POST | Query or set virtual screen resolution | +| `/list_directory` | POST | List directory contents | + +See `main.py` for ~20 additional endpoints. diff --git a/openspace/local_server/__init__.py b/openspace/local_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9850657186eaee2c2f34bbbaa9d14df3796618d8 --- /dev/null +++ b/openspace/local_server/__init__.py @@ -0,0 +1,3 @@ +from .main import app, run_server + +__all__ = ["app", "run_server"] \ No newline at end of file diff --git a/openspace/local_server/config.json b/openspace/local_server/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4961b1f9851d5569d1020e2318b83039225d7724 --- /dev/null +++ b/openspace/local_server/config.json @@ -0,0 +1,8 @@ +{ + "server": { + "host": "127.0.0.1", + "port": 5000, + "debug": false, + "threaded": true + } +} diff --git a/openspace/local_server/feature_checker.py b/openspace/local_server/feature_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..f86a3d8bfdeca83e7fbedba5054472849bc16914 --- /dev/null +++ b/openspace/local_server/feature_checker.py @@ -0,0 +1,261 @@ +import platform +import subprocess +import tempfile +from typing import Dict, Any + +from openspace.utils.logging import Logger +logger = Logger.get_logger(__name__) + +platform_name = platform.system() + + +class FeatureChecker: + def __init__(self, platform_adapter=None, accessibility_helper=None): + self.platform_adapter = platform_adapter + self.accessibility_helper = accessibility_helper + self.platform = platform_name + self._cache = {} + + def check_screenshot_available(self, use_cache: bool = True) -> bool: + if use_cache and 'screenshot' in self._cache: + return self._cache['screenshot'] + + try: + import pyautogui + from PIL import Image + + size = pyautogui.size() + result = size.width > 0 and size.height > 0 + + self._cache['screenshot'] = result + logger.info(f"Screenshot check: {'available' if result else 'unavailable'}") + return result + + except ImportError as e: + logger.warning(f"Screenshot unavailable - missing dependency: {e}") + self._cache['screenshot'] = False + return False + except Exception as e: + logger.error(f"Screenshot check failed: {e}") + self._cache['screenshot'] = False + return False + + def check_shell_available(self, use_cache: bool = True) -> bool: + if use_cache and 'shell' in self._cache: + return self._cache['shell'] + + try: + if self.platform == "Windows": + cmd = ['cmd', '/c', 'echo', 'test'] + else: + cmd = ['echo', 'test'] + + result = subprocess.run( + cmd, + capture_output=True, + timeout=2, + text=True + ) + + available = result.returncode == 0 + self._cache['shell'] = available + logger.info(f"Shell check: {'available' if available else 'unavailable'}") + return available + + except FileNotFoundError as e: + logger.warning(f"Shell check failed - command not found: {e}") + self._cache['shell'] = False + return False + except Exception as e: + logger.error(f"Shell check failed: {e}") + self._cache['shell'] = False + return False + + def check_python_available(self, use_cache: bool = True) -> bool: + if use_cache and 'python' in self._cache: + return self._cache['python'] + + python_commands = [] + if self.platform == "Windows": + python_commands = ['py', 'python', 'python3'] + else: + python_commands = ['python3', 'python'] + + for python_cmd in python_commands: + try: + result = subprocess.run( + [python_cmd, '--version'], + capture_output=True, + timeout=2, + text=True + ) + + if result.returncode == 0: + version = result.stdout.strip() or result.stderr.strip() + self._cache['python'] = True + logger.info(f"Python check: available ({python_cmd} - {version})") + return True + + except FileNotFoundError: + continue + except Exception as e: + logger.debug(f"Error testing {python_cmd}: {e}") + continue + + logger.warning("Python check failed - no valid Python interpreter found") + self._cache['python'] = False + return False + + def check_file_ops_available(self, use_cache: bool = True) -> bool: + if use_cache and 'file_ops' in self._cache: + return self._cache['file_ops'] + + try: + with tempfile.NamedTemporaryFile(mode='w+b', delete=True) as tmp: + test_data = b'test data' + tmp.write(test_data) + tmp.flush() + + tmp.seek(0) + read_data = tmp.read() + + available = read_data == test_data + self._cache['file_ops'] = available + logger.info(f"File operations check: {'available' if available else 'unavailable'}") + return available + + except PermissionError as e: + logger.warning(f"File operations check failed - permission denied: {e}") + self._cache['file_ops'] = False + return False + except Exception as e: + logger.error(f"File operations check failed: {e}") + self._cache['file_ops'] = False + return False + + def check_window_mgmt_available(self, use_cache: bool = True) -> bool: + if use_cache and 'window_mgmt' in self._cache: + return self._cache['window_mgmt'] + + try: + if not self.platform_adapter: + logger.warning("Window management check failed - no platform adapter loaded") + self._cache['window_mgmt'] = False + return False + + required_methods = ['activate_window', 'close_window', 'list_windows'] + available_methods = [ + method for method in required_methods + if hasattr(self.platform_adapter, method) + ] + + available = len(available_methods) > 0 + self._cache['window_mgmt'] = available + + if available: + logger.info(f"Window management check: {'available' if available else 'unavailable'} - supported methods: {', '.join(available_methods)}") + else: + logger.warning(f"Window management check failed - platform adapter missing required methods") + + return available + + except Exception as e: + logger.error(f"Window management check failed: {e}") + self._cache['window_mgmt'] = False + return False + + def check_recording_available(self, use_cache: bool = True) -> bool: + if use_cache and 'recording' in self._cache: + return self._cache['recording'] + + try: + if not self.platform_adapter: + logger.warning("Recording check failed - no platform adapter loaded") + self._cache['recording'] = False + return False + + available = ( + hasattr(self.platform_adapter, 'start_recording') and + hasattr(self.platform_adapter, 'stop_recording') + ) + + self._cache['recording'] = available + logger.info(f"Recording check: {'available' if available else 'unavailable'}") + return available + + except Exception as e: + logger.error(f"Recording check failed: {e}") + self._cache['recording'] = False + return False + + def check_accessibility_available(self, use_cache: bool = True) -> bool: + if use_cache and 'accessibility' in self._cache: + return self._cache['accessibility'] + + try: + if not self.accessibility_helper: + logger.warning("Accessibility check failed - no accessibility helper loaded") + self._cache['accessibility'] = False + return False + + available = self.accessibility_helper.is_available() + self._cache['accessibility'] = available + logger.info(f"Accessibility check: {'available' if available else 'unavailable'}") + return available + + except Exception as e: + logger.error(f"Accessibility check failed: {e}") + self._cache['accessibility'] = False + return False + + def check_platform_adapter_available(self, use_cache: bool = True) -> bool: + if use_cache and 'platform_adapter' in self._cache: + return self._cache['platform_adapter'] + + available = self.platform_adapter is not None + self._cache['platform_adapter'] = available + logger.info(f"Platform adapter check: {'available' if available else 'unavailable'}") + return available + + def check_all_features(self, use_cache: bool = True) -> Dict[str, bool]: + logger.info(f"Checking all features (platform: {self.platform})") + + results = { + 'accessibility': self.check_accessibility_available(use_cache), + 'screenshot': self.check_screenshot_available(use_cache), + 'recording': self.check_recording_available(use_cache), + 'shell': self.check_shell_available(use_cache), + 'python': self.check_python_available(use_cache), + 'file_ops': self.check_file_ops_available(use_cache), + 'window_mgmt': self.check_window_mgmt_available(use_cache), + 'platform_adapter': self.check_platform_adapter_available(use_cache), + } + + available_count = sum(1 for v in results.values() if v) + total_count = len(results) + logger.info(f"Feature check completed: {available_count}/{total_count} features available") + + return results + + def clear_cache(self): + self._cache.clear() + logger.debug("Feature check cache cleared") + + def get_feature_report(self) -> Dict[str, Any]: + results = self.check_all_features() + + return { + 'platform': { + 'system': self.platform, + 'release': platform.release(), + 'version': platform.version(), + 'machine': platform.machine(), + 'processor': platform.processor(), + }, + 'features': results, + 'summary': { + 'total': len(results), + 'available': sum(1 for v in results.values() if v), + 'unavailable': sum(1 for v in results.values() if not v), + } + } \ No newline at end of file diff --git a/openspace/local_server/health_checker.py b/openspace/local_server/health_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..87e11474ad4b895493bd5476e95f7d8518ac9562 --- /dev/null +++ b/openspace/local_server/health_checker.py @@ -0,0 +1,586 @@ +import requests +import os +from pathlib import Path +from typing import Dict, Tuple, Optional +from openspace.utils.logging import Logger +from openspace.local_server.feature_checker import FeatureChecker + +logger = Logger.get_logger(__name__) + +from openspace.utils.display import colorize as _c + + +class HealthStatus: + """Health status""" + def __init__(self, feature_available: bool, endpoint_available: Optional[bool], + endpoint_detail: str = ""): + self.feature_available = feature_available + self.endpoint_available = endpoint_available + self.endpoint_detail = endpoint_detail + + @property + def fully_available(self) -> bool: + """Fully available: feature and endpoint are available""" + return self.feature_available and (self.endpoint_available == True) + + def __str__(self): + if not self.feature_available: + return "Feature N/A" + elif self.endpoint_available is None: + return "Feature OK (endpoint not tested)" + elif self.endpoint_available: + return f"OK ({self.endpoint_detail})" + else: + return f"Endpoint failed: {self.endpoint_detail}" + + +class HealthChecker: + """Health checker with functional testing""" + + def __init__(self, feature_checker: FeatureChecker, + base_url: str = "http://127.0.0.1:5000", + auto_cleanup: bool = True, + test_output_dir: str = None): + self.feature_checker = feature_checker + self.base_url = base_url + self.results = {} + self.auto_cleanup = auto_cleanup + + # set the test output directory + if test_output_dir: + self.test_output_dir = Path(test_output_dir) + else: + current_dir = Path(__file__).parent + self.test_output_dir = current_dir / "temp" + + # create the directory + self.test_output_dir.mkdir(exist_ok=True) + + self.temp_files = [] # Track temporary files for cleanup + + logger.info(f"Health checker initialized. Test output: {self.test_output_dir}, Auto-cleanup: {auto_cleanup}") + + def _get_test_file_path(self, filename: str) -> str: + """Get path for a test file""" + filepath = str(self.test_output_dir / filename) + self._register_temp_file(filepath) + return filepath + + def _register_temp_file(self, filepath: str): + """Register a temporary file for later cleanup""" + if filepath and filepath not in self.temp_files: + self.temp_files.append(filepath) + + def cleanup_temp_files(self): + """Clean up all temporary test files""" + if not self.auto_cleanup: + logger.info(f"Auto-cleanup disabled. Test files kept in: {self.test_output_dir}") + return + + cleaned = 0 + for filepath in self.temp_files: + try: + if os.path.exists(filepath): + os.remove(filepath) + cleaned += 1 + logger.debug(f"Cleaned up: {filepath}") + except Exception as e: + logger.warning(f"Failed to clean up {filepath}: {e}") + + self.temp_files.clear() + + # if the directory is empty, delete it + try: + if self.test_output_dir.exists() and not any(self.test_output_dir.iterdir()): + self.test_output_dir.rmdir() + logger.debug(f"Removed empty directory: {self.test_output_dir}") + except: + pass + + if cleaned > 0: + logger.info(f"Cleaned up {cleaned} test files") + + def check_screenshot(self) -> Tuple[bool, str]: + """Functionally test screenshot - actually take a screenshot and verify""" + # 1. Check feature first + if not self.feature_checker.check_screenshot_available(): + return False, "Feature N/A" + + # 2. Save screenshot to test directory + screenshot_path = self._get_test_file_path("test_screenshot.png") + + try: + response = requests.get(f"{self.base_url}/screenshot", timeout=10) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + # 3. Save to file + with open(screenshot_path, 'wb') as f: + f.write(response.content) + + # 4. Verify it's actually an image + content_type = response.headers.get('Content-Type', '') + if 'image' not in content_type: + return False, f"Invalid content type: {content_type}" + + # 5. Check file size (should be > 1KB) + size_kb = len(response.content) / 1024 + if size_kb < 1: + return False, "Image too small" + + logger.info(f"Screenshot saved: {screenshot_path} ({size_kb:.1f}KB)") + return True, f"OK ({size_kb:.1f}KB)" + + except requests.exceptions.Timeout: + return False, "Timeout" + except Exception as e: + return False, f"Error: {str(e)[:30]}" + + def check_cursor_position(self) -> Tuple[bool, str]: + """Test cursor position""" + if not self.feature_checker.check_screenshot_available(): + return False, "Feature N/A" + + try: + response = requests.get(f"{self.base_url}/cursor_position", timeout=5) + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + if 'x' in data and 'y' in data: + return True, f"({data['x']}, {data['y']})" + return False, "Invalid response" + except Exception as e: + return False, str(e)[:30] + + def check_screen_size(self) -> Tuple[bool, str]: + """Test screen size""" + if not self.feature_checker.check_screenshot_available(): + return False, "Feature N/A" + + try: + response = requests.get(f"{self.base_url}/screen_size", timeout=5) + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + if 'width' in data and 'height' in data: + return True, f"{data['width']}x{data['height']}" + return False, "Invalid response" + except Exception as e: + return False, str(e)[:30] + + def check_shell_command(self) -> Tuple[bool, str]: + """Functionally test shell command execution""" + if not self.feature_checker.check_shell_available(): + return False, "Feature N/A" + + try: + response = requests.post( + f"{self.base_url}/execute", + json={"command": "echo hello_test", "shell": True}, + timeout=5 + ) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + output = data.get('output', '').strip() + + # Verify the command actually executed + if 'hello_test' in output: + return True, "Command executed" + return False, "Command failed" + + except Exception as e: + return False, str(e)[:30] + + def check_python_execution(self) -> Tuple[bool, str]: + """Functionally test Python code execution""" + if not self.feature_checker.check_python_available(): + return False, "Feature N/A" + + try: + test_code = 'print("test_output_123")' + response = requests.post( + f"{self.base_url}/run_python", + json={"code": test_code}, + timeout=5 + ) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + content = data.get('content', '') + + # Verify Python executed correctly + if 'test_output_123' in content: + return True, "Python executed" + return False, "Execution failed" + + except Exception as e: + return False, str(e)[:30] + + def check_bash_script(self) -> Tuple[bool, str]: + """Functionally test Bash script execution""" + if not self.feature_checker.check_shell_available(): + return False, "Feature N/A" + + try: + response = requests.post( + f"{self.base_url}/run_bash_script", + json={"script": "echo bash_test_456"}, + timeout=5 + ) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + output = data.get('output', '') + + if 'bash_test_456' in output: + return True, "Bash executed" + return False, "Execution failed" + + except Exception as e: + return False, str(e)[:30] + + def check_file_operations(self) -> Tuple[bool, str]: + """Test file operations""" + if not self.feature_checker.check_file_ops_available(): + return False, "Feature N/A" + + try: + # Test list directory + response = requests.post( + f"{self.base_url}/list_directory", + json={"path": "."}, + timeout=5 + ) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + if 'items' in data and isinstance(data['items'], list): + return True, f"{len(data['items'])} items" + return False, "Invalid response" + + except Exception as e: + return False, str(e)[:30] + + def check_desktop_path(self) -> Tuple[bool, str]: + """Test desktop path""" + if not self.feature_checker.check_file_ops_available(): + return False, "Feature N/A" + + try: + response = requests.get(f"{self.base_url}/desktop_path", timeout=5) + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + path = data.get('path', '') + if path and os.path.exists(path): + return True, "Path valid" + return False, "Path not found" + except Exception as e: + return False, str(e)[:30] + + def check_window_management(self) -> Tuple[bool, str]: + """Test window management""" + if not self.feature_checker.check_window_mgmt_available(): + return False, "Feature N/A" + + try: + # Just test if endpoint responds (window may not exist) + response = requests.post( + f"{self.base_url}/setup/activate_window", + json={"window_name": "NonExistentWindow"}, + timeout=5 + ) + + # 200 (success), 404 (not found), 501 (not supported) are all acceptable + if response.status_code in [200, 404, 501]: + return True, f"API available" + return False, f"HTTP {response.status_code}" + except Exception as e: + return False, str(e)[:30] + + def check_recording(self) -> Tuple[bool, str]: + """Functionally test recording - actually start and stop recording""" + if not self.feature_checker.check_recording_available(): + return False, "Feature N/A" + + recording_path = self._get_test_file_path("test_recording.mp4") + + try: + # 1. Start recording + response = requests.post(f"{self.base_url}/start_recording", json={}, timeout=10) + + if response.status_code == 501: + return False, "Not supported" + + if response.status_code != 200: + return False, f"Start failed: {response.status_code}" + + # 2. Wait a bit + import time + time.sleep(3.0) # Record for 3 seconds + + # 3. Stop recording + response = requests.post(f"{self.base_url}/end_recording", json={}, timeout=15) + + if response.status_code == 200: + # Save the recording file + with open(recording_path, 'wb') as f: + f.write(response.content) + + size_kb = len(response.content) / 1024 + logger.info(f"Recording saved: {recording_path} ({size_kb:.1f}KB)") + return True, f"OK ({size_kb:.1f}KB)" + else: + return False, f"Stop failed: {response.status_code}" + + except Exception as e: + # Try to stop recording in case of error + try: + requests.post(f"{self.base_url}/end_recording", json={}, timeout=5) + except: + pass + return False, str(e)[:30] + + def check_accessibility(self) -> Tuple[bool, str]: + """Test accessibility tree""" + if not self.feature_checker.check_accessibility_available(): + return False, "Feature N/A" + + try: + response = requests.get(f"{self.base_url}/accessibility?max_depth=1", timeout=10) + + if response.status_code != 200: + return False, f"HTTP {response.status_code}" + + data = response.json() + if 'error' in data: + return False, "Permission denied" + + # Should have some tree structure + if 'platform' in data or 'children' in data: + return True, "Tree available" + return False, "Invalid response" + + except Exception as e: + return False, str(e)[:30] + + def check_health_endpoint(self) -> Tuple[bool, str]: + """Test health check endpoint""" + try: + response = requests.get(f"{self.base_url}/", timeout=5) + if response.status_code == 200: + data = response.json() + if data.get('status') == 'ok': + return True, "OK" + return False, f"HTTP {response.status_code}" + except Exception as e: + return False, str(e)[:30] + + def check_platform_info(self) -> Tuple[bool, str]: + """Test platform info endpoint""" + try: + response = requests.get(f"{self.base_url}/platform", timeout=5) + if response.status_code == 200: + data = response.json() + if 'system' in data: + return True, data['system'] + return False, f"HTTP {response.status_code}" + except Exception as e: + return False, str(e)[:30] + + def check_all(self, test_endpoints: bool = True) -> Dict[str, HealthStatus]: + """ + Check all features with functional testing + + Args: + test_endpoints: Whether to test endpoints (False only checks features) + + Returns: + {Feature name: HealthStatus} + """ + results = {} + + if not test_endpoints: + # Only check features, not endpoints + feature_results = self.feature_checker.check_all_features() + for name, available in feature_results.items(): + results[name] = HealthStatus(available, None, "") + self.results = results + return results + + # Functional tests + test_functions = { + 'Health Check': self.check_health_endpoint, + 'Platform Info': self.check_platform_info, + 'Screenshot': self.check_screenshot, + 'Cursor Position': self.check_cursor_position, + 'Screen Size': self.check_screen_size, + 'Shell Command': self.check_shell_command, + 'Python Execution': self.check_python_execution, + 'Bash Script': self.check_bash_script, + 'File Operations': self.check_file_operations, + 'Desktop Path': self.check_desktop_path, + 'Window Management': self.check_window_management, + 'Recording': self.check_recording, + 'Accessibility': self.check_accessibility, + } + + for name, test_func in test_functions.items(): + success, detail = test_func() + + # Determine feature availability + if detail == "Feature N/A": + feature_available = False + endpoint_available = None + else: + feature_available = True + endpoint_available = success + + results[name] = HealthStatus(feature_available, endpoint_available, detail) + + # Clean up temporary files + self.cleanup_temp_files() + + self.results = results + return results + + def print_results(self, results: Dict[str, HealthStatus] = None, + show_endpoint_details: bool = False): + """Print check results""" + if results is None: + results = self.results + + if not results: + return + + total = len(results) + feature_available = sum(1 for s in results.values() if s.feature_available) + fully_available = sum(1 for s in results.values() if s.fully_available) + + # Categorize + basic = ['Health Check', 'Platform Info'] + + # Basic Features + print() + print(_c(" - Basic", 'c', bold=True)) + basic_items = [] + for name in basic: + if name in results: + status = results[name] + # Use colored dot instead of emoji + if status.fully_available: + icon = _c("●", 'g') + elif not status.feature_available: + icon = _c("●", 'rd') + elif status.endpoint_available is None: + icon = _c("●", 'y') + else: + icon = _c("●", 'y') + + text = _c(name, 'gr' if not status.feature_available else '') + basic_items.append((icon, text, status)) + + # Display in rows of 4 + for i in range(0, len(basic_items), 4): + line_items = [] + for j in range(4): + if i + j < len(basic_items): + icon, text, status = basic_items[i + j] + line_items.append(f"{icon} {text:<15}") + print(" " + " ".join(line_items)) + + # Show details if requested + if show_endpoint_details: + for name in basic: + if name in results: + status = results[name] + print(f" {_c('·', 'gr')} {name}: {_c(str(status), 'gr')}") + + # Advanced Features + print() + print(_c(" - Advanced", 'c', bold=True)) + advanced_items = [] + for name, status in results.items(): + if name not in basic: + # Use colored dot instead of emoji + if status.fully_available: + icon = _c("●", 'g') + elif not status.feature_available: + icon = _c("●", 'rd') + elif status.endpoint_available is None: + icon = _c("●", 'y') + else: + icon = _c("●", 'y') + + text = _c(name, 'gr' if not status.feature_available else '') + advanced_items.append((icon, text, status)) + + # Display in rows of 4 + for i in range(0, len(advanced_items), 4): + line_items = [] + for j in range(4): + if i + j < len(advanced_items): + icon, text, _ = advanced_items[i + j] + line_items.append(f"{icon} {text:<15}") + print(" " + " ".join(line_items)) + + # Show details if requested + if show_endpoint_details: + for name, status in results.items(): + if name not in basic: + print(f" {_c('·', 'gr')} {name}: {_c(str(status), 'gr')}") + + # Summary + from openspace.utils.display import print_separator + print() + print_separator() + print(f" {_c('Summary:', 'c', bold=True)} {_c(str(feature_available) + '/' + str(total), 'g' if feature_available == total else 'y')} features available", end='') + if any(s.endpoint_available is not None for s in results.values()): + print(f", {_c(str(fully_available) + '/' + str(total), 'g' if fully_available == total else 'y')} fully functional") + else: + print() + print_separator() + + # Legend + print(f" {_c('Legend:', 'gr')} {_c('●', 'g')} Available {_c('●', 'y')} Partial/Untested {_c('●', 'rd')} Unavailable") + + # Test files info + if self.temp_files and not self.auto_cleanup: + print() + print(f" {_c('Test files saved:', 'y')} {self.test_output_dir}") + print(f" {_c(str(len(self.temp_files)) + ' file(s) available for inspection', 'gr')}") + + print() + + def get_summary(self) -> dict: + """Get summary""" + if not self.results: + return {} + + total = len(self.results) + feature_available = sum(1 for s in self.results.values() if s.feature_available) + fully_available = sum(1 for s in self.results.values() if s.fully_available) + + return { + 'total': total, + 'feature_available': feature_available, + 'fully_available': fully_available, + 'details': {k: str(v) for k, v in self.results.items()} + } + + def get_simple_features_dict(self) -> Dict[str, bool]: + """Get simple feature dict (for banner display)""" + return self.feature_checker.check_all_features() \ No newline at end of file diff --git a/openspace/local_server/main.py b/openspace/local_server/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec6f45ee54aaf4e5652df7a4257742113819666 --- /dev/null +++ b/openspace/local_server/main.py @@ -0,0 +1,1152 @@ +import os +import platform +import shlex +import subprocess +import signal +import time +import json +import uuid +from datetime import datetime +from flask import Flask, request, jsonify, send_file, abort +import pyautogui +import threading +from io import BytesIO +import tempfile + +from openspace.utils.logging import Logger +from openspace.local_server.utils import AccessibilityHelper, ScreenshotHelper +from openspace.local_server.platform_adapters import get_platform_adapter +from openspace.local_server.health_checker import HealthChecker +from openspace.local_server.feature_checker import FeatureChecker + +platform_name = platform.system() + +app = Flask(__name__) +app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB + +pyautogui.PAUSE = 0 +if platform_name == "Darwin": + pyautogui.DARWIN_CATCH_UP_TIME = 0 + +logger = Logger.get_logger(__name__) + +TIMEOUT = 1800 +recording_process = None + +if platform_name == "Windows": + recording_path = os.path.join(os.environ.get('TEMP', 'C:\\Temp'), 'recording.mp4') +else: + recording_path = "/tmp/recording.mp4" + +accessibility_helper = AccessibilityHelper() +screenshot_helper = ScreenshotHelper() +platform_adapter = get_platform_adapter() + +feature_checker = FeatureChecker( + platform_adapter=platform_adapter, + accessibility_helper=accessibility_helper +) + + +def get_conda_activation_prefix(conda_env: str = None) -> str: + """ + Generate platform-specific conda activation command prefix + + Args: + conda_env: Conda environment name (e.g., 'myenv') + + Returns: + Activation command prefix string, empty if no conda_env + """ + if not conda_env: + return "" + + if platform_name == "Windows": + # Windows: use conda.bat or conda.exe + # Try common conda installation paths + conda_paths = [ + os.path.expandvars("%USERPROFILE%\\miniconda3\\Scripts\\activate.bat"), + os.path.expandvars("%USERPROFILE%\\anaconda3\\Scripts\\activate.bat"), + "C:\\ProgramData\\Miniconda3\\Scripts\\activate.bat", + "C:\\ProgramData\\Anaconda3\\Scripts\\activate.bat", + ] + + # Find first existing conda activate script + activate_script = None + for path in conda_paths: + if os.path.exists(path): + activate_script = path + break + + if activate_script: + return f'call "{activate_script}" {conda_env} && ' + else: + # Fallback: assume conda is in PATH + return f'conda activate {conda_env} && ' + + else: + # Linux/macOS: source conda.sh then activate + conda_paths = [ + os.path.expanduser("~/miniconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/anaconda3/etc/profile.d/conda.sh"), + "/opt/conda/etc/profile.d/conda.sh", + "/usr/local/miniconda3/etc/profile.d/conda.sh", + "/usr/local/anaconda3/etc/profile.d/conda.sh", + ] + + # Find first existing conda.sh + conda_sh = None + for path in conda_paths: + if os.path.exists(path): + conda_sh = path + break + + if conda_sh: + return f'source "{conda_sh}" && conda activate {conda_env} && ' + else: + # Fallback: assume conda is already initialized in shell + return f'conda activate {conda_env} && ' + + +def wrap_script_with_conda(script: str, conda_env: str = None) -> str: + """ + Wrap script with conda activation command. + If conda is not available, returns original script without conda activation. + """ + if not conda_env: + return script + + if platform_name == "Windows": + activation_prefix = get_conda_activation_prefix(conda_env) + return f"{activation_prefix}{script}" + else: + conda_paths = [ + os.path.expanduser("~/miniconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/anaconda3/etc/profile.d/conda.sh"), + os.path.expanduser("~/opt/anaconda3/etc/profile.d/conda.sh"), + "/opt/conda/etc/profile.d/conda.sh", + ] + + conda_sh = None + for path in conda_paths: + if os.path.exists(path): + conda_sh = path + break + + if conda_sh: + # Use bash -i -c to run interactively, or directly source conda.sh + wrapped_script = f"""#!/bin/bash +# Initialize conda +if [ -f "{conda_sh}" ]; then + . "{conda_sh}" + conda activate {conda_env} 2>/dev/null || true +fi + +# Run user script +{script} +""" + return wrapped_script + else: + # Conda not found - log warning and execute script directly without conda + logger.warning(f"Conda environment '{conda_env}' requested but conda not found. Executing with system Python.") + return script + + +health_checker = None + +@app.route('/', methods=['GET']) +def health_check(): + """Health check interface - return features information""" + # Get features from health_checker + if health_checker: + features = health_checker.get_simple_features_dict() + else: + # Initial startup of health_checker may not have been initialized, fallback to feature_checker + features = feature_checker.check_all_features(use_cache=True) + + return jsonify({ + 'status': 'ok', + 'service': 'OpenSpace Desktop Server', + 'version': '1.0.0', + 'platform': platform_name, + 'features': features, + 'timestamp': datetime.now().isoformat() + }) + +@app.route('/platform', methods=['GET']) +def get_platform(): + info = { + 'system': platform_name, + 'release': platform.release(), + 'version': platform.version(), + 'machine': platform.machine(), + 'processor': platform.processor() + } + + if platform_adapter and hasattr(platform_adapter, 'get_system_info'): + info.update(platform_adapter.get_system_info()) + + return jsonify(info) + +@app.route('/execute', methods=['POST']) +@app.route('/setup/execute', methods=['POST']) +def execute_command(): + data = request.json + # The 'command' key in the JSON request should contain the command to be executed. + shell = data.get('shell', False) + command = data.get('command', "" if shell else []) + timeout = data.get('timeout', 120) + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + if isinstance(command, list): + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + try: + if platform_name == "Windows": + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=timeout, + creationflags=subprocess.CREATE_NO_WINDOW, + ) + else: + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=timeout, + ) + + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode + }) + except subprocess.TimeoutExpired: + return jsonify({ + 'status': 'error', + 'message': f'Command timeout after {timeout} seconds' + }), 408 + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/execute_with_verification', methods=['POST']) +@app.route('/setup/execute_with_verification', methods=['POST']) +def execute_command_with_verification(): + """Execute command and verify the result based on provided verification criteria""" + data = request.json + shell = data.get('shell', False) + command = data.get('command', "" if shell else []) + verification = data.get('verification', {}) + max_wait_time = data.get('max_wait_time', 10) # Maximum wait time in seconds + check_interval = data.get('check_interval', 1) # Check interval in seconds + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + if isinstance(command, list): + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + # Execute the main command + try: + if platform_name == "Windows": + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=120, + creationflags=subprocess.CREATE_NO_WINDOW, + ) + else: + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=120, + ) + + # If no verification is needed, return immediately + if not verification: + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode + }) + + # Wait and verify the result + start_time = time.time() + while time.time() - start_time < max_wait_time: + verification_passed = True + + # Check window existence if specified + if 'window_exists' in verification: + window_name = verification['window_exists'] + try: + if platform_name == 'Linux': + wmctrl_result = subprocess.run( + ['wmctrl', '-l'], + capture_output=True, + text=True, + check=True + ) + if window_name.lower() not in wmctrl_result.stdout.lower(): + verification_passed = False + elif platform_adapter: + # Use platform adapter to check window existence + windows = platform_adapter.list_windows() if hasattr(platform_adapter, 'list_windows') else [] + if not any(window_name.lower() in str(w).lower() for w in windows): + verification_passed = False + except: + verification_passed = False + + # Check command execution if specified + if 'command_success' in verification: + verify_cmd = verification['command_success'] + try: + verify_result = subprocess.run( + verify_cmd, + shell=True, + capture_output=True, + text=True, + timeout=5 + ) + if verify_result.returncode != 0: + verification_passed = False + except: + verification_passed = False + + if verification_passed: + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode, + 'verification': 'passed', + 'wait_time': time.time() - start_time + }) + + time.sleep(check_interval) + + # Verification failed + return jsonify({ + 'status': 'verification_failed', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode, + 'verification': 'failed', + 'wait_time': max_wait_time + }), 500 + + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +def _get_machine_architecture() -> str: + """Get the machine architecture, e.g., x86_64, arm64, aarch64, i386, etc. + Returns 'amd' for x86/AMD architectures, 'arm' for ARM architectures, or 'unknown'. + """ + architecture = platform.machine().lower() + if architecture in ['amd32', 'amd64', 'x86', 'x86_64', 'x86-64', 'x64', 'i386', 'i686']: + return 'amd' + elif architecture in ['arm64', 'aarch64', 'aarch32']: + return 'arm' + else: + return 'unknown' + +@app.route('/setup/launch', methods=["POST"]) +def launch_app(): + data = request.json + shell = data.get("shell", False) + command = data.get("command", "" if shell else []) + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + if isinstance(command, list): + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + try: + # ARM architecture compatibility: replace google-chrome with chromium + # ARM64 Chrome is not available yet, can only use Chromium + if isinstance(command, list) and 'google-chrome' in command and _get_machine_architecture() == 'arm': + index = command.index('google-chrome') + command[index] = 'chromium' + logger.info("ARM architecture detected: replacing 'google-chrome' with 'chromium'") + + subprocess.Popen(command, shell=shell) + cmd_str = command if shell else " ".join(command) + logger.info(f"Application launched successfully: {cmd_str}") + return jsonify({ + 'status': 'success', + 'message': f'{cmd_str} launched successfully' + }) + except Exception as e: + logger.error(f"Application launch failed: {str(e)}") + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route("/run_python", methods=['POST']) +def run_python(): + data = request.json + code = data.get('code', None) + timeout = data.get('timeout', 30) + working_dir = data.get('working_dir', None) + env = data.get('env', None) + conda_env = data.get('conda_env', None) + + if not code: + return jsonify({'status': 'error', 'message': 'Code not supplied!'}), 400 + + # Generate unique filename + if platform_name == "Windows": + temp_filename = os.path.join(tempfile.gettempdir(), f"python_exec_{uuid.uuid4().hex}.py") + else: + temp_filename = f"/tmp/python_exec_{uuid.uuid4().hex}.py" + + try: + with open(temp_filename, 'w') as f: + f.write(code) + + # Prepare environment variables + exec_env = os.environ.copy() + if env: + exec_env.update(env) + + # If conda_env is specified, try to use bash/cmd to activate and run + # If conda is not available, fall back to system Python + if conda_env: + activation_cmd = get_conda_activation_prefix(conda_env) + # Check if conda activation command is empty (conda not found) + if not activation_cmd: + logger.warning(f"Conda environment '{conda_env}' requested but conda not found. Using system Python.") + conda_env = None # Disable conda and use default path + + if conda_env and get_conda_activation_prefix(conda_env): + if platform_name == "Windows": + # Windows: use cmd with activation + activation_cmd = get_conda_activation_prefix(conda_env) + full_cmd = f'{activation_cmd}python "{temp_filename}"' + result = subprocess.run( + ['cmd', '/c', full_cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + cwd=working_dir or os.getcwd(), + env=exec_env + ) + else: + # Linux/macOS: use bash with activation + activation_cmd = get_conda_activation_prefix(conda_env) + full_cmd = f'{activation_cmd}python3 "{temp_filename}"' + result = subprocess.run( + ['/bin/bash', '-c', full_cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + cwd=working_dir or os.getcwd(), + env=exec_env + ) + else: + # No conda activation needed + python_cmd = 'python' if platform_name == "Windows" else 'python3' + result = subprocess.run( + [python_cmd, temp_filename], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + cwd=working_dir or os.getcwd(), + env=exec_env + ) + + os.remove(temp_filename) + + output = result.stdout + result.stderr + + return jsonify({ + 'status': 'success' if result.returncode == 0 else 'error', + 'content': output or "Code executed successfully (no output)", + 'returncode': result.returncode + }) + + except subprocess.TimeoutExpired: + if os.path.exists(temp_filename): + os.remove(temp_filename) + return jsonify({ + 'status': 'error', + 'message': f'Execution timeout after {timeout} seconds' + }), 408 + except Exception as e: + if os.path.exists(temp_filename): + os.remove(temp_filename) + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route("/run_bash_script", methods=['POST']) +def run_bash_script(): + data = request.json + script = data.get('script', None) + timeout = data.get('timeout', 30) + working_dir = data.get('working_dir', None) + env = data.get('env', None) + conda_env = data.get('conda_env', None) + + if not script: + return jsonify({'status': 'error', 'message': 'Script not supplied!'}), 400 + + # Generate unique filename + if platform_name == "Windows": + temp_filename = os.path.join(tempfile.gettempdir(), f"bash_exec_{uuid.uuid4().hex}.sh") + else: + temp_filename = f"/tmp/bash_exec_{uuid.uuid4().hex}.sh" + + try: + # Wrap script with conda activation if needed + final_script = wrap_script_with_conda(script, conda_env) + + with open(temp_filename, 'w') as f: + f.write(final_script) + + os.chmod(temp_filename, 0o755) + + if platform_name == "Windows": + shell_cmd = ['bash', temp_filename] + else: + shell_cmd = ['/bin/bash', temp_filename] + + # Prepare environment variables + exec_env = os.environ.copy() + if env: + exec_env.update(env) + + result = subprocess.run( + shell_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=timeout, + cwd=working_dir or os.getcwd(), + env=exec_env + ) + + os.unlink(temp_filename) + + return jsonify({ + 'status': 'success' if result.returncode == 0 else 'error', + 'output': result.stdout, + 'error': "", + 'returncode': result.returncode + }) + + except subprocess.TimeoutExpired: + if os.path.exists(temp_filename): + os.unlink(temp_filename) + return jsonify({ + 'status': 'error', + 'output': f'Script execution timed out after {timeout} seconds', + 'error': "", + 'returncode': -1 + }), 500 + except Exception as e: + if os.path.exists(temp_filename): + try: + os.unlink(temp_filename) + except: + pass + return jsonify({ + 'status': 'error', + 'output': f'Failed to execute script: {str(e)}', + 'error': "", + 'returncode': -1 + }), 500 + +@app.route('/screenshot', methods=['GET']) +def capture_screen_with_cursor(): + """Capture screenshot (including mouse cursor)""" + try: + buf = BytesIO() + tmp_path = os.path.join(tempfile.gettempdir(), f"screenshot_{uuid.uuid4().hex}.png") + if screenshot_helper.capture(tmp_path, with_cursor=True): + with open(tmp_path, 'rb') as f: + buf.write(f.read()) + os.remove(tmp_path) + buf.seek(0) + return send_file(buf, mimetype='image/png') + else: + return jsonify({'status':'error','message':'Screenshot failed'}), 500 + + except Exception as e: + logger.error(f"Screenshot failed: {str(e)}") + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/cursor_position', methods=['GET']) +def get_cursor_position(): + """Get cursor position""" + try: + x, y = screenshot_helper.get_cursor_position() + return jsonify({'x': x, 'y': y, 'status': 'success'}) + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route('/screen_size', methods=['POST', 'GET']) +def get_screen_size(): + """Get screen size""" + try: + width, height = screenshot_helper.get_screen_size() + return jsonify({'width': width, 'height': height, 'status': 'success'}) + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 + +# Accessibility Tree +@app.route("/accessibility", methods=["GET"]) +def get_accessibility_tree(): + """Get accessibility tree""" + try: + max_depth = request.args.get('max_depth', 10, type=int) + tree = accessibility_helper.get_tree(max_depth=max_depth) + return jsonify(tree) + except Exception as e: + logger.error(f"Failed to get accessibility tree: {str(e)}") + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +# File Operations +@app.route('/list_directory', methods=['POST']) +def list_directory(): + """List directory contents""" + data = request.json + path = data.get('path', '.') + + try: + path = os.path.expanduser(path) + items = [] + + for item in os.listdir(path): + item_path = os.path.join(path, item) + items.append({ + 'name': item, + 'is_dir': os.path.isdir(item_path), + 'is_file': os.path.isfile(item_path), + 'size': os.path.getsize(item_path) if os.path.isfile(item_path) else None + }) + + return jsonify({ + 'status': 'success', + 'path': path, + 'items': items + }) + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/file', methods=['POST']) +def file_operation(): + """File operations""" + data = request.json + operation = data.get('operation', 'read') + path = data.get('path') + + if not path: + return jsonify({'status': 'error', 'message': 'Path required'}), 400 + + path = os.path.expanduser(path) + + try: + if operation == 'read': + with open(path, 'r') as f: + content = f.read() + return jsonify({ + 'status': 'success', + 'content': content + }) + elif operation == 'exists': + exists = os.path.exists(path) + return jsonify({ + 'status': 'success', + 'exists': exists + }) + else: + return jsonify({ + 'status': 'error', + 'message': f'Unknown operation: {operation}' + }), 400 + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/desktop_path', methods=['POST', 'GET']) +def get_desktop_path(): + """Get desktop path""" + try: + desktop = os.path.expanduser("~/Desktop") + return jsonify({ + 'status': 'success', + 'path': desktop + }) + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route("/setup/activate_window", methods=['POST']) +def activate_window(): + """Activate window""" + data = request.json + window_name = data.get("window_name") + strict = data.get("strict", False) + by_class_name = data.get("by_class", False) + + if not window_name: + return jsonify({'status': 'error', 'message': 'window_name required'}), 400 + + try: + if platform_adapter and hasattr(platform_adapter, 'activate_window'): + result = platform_adapter.activate_window(window_name, strict=strict) + if result['status'] == 'success': + return jsonify(result) + else: + return jsonify(result), 400 + else: + return jsonify({ + 'status': 'error', + 'message': f'Window activation not supported on {platform_name}' + }), 501 + except Exception as e: + logger.error(f"Window activation failed: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route("/setup/close_window", methods=["POST"]) +def close_window(): + """Close window""" + data = request.json + window_name = data.get("window_name") + strict = data.get("strict", False) + by_class_name = data.get("by_class", False) + + if not window_name: + return jsonify({'status': 'error', 'message': 'window_name required'}), 400 + + try: + if platform_adapter and hasattr(platform_adapter, 'close_window'): + result = platform_adapter.close_window(window_name, strict=strict) + if result['status'] == 'success': + return jsonify(result) + else: + return jsonify(result), 404 + else: + return jsonify({ + 'status': 'error', + 'message': f'Window closing not supported on {platform_name}' + }), 501 + except Exception as e: + logger.error(f"Window closing failed: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route('/window_size', methods=['POST']) +def get_window_size(): + """Get window size""" + try: + width, height = screenshot_helper.get_screen_size() + return jsonify({ + 'status': 'success', + 'width': width, + 'height': height + }) + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route('/wallpaper', methods=['POST']) +@app.route('/setup/change_wallpaper', methods=['POST']) +def set_wallpaper(): + """Set wallpaper""" + data = request.json + image_path = data.get('path') + + if not image_path: + return jsonify({'status': 'error', 'message': 'path required'}), 400 + + try: + if platform_adapter and hasattr(platform_adapter, 'set_wallpaper'): + result = platform_adapter.set_wallpaper(image_path) + if result['status'] == 'success': + return jsonify(result) + else: + return jsonify(result), 400 + else: + return jsonify({ + 'status': 'error', + 'message': f'Wallpaper setting not supported on {platform_name}' + }), 501 + except Exception as e: + logger.error(f"Failed to set wallpaper: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +# Screen Recording +@app.route('/start_recording', methods=['POST']) +def start_recording(): + """Start screen recording (supports Linux, macOS, Windows)""" + global recording_process + + # Check if platform adapter supports recording + if not platform_adapter or not hasattr(platform_adapter, 'start_recording'): + return jsonify({ + 'status': 'error', + 'message': f'Recording not supported on {platform_name}' + }), 501 + + # Check if recording is already in progress + if recording_process and recording_process.poll() is None: + return jsonify({ + 'status': 'error', + 'message': 'Recording is already in progress.' + }), 400 + + # Clean up old recording file + if os.path.exists(recording_path): + try: + os.remove(recording_path) + except OSError as e: + logger.error(f"Cannot delete old recording file: {e}") + + try: + # Use platform adapter to start recording + result = platform_adapter.start_recording(recording_path) + + if result['status'] == 'success': + recording_process = result.get('process') + logger.info("Recording started successfully") + return jsonify({ + 'status': 'success', + 'message': 'Recording started' + }) + else: + logger.error(f"Failed to start recording: {result.get('message', 'Unknown error')}") + return jsonify({ + 'status': 'error', + 'message': result.get('message', 'Failed to start recording') + }), 500 + + except Exception as e: + logger.error(f"Failed to start recording: {str(e)}") + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/end_recording', methods=['POST']) +def end_recording(): + """End screen recording (supports Linux, macOS, Windows)""" + global recording_process + + # Check if recording is in progress + if not recording_process or recording_process.poll() is not None: + recording_process = None + return jsonify({ + 'status': 'error', + 'message': 'No recording in progress' + }), 400 + + try: + # Use platform adapter to stop recording + if platform_adapter and hasattr(platform_adapter, 'stop_recording'): + result = platform_adapter.stop_recording(recording_process) + recording_process = None + + if result['status'] != 'success': + logger.error(f"Failed to stop recording: {result.get('message', 'Unknown error')}") + return jsonify(result), 500 + else: + # Fallback: terminate process directly + recording_process.send_signal(signal.SIGINT) + try: + recording_process.wait(timeout=15) + except subprocess.TimeoutExpired: + logger.warning("ffmpeg not responding, force terminating") + recording_process.kill() + recording_process.wait() + recording_process = None + + # Check if recording file exists + # wait for ffmpeg to write the file header + for _ in range(10): + if os.path.exists(recording_path) and os.path.getsize(recording_path) > 0: + break + time.sleep(0.5) + + if os.path.exists(recording_path) and os.path.getsize(recording_path) > 0: + logger.info("Recording ended, file saved") + return send_file(recording_path, as_attachment=True) + else: + logger.error("Recording file is missing or empty") + return abort(500, description="Recording file is missing or empty") + + except Exception as e: + logger.error(f"Failed to end recording: {str(e)}") + if recording_process: + try: + recording_process.kill() + recording_process.wait() + except: + pass + recording_process = None + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + +@app.route('/terminal', methods=['GET']) +def get_terminal_output(): + """Get terminal output (supports Linux, macOS, Windows)""" + try: + if platform_adapter and hasattr(platform_adapter, 'get_terminal_output'): + output = platform_adapter.get_terminal_output() + if output: + return jsonify({'output': output, 'status': 'success'}) + else: + return jsonify({ + 'status': 'error', + 'message': f'No terminal output available on {platform_name}', + 'platform_note': 'Make sure a terminal window is open and active' + }), 404 + else: + return jsonify({ + 'status': 'error', + 'message': f'Terminal output not supported on {platform_name}' + }), 501 + except Exception as e: + logger.error(f"Failed to get terminal output: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + + +@app.route("/setup/upload", methods=["POST"]) +def upload_file(): + """Upload file""" + if 'file' not in request.files: + return jsonify({'status': 'error', 'message': 'No file provided'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'status': 'error', 'message': 'No file selected'}), 400 + + try: + # Get target path + target_path = request.form.get('path', os.path.expanduser('~/Desktop')) + target_path = os.path.expanduser(target_path) + + # Ensure directory exists + os.makedirs(target_path, exist_ok=True) + + # Save file + file_path = os.path.join(target_path, file.filename) + file.save(file_path) + + logger.info(f"File uploaded successfully: {file_path}") + return jsonify({ + 'status': 'success', + 'path': file_path, + 'message': 'File uploaded successfully' + }) + except Exception as e: + logger.error(f"File upload failed: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route("/setup/download_file", methods=["POST"]) +def download_file(): + """Download file""" + data = request.json + path = data.get('path') + + if not path: + return jsonify({'status': 'error', 'message': 'path required'}), 400 + + try: + path = os.path.expanduser(path) + + if not os.path.exists(path): + return jsonify({'status': 'error', 'message': f'File not found: {path}'}), 404 + + return send_file(path, as_attachment=True) + except Exception as e: + logger.error(f"File download failed: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route("/setup/open_file", methods=['POST']) +def open_file(): + """Open file (using system default application)""" + data = request.json + path = data.get('path') + + if not path: + return jsonify({'status': 'error', 'message': 'path required'}), 400 + + try: + path = os.path.expanduser(path) + + if not os.path.exists(path): + return jsonify({'status': 'error', 'message': f'File not found: {path}'}), 404 + + if platform_name == "Darwin": + subprocess.Popen(['open', path]) + elif platform_name == "Linux": + subprocess.Popen(['xdg-open', path]) + elif platform_name == "Windows": + os.startfile(path) + + logger.info(f"File opened successfully: {path}") + return jsonify({ + 'status': 'success', + 'message': f'File opened: {path}' + }) + except Exception as e: + logger.error(f"File opening failed: {str(e)}") + return jsonify({'status': 'error', 'message': str(e)}), 500 + +def print_banner(host: str = "127.0.0.1", port: int = 5000, debug: bool = False): + """Print startup banner with server information""" + from openspace.utils.display import print_banner as display_banner, print_section, print_separator, colorize + + # STARTUP INFORMATION + display_banner("OpenSpace · Local Server") + + server_url = f"http://{host}:{port}" + + # Server section + info_lines = [ + colorize(server_url, 'g', bold=True), + ] + if host == '0.0.0.0': + info_lines.append(f"{colorize('Listening on all interfaces', 'gr')} {colorize('(0.0.0.0:' + str(port) + ')', 'y')}") + info_lines.append(f"{colorize(platform_name, 'gr')} · {colorize('Debug' if debug else 'Production', 'y' if debug else 'g')}") + + print_section("Server", info_lines) + + print() + print_separator() + print(f" {colorize('Press Ctrl+C to stop', 'gr')}") + print() + +def run_health_check_async(): + """Asynchronous running health check""" + def _run(): + from openspace.utils.display import colorize + time.sleep(2) + + print(colorize("\n - Starting health check...\n", 'c', bold=True)) + + results = health_checker.check_all(test_endpoints=True) + + health_checker.print_results(results, show_endpoint_details=False) + + summary = health_checker.get_summary() + logger.info(f"Health check completed: {summary['fully_available']}/{summary['total']} fully available") + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + +def run_server(host: str = "127.0.0.1", port: int = 5000, debug: bool = False): + """ + Start desktop control server + + Args: + host: Listening address (127.0.0.1 for local, 0.0.0.0 for all interfaces) + port: Listening port + debug: Debug mode (display detailed logs) + """ + global health_checker + + # Initialize health_checker + base_url = f"http://{host if host != '0.0.0.0' else '127.0.0.1'}:{port}" + health_checker = HealthChecker(feature_checker, base_url, auto_cleanup=False) + + print_banner(host, port, debug) + + if not debug: + run_health_check_async() + + app.run(host=host, port=port, debug=debug, threaded=True) + +def main(): + import argparse + from openspace.config.utils import get_config_value + + parser = argparse.ArgumentParser( + description='OpenSpace Local Server - Desktop Control Server' + ) + parser.add_argument('--host', type=str, default='127.0.0.1', + help='Server host (default: 127.0.0.1)') + parser.add_argument('--port', type=int, default=5000, + help='Server port (default: 5000)') + parser.add_argument('--debug', action='store_true', + help='Enable debug mode') + parser.add_argument('--config', type=str, + help='Path to config.json file') + + args = parser.parse_args() + + config_path = args.config + if not config_path: + config_path = os.path.join(os.path.dirname(__file__), 'config.json') + + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + server_config = get_config_value(config, 'server', {}) + + host = args.host if args.host != '127.0.0.1' else get_config_value(server_config, 'host', '127.0.0.1') + port = args.port if args.port != 5000 else get_config_value(server_config, 'port', 5000) + debug = args.debug or get_config_value(server_config, 'debug', False) + + run_server(host=host, port=port, debug=debug) + except Exception as e: + logger.error(f"Failed to load config: {e}") + run_server(host=args.host, port=args.port, debug=args.debug) + else: + run_server(host=args.host, port=args.port, debug=args.debug) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/openspace/local_server/platform_adapters/__init__.py b/openspace/local_server/platform_adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21b585eb83a559de101293dbf4d3b8e68096fe4e --- /dev/null +++ b/openspace/local_server/platform_adapters/__init__.py @@ -0,0 +1,37 @@ +import platform +from typing import Optional, Any + +platform_name = platform.system() + +if platform_name == "Darwin": + try: + from .macos_adapter import MacOSAdapter as PlatformAdapter + ADAPTER_AVAILABLE = True + except ImportError: + PlatformAdapter = None + ADAPTER_AVAILABLE = False +elif platform_name == "Linux": + try: + from .linux_adapter import LinuxAdapter as PlatformAdapter + ADAPTER_AVAILABLE = True + except ImportError: + PlatformAdapter = None + ADAPTER_AVAILABLE = False +elif platform_name == "Windows": + try: + from .windows_adapter import WindowsAdapter as PlatformAdapter + ADAPTER_AVAILABLE = True + except ImportError: + PlatformAdapter = None + ADAPTER_AVAILABLE = False +else: + PlatformAdapter = None + ADAPTER_AVAILABLE = False + +def get_platform_adapter() -> Optional[Any]: + if ADAPTER_AVAILABLE and PlatformAdapter: + return PlatformAdapter() + return None + +__all__ = ["PlatformAdapter", "get_platform_adapter", "ADAPTER_AVAILABLE"] + diff --git a/openspace/local_server/platform_adapters/linux_adapter.py b/openspace/local_server/platform_adapters/linux_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..74fe309d41a43b43a14d8860d47962d4132386ac --- /dev/null +++ b/openspace/local_server/platform_adapters/linux_adapter.py @@ -0,0 +1,585 @@ +import subprocess +import os +from typing import Dict, Any, Optional, List +from openspace.utils.logging import Logger +from PIL import Image +import pyautogui + +try: + import pyatspi + from pyatspi import Accessible, StateType, STATE_SHOWING + import Xlib + from Xlib import display, X + LINUX_LIBS_AVAILABLE = True +except ImportError: + LINUX_LIBS_AVAILABLE = False + +logger = Logger.get_logger(__name__) + + +class LinuxAdapter: + + def __init__(self): + if not LINUX_LIBS_AVAILABLE: + logger.warning("Linux libraries are not fully installed, some features may not be available") + self.available = LINUX_LIBS_AVAILABLE + + def capture_screenshot_with_cursor(self, output_path: str) -> bool: + """ + Use pyautogui + pyxcursor to capture screenshot (including cursor) + + Args: + output_path: Output file path + + Returns: + Whether the screenshot is successful + """ + try: + # Use pyautogui to capture screenshot + screenshot = pyautogui.screenshot() + + # Try to add cursor + try: + # Import pyxcursor (should be in the same directory) + import sys + import os + sys.path.insert(0, os.path.dirname(__file__)) + + from pyxcursor import Xcursor + + cursor_obj = Xcursor() + imgarray = cursor_obj.getCursorImageArrayFast() + cursor_img = Image.fromarray(imgarray) + cursor_x, cursor_y = pyautogui.position() + screenshot.paste(cursor_img, (cursor_x, cursor_y), cursor_img) + logger.info("Linux screenshot successfully (with cursor)") + except Exception as e: + logger.warning(f"Failed to add cursor to screenshot: {e}") + logger.info("Linux screenshot successfully (without cursor)") + + screenshot.save(output_path) + return True + + except Exception as e: + logger.error(f"Linux screenshot failed: {e}") + return False + + def activate_window(self, window_name: str, strict: bool = False, by_class: bool = False) -> Dict[str, Any]: + """ + Activate window (Linux uses wmctrl) + + Args: + window_name: Window name + strict: Whether to strictly match + by_class: Whether to match by class name + + Returns: + Result dictionary + """ + try: + # Build wmctrl command + flags = f"-{'x' if by_class else ''}{'F' if strict else ''}a" + cmd = ["wmctrl", flags, window_name] + + subprocess.run(cmd, check=True, timeout=5) + logger.info(f"Linux window activated successfully: {window_name}") + return {'status': 'success', 'message': 'Window activated'} + + except subprocess.CalledProcessError as e: + logger.warning(f"wmctrl command execution failed: {e}") + return {'status': 'error', 'message': f'Window {window_name} not found or wmctrl failed'} + except FileNotFoundError: + logger.error("wmctrl not installed, please install: sudo apt install wmctrl") + return {'status': 'error', 'message': 'wmctrl not installed'} + except Exception as e: + logger.error(f"Linux window activation failed: {e}") + return {'status': 'error', 'message': str(e)} + + def close_window(self, window_name: str, strict: bool = False, by_class: bool = False) -> Dict[str, Any]: + """ + Close window (Linux uses wmctrl) + + Args: + window_name: Window name + strict: Whether to strictly match + by_class: Whether to match by class name + + Returns: + Result dictionary + """ + try: + # Build wmctrl command + flags = f"-{'x' if by_class else ''}{'F' if strict else ''}c" + cmd = ["wmctrl", flags, window_name] + + subprocess.run(cmd, check=True, timeout=5) + logger.info(f"Linux window closed successfully: {window_name}") + return {'status': 'success', 'message': 'Window closed'} + + except subprocess.CalledProcessError as e: + logger.warning(f"wmctrl command execution failed: {e}") + return {'status': 'error', 'message': f'Window {window_name} not found or wmctrl failed'} + except FileNotFoundError: + logger.error("wmctrl not installed") + return {'status': 'error', 'message': 'wmctrl not installed'} + except Exception as e: + logger.error(f"Linux window close failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_accessibility_tree(self, max_depth: int = 10, max_width: int = 50) -> Dict[str, Any]: + """ + Get Linux accessibility tree (using AT-SPI) + + Args: + max_depth: Maximum depth + max_width: Maximum number of child elements per level + + Returns: + Accessibility tree data + """ + if not LINUX_LIBS_AVAILABLE: + return {'error': 'Linux accessibility libraries not available'} + + try: + # Get desktop root node + desktop = pyatspi.Registry.getDesktop(0) + + # Serialize accessibility tree + tree = self._serialize_atspi_element( + desktop, + depth=0, + max_depth=max_depth, + max_width=max_width + ) + + return { + 'tree': tree, + 'platform': 'Linux' + } + + except Exception as e: + logger.error(f"Linux get accessibility tree failed: {e}") + return {'error': str(e)} + + def _serialize_atspi_element( + self, + element: Accessible, + depth: int = 0, + max_depth: int = 10, + max_width: int = 50 + ) -> Optional[Dict[str, Any]]: + """ + Serialize AT-SPI element to dictionary + + Args: + element: AT-SPI accessible element + depth: Current depth + max_depth: Maximum depth + max_width: Maximum width + + Returns: + Serialized dictionary + """ + if depth > max_depth: + return None + + try: + result = { + 'depth': depth, + 'role': element.getRoleName(), + 'name': element.name, + } + + # Get states + try: + states = element.getState().get_states() + result['states'] = [StateType._enum_lookup[st].split('_', 1)[1].lower() + for st in states if st in StateType._enum_lookup] + except: + result['states'] = [] + + # Get attributes + try: + attributes = element.get_attributes() + if attributes: + result['attributes'] = dict(attributes) + except: + result['attributes'] = {} + + # Get position and size (if visible) + if STATE_SHOWING in element.getState().get_states(): + try: + component = element.queryComponent() + bbox = component.getExtents(pyatspi.XY_SCREEN) + result['position'] = {'x': bbox[0], 'y': bbox[1]} + result['size'] = {'width': bbox[2], 'height': bbox[3]} + except: + pass + + # Get text content + try: + text_obj = element.queryText() + text = text_obj.getText(0, text_obj.characterCount) + if text: + result['text'] = text.replace("\ufffc", "").replace("\ufffd", "") + except: + pass + + # Recursively get child elements + result['children'] = [] + try: + child_count = min(element.childCount, max_width) + for i in range(child_count): + try: + child = element.getChildAtIndex(i) + child_data = self._serialize_atspi_element( + child, + depth + 1, + max_depth, + max_width + ) + if child_data: + result['children'].append(child_data) + except Exception as e: + logger.debug(f"Cannot serialize child element {i}: {e}") + continue + except Exception as e: + logger.debug(f"Cannot get child elements: {e}") + + return result + + except Exception as e: + logger.debug(f"Failed to serialize element (depth={depth}): {e}") + return None + + def get_screen_size(self) -> Dict[str, int]: + """ + Get screen size + + Returns: + Screen size dictionary + """ + try: + if LINUX_LIBS_AVAILABLE: + d = display.Display() + screen = d.screen() + return { + 'width': screen.width_in_pixels, + 'height': screen.height_in_pixels + } + else: + # Use pyautogui as fallback + size = pyautogui.size() + return {'width': size.width, 'height': size.height} + + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return {'width': 1920, 'height': 1080} # Default value + + def list_windows(self) -> List[Dict[str, Any]]: + """ + List all windows + + Returns: + Window list + """ + try: + result = subprocess.run( + ['wmctrl', '-l'], + capture_output=True, + text=True, + check=True + ) + + windows = [] + for line in result.stdout.strip().split('\n'): + if line: + parts = line.split(None, 3) + if len(parts) >= 4: + windows.append({ + 'id': parts[0], + 'desktop': parts[1], + 'hostname': parts[2], + 'title': parts[3] + }) + + return windows + + except FileNotFoundError: + logger.error("wmctrl not installed") + return [] + except Exception as e: + logger.error(f"List windows failed: {e}") + return [] + + def get_terminal_output(self) -> Optional[str]: + """ + Get terminal output (GNOME Terminal) + + Returns: + Terminal output content + """ + if not LINUX_LIBS_AVAILABLE: + return None + + try: + desktop = pyatspi.Registry.getDesktop(0) + + # Find gnome-terminal-server + for app in desktop: + if app.getRoleName() == "application" and app.name == "gnome-terminal-server": + for frame in app: + if frame.getRoleName() == "frame" and frame.getState().contains(pyatspi.STATE_ACTIVE): + # Find terminal component + for component in self._find_terminals(frame): + try: + text_obj = component.queryText() + output = text_obj.getText(0, text_obj.characterCount) + return output.rstrip() if output else None + except: + continue + + return None + + except Exception as e: + logger.error(f"Failed to get terminal output: {e}") + return None + + def _find_terminals(self, element) -> List[Accessible]: + """Recursively find terminal components""" + terminals = [] + try: + if element.getRoleName() == "terminal": + terminals.append(element) + + for i in range(element.childCount): + child = element.getChildAtIndex(i) + terminals.extend(self._find_terminals(child)) + except: + pass + + return terminals + + def set_wallpaper(self, image_path: str) -> Dict[str, Any]: + """ + Set desktop wallpaper (GNOME) + + Args: + image_path: Image path + + Returns: + Result dictionary + """ + try: + image_path = os.path.expanduser(image_path) + image_path = os.path.abspath(image_path) + + if not os.path.exists(image_path): + return {'status': 'error', 'message': f'Image not found: {image_path}'} + + # Use gsettings to set wallpaper (GNOME) + subprocess.run([ + 'gsettings', 'set', + 'org.gnome.desktop.background', + 'picture-uri', + f'file://{image_path}' + ], check=True, timeout=5) + + logger.info(f"Linux wallpaper set successfully: {image_path}") + return {'status': 'success', 'message': 'Wallpaper set successfully'} + + except Exception as e: + logger.error(f"Linux set wallpaper failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_system_info(self) -> Dict[str, Any]: + """ + Get Linux system information + + Returns: + System information dictionary + """ + try: + # Get distribution information + try: + with open('/etc/os-release', 'r') as f: + os_info = {} + for line in f: + if '=' in line: + key, value = line.strip().split('=', 1) + os_info[key] = value.strip('"') + distro = os_info.get('PRETTY_NAME', 'Unknown Linux') + except: + distro = 'Unknown Linux' + + # Get kernel version + kernel = subprocess.run( + ['uname', '-r'], + capture_output=True, + text=True + ).stdout.strip() + + return { + 'platform': 'Linux', + 'distro': distro, + 'kernel': kernel, + 'available': self.available + } + + except Exception as e: + logger.error(f"Failed to get system information: {e}") + return { + 'platform': 'Linux', + 'error': str(e) + } + + def start_recording(self, output_path: str) -> Dict[str, Any]: + try: + try: + subprocess.run(['ffmpeg', '-version'], + capture_output=True, + check=True, + timeout=5) + except (subprocess.CalledProcessError, FileNotFoundError): + return { + 'status': 'error', + 'message': 'ffmpeg not installed. Install with: sudo apt install ffmpeg' + } + + try: + if LINUX_LIBS_AVAILABLE: + from Xlib import display as xdisplay + d = xdisplay.Display() + screen_width = d.screen().width_in_pixels + screen_height = d.screen().height_in_pixels + else: + # use pyautogui as fallback + size = pyautogui.size() + screen_width = size.width + screen_height = size.height + except: + screen_width, screen_height = 1920, 1080 + + command = [ + 'ffmpeg', + '-y', + '-f', 'x11grab', + '-draw_mouse', '1', + '-s', f'{screen_width}x{screen_height}', + '-i', ':0.0', + '-c:v', 'libx264', + '-preset', 'ultrafast', + '-r', '30', + output_path + ] + + process = subprocess.Popen( + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True + ) + + import time + time.sleep(1) + + if process.poll() is not None: + error_output = process.stderr.read() if process.stderr else "Unknown error" + return { + 'status': 'error', + 'message': f'Failed to start recording: {error_output}' + } + + logger.info(f"Linux recording started: {output_path}") + return { + 'status': 'success', + 'message': 'Recording started', + 'process': process + } + + except Exception as e: + logger.error(f"Linux start recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def stop_recording(self, process) -> Dict[str, Any]: + try: + import signal + + if not process or process.poll() is not None: + return { + 'status': 'error', + 'message': 'No recording in progress' + } + + process.send_signal(signal.SIGINT) + + try: + process.wait(timeout=15) + except subprocess.TimeoutExpired: + logger.warning("ffmpeg did not respond to SIGINT, killing process") + process.kill() + process.wait() + + logger.info("Linux recording stopped successfully") + return { + 'status': 'success', + 'message': 'Recording stopped' + } + + except Exception as e: + logger.error(f"Linux stop recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def get_running_applications(self) -> List[Dict[str, str]]: + """ + Get list of all running applications + + Returns: + Application list + """ + try: + import psutil + + apps = [] + seen_names = set() + + for proc in psutil.process_iter(['pid', 'name', 'exe', 'cmdline']): + try: + pinfo = proc.info + name = pinfo['name'] + exe = pinfo['exe'] + + # Skip kernel processes and system daemons + if not exe or name.startswith('['): + continue + + # Skip duplicates + if name in seen_names: + continue + + seen_names.add(name) + + apps.append({ + 'name': name, + 'pid': pinfo['pid'], + 'path': exe or '', + 'cmdline': ' '.join(pinfo.get('cmdline', [])) + }) + + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + return apps + + except ImportError: + logger.warning("psutil not installed, cannot get running applications") + return [] + except Exception as e: + logger.error(f"Failed to get running applications list: {e}") + return [] \ No newline at end of file diff --git a/openspace/local_server/platform_adapters/macos_adapter.py b/openspace/local_server/platform_adapters/macos_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..db7d5f6ab79902dd4144c65c3993336012922d99 --- /dev/null +++ b/openspace/local_server/platform_adapters/macos_adapter.py @@ -0,0 +1,722 @@ +import subprocess +import os +from typing import Dict, Any, Optional, List +from openspace.utils.logging import Logger + +try: + import AppKit + import atomacos + MACOS_LIBS_AVAILABLE = True +except ImportError: + MACOS_LIBS_AVAILABLE = False + +logger = Logger.get_logger(__name__) + +_warning_shown = False + + +class MacOSAdapter: + def __init__(self): + global _warning_shown + if not MACOS_LIBS_AVAILABLE and not _warning_shown: + logger.warning("macOS libraries are not fully installed, some features may not be available") + logger.info("To install missing libraries, run: pip install pyobjc-framework-Cocoa atomacos") + _warning_shown = True + self.available = MACOS_LIBS_AVAILABLE + + def capture_screenshot_with_cursor(self, output_path: str) -> bool: + """ + Capture screenshot with cursor using macOS native screencapture command + + Args: + output_path: Output file path + + Returns: + Whether successful + """ + try: + # -C parameter includes cursor, -x disables sound, -m captures main display + subprocess.run(["screencapture", "-C", "-x", "-m", output_path], check=True) + logger.info(f"macOS screenshot successfully: {output_path}") + return True + except Exception as e: + logger.error(f"macOS screenshot failed: {e}") + return False + + def activate_window(self, window_name: str, strict: bool = False) -> Dict[str, Any]: + """ + Activate window (macOS uses AppleScript) + + Args: + window_name: Window name or application name + strict: Whether to strictly match + + Returns: + Result dictionary + """ + try: + # Try to activate application + script = f''' + tell application "System Events" + set appName to "{window_name}" + try + -- Try to activate application by name + set frontmost of first process whose name is appName to true + return "success" + on error + -- Try to find window by title + set foundWindow to false + repeat with theProcess in (every process whose visible is true) + try + tell theProcess + repeat with theWindow in windows + if name of theWindow contains appName then + set frontmost of theProcess to true + set foundWindow to true + exit repeat + end if + end repeat + end tell + end try + if foundWindow then exit repeat + end repeat + + if foundWindow then + return "success" + else + return "not found" + end if + end try + end tell + ''' + + result = subprocess.run( + ['osascript', '-e', script], + capture_output=True, + text=True, + timeout=10 + ) + + if "success" in result.stdout: + logger.info(f"macOS window activated successfully: {window_name}") + return {'status': 'success', 'message': 'Window activated'} + else: + logger.warning(f"macOS window not found: {window_name}") + return {'status': 'error', 'message': f'Window {window_name} not found'} + + except Exception as e: + logger.error(f"macOS window activation failed: {e}") + return {'status': 'error', 'message': str(e)} + + def close_window(self, window_name: str, strict: bool = False) -> Dict[str, Any]: + """ + Close window or application (macOS uses AppleScript) + + Args: + window_name: Window name or application name + strict: Whether to strictly match + + Returns: + Result dictionary + """ + try: + # Try to exit application + script = f''' + tell application "{window_name}" + quit + end tell + ''' + + subprocess.run(['osascript', '-e', script], check=True, timeout=5) + logger.info(f"macOS window/application closed successfully: {window_name}") + return {'status': 'success', 'message': 'Window/Application closed'} + + except subprocess.TimeoutExpired: + # If timeout, try to force terminate + try: + script_force = f''' + tell application "{window_name}" + quit + end tell + do shell script "killall '{window_name}'" + ''' + subprocess.run(['osascript', '-e', script_force], timeout=5) + logger.info(f"macOS application force closed: {window_name}") + return {'status': 'success', 'message': 'Application force closed'} + except Exception as e2: + logger.error(f"macOS force close failed: {e2}") + return {'status': 'error', 'message': str(e2)} + + except Exception as e: + logger.error(f"macOS close window failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_accessibility_tree(self, max_depth: int = 10) -> Dict[str, Any]: + """ + Get macOS accessibility tree + + Args: + max_depth: Maximum depth + + Returns: + Accessibility tree data + """ + if not MACOS_LIBS_AVAILABLE: + return {'error': 'macOS accessibility libraries not available'} + + try: + # Get frontmost application + workspace = AppKit.NSWorkspace.sharedWorkspace() + active_app = workspace.activeApplication() + + if not active_app: + return {'error': 'No active application'} + + app_name = active_app.get('NSApplicationName', 'Unknown') + bundle_id = active_app.get('NSApplicationBundleIdentifier', '') + + logger.info(f"Getting accessibility tree: {app_name} ({bundle_id})") + + # Use atomacos to get application reference + try: + if bundle_id: + app_ref = atomacos.getAppRefByBundleId(bundle_id) + else: + # If no bundle_id, try to find by name + return {'error': 'Cannot find application without bundle ID'} + + # Serialize accessibility tree + tree = self._serialize_ax_element(app_ref, depth=0, max_depth=max_depth) + + return { + 'app_name': app_name, + 'bundle_id': bundle_id, + 'tree': tree, + 'platform': 'macOS' + } + + except Exception as e: + logger.error(f"Cannot get app reference: {e}") + return { + 'error': f'Cannot get app reference: {e}', + 'app_name': app_name, + 'bundle_id': bundle_id + } + + except Exception as e: + logger.error(f"macOS get accessibility tree failed: {e}") + return {'error': str(e)} + + def _serialize_ax_element(self, element, depth: int = 0, max_depth: int = 10) -> Optional[Dict[str, Any]]: + """ + Serialize macOS accessibility element to dictionary + + Args: + element: AX element + depth: Current depth + max_depth: Maximum depth + + Returns: + Serialized dictionary + """ + if depth > max_depth: + return None + + try: + result = { + 'depth': depth + } + + # Get common attributes + try: + result['role'] = element.AXRole if hasattr(element, 'AXRole') else 'unknown' + except: + result['role'] = 'unknown' + + try: + result['title'] = element.AXTitle if hasattr(element, 'AXTitle') else '' + except: + result['title'] = '' + + try: + result['description'] = element.AXDescription if hasattr(element, 'AXDescription') else '' + except: + result['description'] = '' + + try: + result['value'] = str(element.AXValue) if hasattr(element, 'AXValue') else '' + except: + result['value'] = '' + + try: + result['enabled'] = element.AXEnabled if hasattr(element, 'AXEnabled') else False + except: + result['enabled'] = False + + try: + result['focused'] = element.AXFocused if hasattr(element, 'AXFocused') else False + except: + result['focused'] = False + + # Position and size + try: + if hasattr(element, 'AXPosition'): + pos = element.AXPosition + result['position'] = {'x': pos.x, 'y': pos.y} + except: + pass + + try: + if hasattr(element, 'AXSize'): + size = element.AXSize + result['size'] = {'width': size.width, 'height': size.height} + except: + pass + + # Recursively get child elements (with limit) + result['children'] = [] + try: + if hasattr(element, 'AXChildren') and element.AXChildren: + for i, child in enumerate(element.AXChildren[:30]): # Limit to max 30 child elements + try: + child_data = self._serialize_ax_element(child, depth + 1, max_depth) + if child_data: + result['children'].append(child_data) + except Exception as e: + logger.debug(f"Cannot serialize child element {i}: {e}") + continue + except Exception as e: + logger.debug(f"Cannot get child elements: {e}") + + return result + + except Exception as e: + logger.debug(f"Failed to serialize element (depth={depth}): {e}") + return None + + def get_running_applications(self) -> List[Dict[str, str]]: + """ + Get list of all running applications + + Returns: + Application list + """ + try: + workspace = AppKit.NSWorkspace.sharedWorkspace() + running_apps = workspace.runningApplications() + + apps = [] + for app in running_apps: + if app.activationPolicy() == AppKit.NSApplicationActivationPolicyRegular: + apps.append({ + 'name': app.localizedName() or 'Unknown', + 'bundle_id': app.bundleIdentifier() or '', + 'pid': app.processIdentifier(), + 'active': app.isActive() + }) + + return apps + + except Exception as e: + logger.error(f"Failed to get running applications list: {e}") + return [] + + def set_wallpaper(self, image_path: str) -> Dict[str, Any]: + """ + Set desktop wallpaper + + Args: + image_path: Image path + + Returns: + Result dictionary + """ + try: + image_path = os.path.expanduser(image_path) + + if not os.path.exists(image_path): + return {'status': 'error', 'message': f'Image not found: {image_path}'} + + # Use AppleScript to set wallpaper + script = f''' + tell application "System Events" + tell every desktop + set picture to "{image_path}" + end tell + end tell + ''' + + subprocess.run(['osascript', '-e', script], check=True, timeout=10) + logger.info(f"macOS wallpaper set successfully: {image_path}") + return {'status': 'success', 'message': 'Wallpaper set successfully'} + + except Exception as e: + logger.error(f"macOS set wallpaper failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_system_info(self) -> Dict[str, Any]: + """ + Get macOS system information + + Returns: + System information dictionary + """ + try: + # Get macOS version + version = subprocess.run( + ['sw_vers', '-productVersion'], + capture_output=True, + text=True + ).stdout.strip() + + # Get hardware information + model = subprocess.run( + ['sysctl', '-n', 'hw.model'], + capture_output=True, + text=True + ).stdout.strip() + + return { + 'platform': 'macOS', + 'version': version, + 'model': model, + 'available': self.available + } + + except Exception as e: + logger.error(f"Failed to get system information: {e}") + return { + 'platform': 'macOS', + 'error': str(e) + } + + def _detect_screen_device(self) -> str: + """ + Return the screen device number of avfoundation, like '1:none' + + On macOS, ffmpeg -f avfoundation -list_devices true -i "" will list all devices: + - AVFoundation video devices (usually the camera is [0]) + - AVFoundation audio devices + - The screen capture device usually displays as "Capture screen X", numbered from [1] + """ + try: + probe = subprocess.run( + ['ffmpeg', '-f', 'avfoundation', '-list_devices', 'true', '-i', ''], + stderr=subprocess.PIPE, text=True, timeout=5 + ) + + # Find all "Capture screen" devices + screen_devices = [] + for line in probe.stderr.splitlines(): + # Match lines like "[AVFoundation indev @ 0x...] [1] Capture screen 0" + if 'Capture screen' in line and '[AVFoundation' in line: + # Extract device number from square brackets + import re + # Find pattern like "] [number] Capture screen" + match = re.search(r'\]\s*\[(\d+)\]\s*Capture screen', line) + if match: + device_id = match.group(1) + screen_devices.append(device_id) + logger.info(f"Found screen capture device: {device_id} - {line.strip()}") + + # Use first found screen capture device + if screen_devices: + device = f'{screen_devices[0]}:none' + logger.info(f"Using screen capture device: {device}") + return device + else: + logger.warning("No screen capture device found, using default '1:none'") + return '1:none' # Usually screen capture is device 1 + + except Exception as e: + logger.warning(f"Failed to detect screen device: {e}, using default '1:none'") + return '1:none' + + def start_recording(self, output_path: str) -> Dict[str, Any]: + try: + # Check if libx264 encoder is available + result = subprocess.run( + ['ffmpeg', '-encoders'], + capture_output=True, + text=True, + timeout=5 + ) + has_libx264 = 'libx264' in result.stdout + + # Get screen resolution + try: + if MACOS_LIBS_AVAILABLE: + from AppKit import NSScreen + screen = NSScreen.mainScreen() + frame = screen.frame() + width = int(frame.size.width) + height = int(frame.size.height) + logger.info(f"Screen resolution: {width}x{height}") + else: + width, height = 1920, 1080 + logger.info(f"Using default resolution: {width}x{height}") + except: + width, height = 1920, 1080 + logger.info(f"Using default resolution: {width}x{height}") + + # Detect screen capture device + screen_dev = self._detect_screen_device() + logger.info(f"Screen capture device: {screen_dev}") + + # Build ffmpeg command + command = [ + 'ffmpeg', '-y', + '-f', 'avfoundation', + '-capture_cursor', '1', + '-capture_mouse_clicks', '1', + '-framerate', '30', + '-i', screen_dev, # Use detected screen device + ] + + if has_libx264: + command.extend(['-c:v', 'libx264', '-pix_fmt', 'yuv420p']) + logger.info("Using libx264 encoder") + else: + command.extend(['-c:v', 'mpeg4']) + logger.info("Using mpeg4 encoder") + + command.extend(['-r', '30', output_path]) + + logger.info(f"Starting recording with command: {' '.join(command)}") + + process = subprocess.Popen( + command, + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True + ) + + import time + time.sleep(1.5) # Wait for a longer time to ensure ffmpeg starts + + # Check if process exited early + if process.poll() is not None: + err = process.stderr.read() if process.stderr else "" + logger.error(f"FFmpeg exited early with stderr: {err}") + + if "Operation not permitted" in err or "Screen Recording" in err: + return { + "status": "error", + "message": "Screen-recording permission denied. Please grant permission in System Settings → Privacy & Security → Screen Recording." + } + + # Check if it's a device error + if "Input/output error" in err or "Invalid argument" in err or "does not exist" in err: + return { + "status": "error", + "message": f"Invalid screen capture device. Please ensure screen recording is enabled. Error: {err[:200]}" + } + + error_output = err or "Unknown error" + return { + 'status': 'error', + 'message': f'Failed to start recording: {error_output[:300]}' + } + + logger.info(f"macOS recording started successfully: {output_path}") + return { + 'status': 'success', + 'message': 'Recording started', + 'process': process + } + + except Exception as e: + logger.error(f"macOS start recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def stop_recording(self, process) -> Dict[str, Any]: + try: + import signal + import time + + if not process or process.poll() is not None: + return { + 'status': 'error', + 'message': 'No recording in progress' + } + + try: + process.stdin.write('q') + process.stdin.flush() + logger.info("Sent 'q' command to ffmpeg") + + process.wait(timeout=5) + logger.info("ffmpeg exited gracefully") + time.sleep(0.2) # give ffmpeg time to flush the file + + except subprocess.TimeoutExpired: + logger.warning("ffmpeg did not respond to 'q', trying SIGINT") + + process.send_signal(signal.SIGINT) + try: + process.wait(timeout=20) + logger.info("ffmpeg responded to SIGINT") + except subprocess.TimeoutExpired: + logger.warning("ffmpeg did not respond to SIGINT, killing process") + process.kill() + process.wait() + + except Exception as e: + logger.warning(f"Failed to send 'q': {e}, trying SIGINT") + process.send_signal(signal.SIGINT) + try: + process.wait(timeout=20) + except subprocess.TimeoutExpired: + logger.warning("Killing ffmpeg") + process.kill() + process.wait() + + time.sleep(0.5) + + logger.info("macOS recording stopped successfully") + return { + 'status': 'success', + 'message': 'Recording stopped' + } + + except Exception as e: + logger.error(f"macOS stop recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def list_windows(self) -> List[Dict[str, Any]]: + """ + List all windows + + Returns: + Window list + """ + try: + # Use AppleScript to get window list + script = ''' + tell application "System Events" + set windowList to {} + repeat with theProcess in (every process whose visible is true) + try + set processName to name of theProcess + tell theProcess + repeat with theWindow in windows + try + set windowTitle to name of theWindow + set windowInfo to {processName, windowTitle} + set end of windowList to windowInfo + end try + end repeat + end tell + end try + end repeat + return windowList + end tell + ''' + + result = subprocess.run( + ['osascript', '-e', script], + capture_output=True, + text=True, + timeout=10 + ) + + windows = [] + if result.returncode == 0 and result.stdout: + # Parse AppleScript output: "app1, window1, app2, window2" + output = result.stdout.strip() + if output: + # AppleScript returns comma-separated list + items = [item.strip() for item in output.split(',')] + # Group by pairs (app, window) + for i in range(0, len(items), 2): + if i + 1 < len(items): + windows.append({ + 'app_name': items[i], + 'window_title': items[i + 1] + }) + + return windows + + except Exception as e: + logger.error(f"List windows failed: {e}") + return [] + + def get_terminal_output(self) -> Optional[str]: + """ + Get terminal output (macOS Terminal.app or iTerm2) + + Returns: + Terminal output content + """ + try: + # Try to get Terminal.app output first + script = ''' + tell application "Terminal" + if (count of windows) > 0 then + try + set currentTab to selected tab of front window + set terminalOutput to contents of currentTab + return terminalOutput + on error + return "" + end try + else + return "" + end if + end tell + ''' + + result = subprocess.run( + ['osascript', '-e', script], + capture_output=True, + text=True, + timeout=5 + ) + + if result.returncode == 0 and result.stdout: + output = result.stdout.strip() + if output: + return output + + # Try iTerm2 if Terminal.app failed + iterm_script = ''' + tell application "iTerm" + if (count of windows) > 0 then + try + tell current session of current window + set terminalOutput to contents + return terminalOutput + end tell + on error + return "" + end try + else + return "" + end if + end tell + ''' + + result = subprocess.run( + ['osascript', '-e', iterm_script], + capture_output=True, + text=True, + timeout=5 + ) + + if result.returncode == 0 and result.stdout: + output = result.stdout.strip() + if output: + return output + + return None + + except Exception as e: + logger.error(f"Failed to get terminal output: {e}") + return None \ No newline at end of file diff --git a/openspace/local_server/platform_adapters/pyxcursor.py b/openspace/local_server/platform_adapters/pyxcursor.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe11def391711be5daa759b2ae49c2816ae950f --- /dev/null +++ b/openspace/local_server/platform_adapters/pyxcursor.py @@ -0,0 +1,146 @@ +import os +import ctypes +import ctypes.util +import numpy as np + +# A helper function to convert data from Xlib to byte array. +import struct, array + +# Define ctypes version of XFixesCursorImage structure. +PIXEL_DATA_PTR = ctypes.POINTER(ctypes.c_ulong) +Atom = ctypes.c_ulong + + +class XFixesCursorImage(ctypes.Structure): + """ + See /usr/include/X11/extensions/Xfixes.h + + typedef struct { + short x, y; + unsigned short width, height; + unsigned short xhot, yhot; + unsigned long cursor_serial; + unsigned long *pixels; + if XFIXES_MAJOR >= 2 + Atom atom; /* Version >= 2 only */ + const char *name; /* Version >= 2 only */ + endif + } XFixesCursorImage; + """ + _fields_ = [('x', ctypes.c_short), + ('y', ctypes.c_short), + ('width', ctypes.c_ushort), + ('height', ctypes.c_ushort), + ('xhot', ctypes.c_ushort), + ('yhot', ctypes.c_ushort), + ('cursor_serial', ctypes.c_ulong), + ('pixels', PIXEL_DATA_PTR), + ('atom', Atom), + ('name', ctypes.c_char_p)] + + +class Display(ctypes.Structure): + pass + + +class Xcursor: + display = None + + def __init__(self, display=None): + if not display: + try: + display = os.environ["DISPLAY"].encode("utf-8") + except KeyError: + raise Exception("$DISPLAY not set.") + + # XFixeslib = ctypes.CDLL('libXfixes.so') + XFixes = ctypes.util.find_library("Xfixes") + if not XFixes: + raise Exception("No XFixes library found.") + self.XFixeslib = ctypes.cdll.LoadLibrary(XFixes) + + # xlib = ctypes.CDLL('libX11.so.6') + x11 = ctypes.util.find_library("X11") + if not x11: + raise Exception("No X11 library found.") + self.xlib = ctypes.cdll.LoadLibrary(x11) + + # Define ctypes' version of XFixesGetCursorImage function + XFixesGetCursorImage = self.XFixeslib.XFixesGetCursorImage + XFixesGetCursorImage.restype = ctypes.POINTER(XFixesCursorImage) + XFixesGetCursorImage.argtypes = [ctypes.POINTER(Display)] + self.XFixesGetCursorImage = XFixesGetCursorImage + + XOpenDisplay = self.xlib.XOpenDisplay + XOpenDisplay.restype = ctypes.POINTER(Display) + XOpenDisplay.argtypes = [ctypes.c_char_p] + + if not self.display: + self.display = self.xlib.XOpenDisplay(display) # (display) or (None) + + def argbdata_to_pixdata(self, data, len): + if data == None or len < 1: return None + + # Create byte array + b = array.array('b', b'\x00' * 4 * len) + + offset, i = 0, 0 + while i < len: + argb = data[i] & 0xffffffff + rgba = (argb << 8) | (argb >> 24) + b1 = (rgba >> 24) & 0xff + b2 = (rgba >> 16) & 0xff + b3 = (rgba >> 8) & 0xff + b4 = rgba & 0xff + + struct.pack_into("=BBBB", b, offset, b1, b2, b3, b4) + offset = offset + 4 + i = i + 1 + + return b + + def getCursorImageData(self): + # Call the function. Read data of cursor/mouse-pointer. + cursor_data = self.XFixesGetCursorImage(self.display) + + if not (cursor_data and cursor_data[0]): + raise Exception("Cannot read XFixesGetCursorImage()") + + # Note: cursor_data is a pointer, take cursor_data[0] + return cursor_data[0] + + def getCursorImageArray(self): + data = self.getCursorImageData() + # x, y = data.x, data.y + height, width = data.height, data.width + + bytearr = self.argbdata_to_pixdata(data.pixels, height * width) + + imgarray = np.array(bytearr, dtype=np.uint8) + imgarray = imgarray.reshape(height, width, 4) + del bytearr + + return imgarray + + def getCursorImageArrayFast(self): + data = self.getCursorImageData() + # x, y = data.x, data.y + height, width = data.height, data.width + + bytearr = ctypes.cast(data.pixels, ctypes.POINTER(ctypes.c_ulong * height * width))[0] + imgarray = np.array(bytearray(bytearr)) + imgarray = imgarray.reshape(height, width, 8)[:, :, (0, 1, 2, 3)] + del bytearr + + return imgarray + + def saveImage(self, imgarray, text): + from PIL import Image + img = Image.fromarray(imgarray) + img.save(text) + + +if __name__ == "__main__": + cursor = Xcursor() + imgarray = cursor.getCursorImageArrayFast() + cursor.saveImage(imgarray, 'cursor_image.png') diff --git a/openspace/local_server/platform_adapters/windows_adapter.py b/openspace/local_server/platform_adapters/windows_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..cd09ad74b63f79c50834003048d1d48453073d6f --- /dev/null +++ b/openspace/local_server/platform_adapters/windows_adapter.py @@ -0,0 +1,641 @@ +import os +import ctypes +import subprocess +from typing import Dict, Any, Optional, List +from openspace.utils.logging import Logger +from PIL import Image, ImageGrab + +try: + from pywinauto import Desktop + import win32ui + import win32gui + import win32con + import pygetwindow as gw + WINDOWS_LIBS_AVAILABLE = True +except ImportError: + WINDOWS_LIBS_AVAILABLE = False + +logger = Logger.get_logger(__name__) + + +class WindowsAdapter: + """Windows platform-specific functionality adapter""" + + def __init__(self): + if not WINDOWS_LIBS_AVAILABLE: + logger.warning("Windows libraries are not fully installed, some features may not be available") + self.available = WINDOWS_LIBS_AVAILABLE + + def capture_screenshot_with_cursor(self, output_path: str) -> bool: + """ + Capture screenshot using ImageGrab (including cursor) + + Args: + output_path: Output file path + + Returns: + Whether successful + """ + try: + # Use ImageGrab to capture screenshot + img = ImageGrab.grab(bbox=None, include_layered_windows=True) + + # Try to add cursor + try: + if WINDOWS_LIBS_AVAILABLE: + cursor, hotspot = self._get_cursor() + if cursor: + # Get scaling ratio + ratio = ctypes.windll.shcore.GetScaleFactorForDevice(0) / 100 + pos_win = win32gui.GetCursorPos() + pos = ( + round(pos_win[0] * ratio - hotspot[0]), + round(pos_win[1] * ratio - hotspot[1]) + ) + img.paste(cursor, pos, cursor) + logger.info("Windows screenshot successfully (with cursor)") + else: + logger.info("Windows screenshot successfully (without cursor)") + except Exception as e: + logger.warning(f"Cannot add cursor to screenshot: {e}") + logger.info("Windows screenshot successfully (without cursor)") + + img.save(output_path) + return True + + except Exception as e: + logger.error(f"Windows screenshot failed: {e}") + return False + + def _get_cursor(self) -> tuple: + """ + Get current cursor image and hotspot + + Returns: + (cursor_image, (hotspot_x, hotspot_y)) + """ + try: + hcursor = win32gui.GetCursorInfo()[1] + hdc = win32ui.CreateDCFromHandle(win32gui.GetDC(0)) + hbmp = win32ui.CreateBitmap() + hbmp.CreateCompatibleBitmap(hdc, 36, 36) + hdc_compatible = hdc.CreateCompatibleDC() + hdc_compatible.SelectObject(hbmp) + hdc_compatible.DrawIcon((0, 0), hcursor) + + bmpinfo = hbmp.GetInfo() + bmpstr = hbmp.GetBitmapBits(True) + cursor = Image.frombuffer( + 'RGB', + (bmpinfo['bmWidth'], bmpinfo['bmHeight']), + bmpstr, 'raw', 'BGRX', 0, 1 + ).convert("RGBA") + + win32gui.DestroyIcon(hcursor) + win32gui.DeleteObject(hbmp.GetHandle()) + hdc_compatible.DeleteDC() + + # Make black pixels transparent + pixdata = cursor.load() + width, height = cursor.size + for y in range(height): + for x in range(width): + if pixdata[x, y] == (0, 0, 0, 255): + pixdata[x, y] = (0, 0, 0, 0) + + hotspot = win32gui.GetIconInfo(hcursor)[1:3] + + return (cursor, hotspot) + + except Exception as e: + logger.debug(f"Failed to get cursor image: {e}") + return (None, (0, 0)) + + def activate_window(self, window_name: str, strict: bool = False) -> Dict[str, Any]: + """ + Activate window (Windows uses pygetwindow) + + Args: + window_name: Window title + strict: Whether to strictly match + + Returns: + Result dictionary + """ + if not WINDOWS_LIBS_AVAILABLE: + return {'status': 'error', 'message': 'Windows libraries not available'} + + try: + windows = gw.getWindowsWithTitle(window_name) + + if not windows: + logger.warning(f"Window not found: {window_name}") + return {'status': 'error', 'message': f'Window {window_name} not found'} + + window = None + if strict: + # Strict match + for wnd in windows: + if wnd.title == window_name: + window = wnd + break + if not window: + return {'status': 'error', 'message': f'Window {window_name} not found (strict mode)'} + else: + window = windows[0] + + window.activate() + logger.info(f"Windows window activated successfully: {window_name}") + return {'status': 'success', 'message': 'Window activated'} + + except Exception as e: + logger.error(f"Windows window activation failed: {e}") + return {'status': 'error', 'message': str(e)} + + def close_window(self, window_name: str, strict: bool = False) -> Dict[str, Any]: + """ + Close window (Windows uses pygetwindow) + + Args: + window_name: Window title + strict: Whether to strictly match + + Returns: + Result dictionary + """ + if not WINDOWS_LIBS_AVAILABLE: + return {'status': 'error', 'message': 'Windows libraries not available'} + + try: + windows = gw.getWindowsWithTitle(window_name) + + if not windows: + logger.warning(f"Window not found: {window_name}") + return {'status': 'error', 'message': f'Window {window_name} not found'} + + window = None + if strict: + for wnd in windows: + if wnd.title == window_name: + window = wnd + break + if not window: + return {'status': 'error', 'message': f'Window {window_name} not found (strict mode)'} + else: + window = windows[0] + + window.close() + logger.info(f"Windows window closed successfully: {window_name}") + return {'status': 'success', 'message': 'Window closed'} + + except Exception as e: + logger.error(f"Windows window close failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_accessibility_tree(self, max_depth: int = 10, max_width: int = 50) -> Dict[str, Any]: + """ + Get Windows accessibility tree (using pywinauto) + + Args: + max_depth: Maximum depth + max_width: Maximum number of child elements per level + + Returns: + Accessibility tree data + """ + if not WINDOWS_LIBS_AVAILABLE: + return {'error': 'Windows accessibility libraries not available'} + + try: + # Get desktop + desktop = Desktop(backend="uia") + + # Serialize accessibility tree + tree = self._serialize_uia_element( + desktop, + depth=0, + max_depth=max_depth, + max_width=max_width, + visited=set() + ) + + return { + 'tree': tree, + 'platform': 'Windows' + } + + except Exception as e: + logger.error(f"Windows get accessibility tree failed: {e}") + return {'error': str(e)} + + def _serialize_uia_element( + self, + element, + depth: int = 0, + max_depth: int = 10, + max_width: int = 50, + visited: set = None + ) -> Optional[Dict[str, Any]]: + """ + Serialize Windows UIA element to dictionary + + Args: + element: UIA element + depth: Current depth + max_depth: Maximum depth + max_width: Maximum width + visited: Set of visited elements + + Returns: + Serialized dictionary + """ + if visited is None: + visited = set() + + if depth > max_depth or element in visited: + return None + + visited.add(element) + + try: + result = { + 'depth': depth + } + + # Get basic attributes + try: + result['class_name'] = element.class_name() + except: + result['class_name'] = 'unknown' + + try: + result['name'] = element.window_text() + except: + result['name'] = '' + + # Get states + states = {} + state_methods = [ + 'is_enabled', 'is_visible', 'is_minimized', 'is_maximized', + 'is_focused', 'is_checked', 'is_selected' + ] + + for method_name in state_methods: + if hasattr(element, method_name): + try: + method = getattr(element, method_name) + states[method_name] = method() + except: + pass + + if states: + result['states'] = states + + # Get position and size + try: + rectangle = element.rectangle() + result['position'] = { + 'left': rectangle.left, + 'top': rectangle.top + } + result['size'] = { + 'width': rectangle.width(), + 'height': rectangle.height() + } + except: + pass + + # Recursively get child elements + result['children'] = [] + try: + children = element.children() + for i, child in enumerate(children[:max_width]): + try: + child_data = self._serialize_uia_element( + child, + depth + 1, + max_depth, + max_width, + visited + ) + if child_data: + result['children'].append(child_data) + except Exception as e: + logger.debug(f"Cannot serialize child element {i}: {e}") + continue + except Exception as e: + logger.debug(f"Cannot get child elements: {e}") + + return result + + except Exception as e: + logger.debug(f"Failed to serialize element (depth={depth}): {e}") + return None + + def list_windows(self) -> List[Dict[str, Any]]: + """ + List all windows + + Returns: + Window list + """ + if not WINDOWS_LIBS_AVAILABLE: + return [] + + try: + windows = gw.getAllWindows() + + return [ + { + 'title': win.title, + 'left': win.left, + 'top': win.top, + 'width': win.width, + 'height': win.height, + 'visible': win.visible, + 'active': win.isActive + } + for win in windows + if win.title # Only return windows with titles + ] + + except Exception as e: + logger.error(f"List windows failed: {e}") + return [] + + def set_wallpaper(self, image_path: str) -> Dict[str, Any]: + """ + Set desktop wallpaper + + Args: + image_path: Image path + + Returns: + Result dictionary + """ + try: + image_path = os.path.expanduser(image_path) + image_path = os.path.abspath(image_path) + + if not os.path.exists(image_path): + return {'status': 'error', 'message': f'Image not found: {image_path}'} + + # Use Windows API to set wallpaper + SPI_SETDESKWALLPAPER = 20 + ctypes.windll.user32.SystemParametersInfoW( + SPI_SETDESKWALLPAPER, + 0, + image_path, + 3 # SPIF_UPDATEINIFILE | SPIF_SENDCHANGE + ) + + logger.info(f"Windows wallpaper set successfully: {image_path}") + return {'status': 'success', 'message': 'Wallpaper set successfully'} + + except Exception as e: + logger.error(f"Windows set wallpaper failed: {e}") + return {'status': 'error', 'message': str(e)} + + def get_system_info(self) -> Dict[str, Any]: + """ + Get Windows system information + + Returns: + System information dictionary + """ + try: + import platform as plat + + return { + 'platform': 'Windows', + 'version': plat.version(), + 'release': plat.release(), + 'edition': plat.win32_edition() if hasattr(plat, 'win32_edition') else 'Unknown', + 'available': self.available + } + + except Exception as e: + logger.error(f"Failed to get system information: {e}") + return { + 'platform': 'Windows', + 'error': str(e) + } + + def start_recording(self, output_path: str) -> Dict[str, Any]: + try: + try: + result = subprocess.run(['ffmpeg', '-version'], + capture_output=True, + check=True, + timeout=5, + creationflags=subprocess.CREATE_NO_WINDOW) + except (subprocess.CalledProcessError, FileNotFoundError): + return { + 'status': 'error', + 'message': 'ffmpeg not installed. Download from: https://ffmpeg.org/download.html' + } + try: + user32 = ctypes.windll.user32 + width = user32.GetSystemMetrics(0) # SM_CXSCREEN + height = user32.GetSystemMetrics(1) # SM_CYSCREEN + except: + width, height = 1920, 1080 + + command = [ + 'ffmpeg', + '-y', + '-f', 'gdigrab', + '-draw_mouse', '1', + '-framerate', '30', + '-video_size', f'{width}x{height}', + '-i', 'desktop', + '-c:v', 'libx264', + '-preset', 'ultrafast', + '-pix_fmt', 'yuv420p', + '-r', '30', + output_path + ] + + process = subprocess.Popen( + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True, + creationflags=subprocess.CREATE_NO_WINDOW + ) + + import time + time.sleep(1) + + if process.poll() is not None: + error_output = process.stderr.read() if process.stderr else "Unknown error" + return { + 'status': 'error', + 'message': f'Failed to start recording: {error_output}' + } + + logger.info(f"Windows recording started: {output_path}") + return { + 'status': 'success', + 'message': 'Recording started', + 'process': process + } + + except Exception as e: + logger.error(f"Windows start recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def stop_recording(self, process) -> Dict[str, Any]: + try: + if not process or process.poll() is not None: + return { + 'status': 'error', + 'message': 'No recording in progress' + } + + import signal + try: + process.send_signal(signal.CTRL_C_EVENT) + except: + process.terminate() + + try: + process.wait(timeout=15) + except subprocess.TimeoutExpired: + logger.warning("ffmpeg did not respond, killing process") + process.kill() + process.wait() + + logger.info("Windows recording stopped successfully") + return { + 'status': 'success', + 'message': 'Recording stopped' + } + + except Exception as e: + logger.error(f"Windows stop recording failed: {e}") + return { + 'status': 'error', + 'message': str(e) + } + + def get_running_applications(self) -> List[Dict[str, str]]: + """ + Get list of all running applications + + Returns: + Application list + """ + if not WINDOWS_LIBS_AVAILABLE: + return [] + + try: + import psutil + + apps = [] + seen_names = set() + + for proc in psutil.process_iter(['pid', 'name', 'exe']): + try: + pinfo = proc.info + name = pinfo['name'] + exe = pinfo['exe'] + + # Skip system processes + if not exe or name in ['System', 'Registry', 'svchost.exe', 'csrss.exe']: + continue + + # Skip duplicates + if name in seen_names: + continue + + seen_names.add(name) + + apps.append({ + 'name': name, + 'pid': pinfo['pid'], + 'path': exe or '' + }) + + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + return apps + + except ImportError: + logger.warning("psutil not installed, cannot get running applications") + return [] + except Exception as e: + logger.error(f"Failed to get running applications list: {e}") + return [] + + def get_screen_size(self) -> Dict[str, int]: + """ + Get screen size + + Returns: + Screen size dictionary + """ + try: + user32 = ctypes.windll.user32 + width = user32.GetSystemMetrics(0) # SM_CXSCREEN + height = user32.GetSystemMetrics(1) # SM_CYSCREEN + return {'width': width, 'height': height} + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return {'width': 1920, 'height': 1080} # Default value + + def get_terminal_output(self) -> Optional[str]: + """ + Get terminal output (Windows Command Prompt, PowerShell, or Windows Terminal) + + Note: Due to Windows architecture, getting terminal output is complex. + This method attempts to find active console windows. + + Returns: + Terminal output content (limited functionality on Windows) + """ + try: + # Windows doesn't provide easy access to terminal content like Linux/macOS + # This is a limitation of the Windows platform + # We can try to use PowerShell to get recent command history + + # Try to get PowerShell history + try: + history_path = os.path.expanduser( + '~\\AppData\\Roaming\\Microsoft\\Windows\\PowerShell\\PSReadLine\\ConsoleHost_history.txt' + ) + if os.path.exists(history_path): + with open(history_path, 'r', encoding='utf-8', errors='ignore') as f: + # Get last 50 lines + lines = f.readlines() + recent_history = ''.join(lines[-50:]) + if recent_history: + return f"PowerShell History (last 50 commands):\n{recent_history}" + except Exception as e: + logger.debug(f"Cannot read PowerShell history: {e}") + + # Try to get Command Prompt history using doskey + try: + result = subprocess.run( + ['doskey', '/history'], + capture_output=True, + text=True, + timeout=2, + creationflags=subprocess.CREATE_NO_WINDOW + ) + if result.returncode == 0 and result.stdout: + return f"Command Prompt History:\n{result.stdout}" + except Exception as e: + logger.debug(f"Cannot get Command Prompt history: {e}") + + logger.warning("Windows terminal output is limited - only command history available") + return None + + except Exception as e: + logger.error(f"Failed to get terminal output: {e}") + return None + diff --git a/openspace/local_server/requirements.txt b/openspace/local_server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d67d69152734f32bea5913f716cc2cef69d6ec9 --- /dev/null +++ b/openspace/local_server/requirements.txt @@ -0,0 +1,21 @@ +# Local server dependencies (cross-platform) +flask>=3.1.0 +pyautogui>=0.9.54 +pydantic>=2.12.0 +requests>=2.32.0 + +# # macOS-specific dependencies (local server) +# pyobjc-core>=12.0; sys_platform == 'darwin' +# pyobjc-framework-cocoa>=12.0; sys_platform == 'darwin' +# pyobjc-framework-quartz>=12.0; sys_platform == 'darwin' +# atomacos>=3.2.0; sys_platform == 'darwin' + +# # Linux-specific dependencies (local server) +# python-xlib>=0.33; sys_platform == 'linux' +# pyatspi>=2.38.0; sys_platform == 'linux' +# numpy>=1.24.0; sys_platform == 'linux' + +# # Windows-specific dependencies (local server) +# pywinauto>=0.6.8; sys_platform == 'win32' +# pywin32>=306; sys_platform == 'win32' +# PyGetWindow>=0.0.9; sys_platform == 'win32' \ No newline at end of file diff --git a/openspace/local_server/run.sh b/openspace/local_server/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..cce314d58782b3046f32ee846d176bb271a180c4 --- /dev/null +++ b/openspace/local_server/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PROJECT_ROOT="$( cd "$SCRIPT_DIR/../.." && pwd )" + +# Check Python +if ! command -v python3 &> /dev/null; then + echo "Error: python3 not installed" + exit 1 +fi + +# Check if dependencies are installed +if ! python3 -c "import flask" &> /dev/null; then + echo "Installing dependencies..." + pip3 install -q -r "$SCRIPT_DIR/requirements.txt" || { + echo "Failed to install dependencies" + exit 1 + } +fi + +# Set PYTHONPATH and start server +export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" +cd "$PROJECT_ROOT" +python3 -m openspace.local_server.main \ No newline at end of file diff --git a/openspace/local_server/utils/__init__.py b/openspace/local_server/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6009b673dce5c9c3227ec1316765b124528feebc --- /dev/null +++ b/openspace/local_server/utils/__init__.py @@ -0,0 +1,4 @@ +from .accessibility import AccessibilityHelper +from .screenshot import ScreenshotHelper + +__all__ = ["AccessibilityHelper", "ScreenshotHelper"] \ No newline at end of file diff --git a/openspace/local_server/utils/accessibility.py b/openspace/local_server/utils/accessibility.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4a176335b0e3fe35c7e66f6627f2cc704cfb58 --- /dev/null +++ b/openspace/local_server/utils/accessibility.py @@ -0,0 +1,147 @@ +import platform +from openspace.utils.logging import Logger +from typing import Dict, Any, Optional + +logger = Logger.get_logger(__name__) + +platform_name = platform.system() + + +class AccessibilityHelper: + def __init__(self): + self.platform = platform_name + self.adapter = None + + try: + if platform_name == "Darwin": + from ..platform_adapters.macos_adapter import MacOSAdapter + self.adapter = MacOSAdapter() + elif platform_name == "Linux": + from ..platform_adapters.linux_adapter import LinuxAdapter + self.adapter = LinuxAdapter() + elif platform_name == "Windows": + from ..platform_adapters.windows_adapter import WindowsAdapter + self.adapter = WindowsAdapter() + except ImportError as e: + logger.warning(f"Failed to import platform adapter: {e}") + + def get_tree(self, max_depth: int = 10) -> Dict[str, Any]: + if not self.adapter: + return { + 'error': f'No adapter available for {self.platform}', + 'platform': self.platform + } + + try: + return self.adapter.get_accessibility_tree(max_depth=max_depth) + except Exception as e: + logger.error(f"Failed to get accessibility tree: {e}") + return { + 'error': str(e), + 'platform': self.platform + } + + def is_available(self) -> bool: + return self.adapter is not None and hasattr(self.adapter, 'available') and self.adapter.available + + def find_element_by_name(self, tree: Dict[str, Any], name: str) -> Optional[Dict[str, Any]]: + if not tree or 'tree' not in tree: + return None + + return self._search_tree(tree['tree'], 'name', name) + + def find_element_by_role(self, tree: Dict[str, Any], role: str) -> Optional[Dict[str, Any]]: + if not tree or 'tree' not in tree: + return None + + return self._search_tree(tree['tree'], 'role', role) + + def _search_tree(self, node: Dict[str, Any], key: str, value: str) -> Optional[Dict[str, Any]]: + if not node: + return None + + # Check current node + if key in node and node[key] == value: + return node + + # Recursively search child nodes + if 'children' in node: + for child in node['children']: + result = self._search_tree(child, key, value) + if result: + return result + + return None + + def flatten_tree(self, tree: Dict[str, Any]) -> list: + if not tree or 'tree' not in tree: + return [] + + result = [] + self._flatten_node(tree['tree'], result) + return result + + def _flatten_node(self, node: Dict[str, Any], result: list): + """Recursively flatten nodes""" + if not node: + return + + # Add current node (remove children) + node_copy = {k: v for k, v in node.items() if k != 'children'} + result.append(node_copy) + + # Recursively process child nodes + if 'children' in node: + for child in node['children']: + self._flatten_node(child, result) + + def get_visible_elements(self, tree: Dict[str, Any]) -> list: + all_elements = self.flatten_tree(tree) + + visible = [] + for element in all_elements: + if self.platform == "Linux": + if 'states' in element and 'showing' in element.get('states', []): + visible.append(element) + elif self.platform == "Darwin": + if element.get('enabled', False): + visible.append(element) + elif self.platform == "Windows": + if element.get('states', {}).get('is_visible', False): + visible.append(element) + + return visible + + def get_clickable_elements(self, tree: Dict[str, Any]) -> list: + all_elements = self.flatten_tree(tree) + + clickable_roles = [ + 'button', 'push-button', 'toggle-button', 'radio-button', + 'link', 'menu-item', 'AXButton', 'AXLink', 'AXMenuItem' + ] + + clickable = [] + for element in all_elements: + role = element.get('role', '').lower() + if any(cr in role for cr in clickable_roles): + clickable.append(element) + + return clickable + + def get_statistics(self, tree: Dict[str, Any]) -> Dict[str, Any]: + all_elements = self.flatten_tree(tree) + + # Count roles + roles = {} + for element in all_elements: + role = element.get('role', 'unknown') + roles[role] = roles.get(role, 0) + 1 + + return { + 'total_elements': len(all_elements), + 'visible_elements': len(self.get_visible_elements(tree)), + 'clickable_elements': len(self.get_clickable_elements(tree)), + 'roles': roles, + 'platform': self.platform + } + diff --git a/openspace/local_server/utils/screenshot.py b/openspace/local_server/utils/screenshot.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0b2263388a053235ecc9d469705015bf833449 --- /dev/null +++ b/openspace/local_server/utils/screenshot.py @@ -0,0 +1,255 @@ +import platform +import os +import logging +from typing import Optional, Tuple +from PIL import Image +import pyautogui + +logger = logging.getLogger(__name__) + +platform_name = platform.system() + + +class ScreenshotHelper: + def __init__(self): + self.platform = platform_name + self.adapter = None + + try: + if platform_name == "Darwin": + from ..platform_adapters.macos_adapter import MacOSAdapter + self.adapter = MacOSAdapter() + elif platform_name == "Linux": + from ..platform_adapters.linux_adapter import LinuxAdapter + self.adapter = LinuxAdapter() + elif platform_name == "Windows": + from ..platform_adapters.windows_adapter import WindowsAdapter + self.adapter = WindowsAdapter() + except ImportError as e: + logger.warning(f"Failed to import platform adapter: {e}") + + def capture(self, output_path: str, with_cursor: bool = True) -> bool: + try: + # Ensure directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + if with_cursor and self.adapter: + # Use platform-specific method to capture screenshot (with cursor) + return self.adapter.capture_screenshot_with_cursor(output_path) + else: + # Use pyautogui to capture screenshot (without cursor) + screenshot = pyautogui.screenshot() + screenshot.save(output_path) + logger.info(f"Screenshot successfully (without cursor): {output_path}") + return True + + except Exception as e: + logger.error(f"Screenshot failed: {e}") + return False + + def capture_region( + self, + output_path: str, + x: int, + y: int, + width: int, + height: int + ) -> bool: + """ + Capture specified screen region + + Args: + output_path: Output path + x: Starting x coordinate + y: Starting y coordinate + width: Width + height: Height + + Returns: + Whether successful + """ + try: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + screenshot = pyautogui.screenshot(region=(x, y, width, height)) + screenshot.save(output_path) + logger.info(f"Region screenshot successfully: {output_path}") + return True + + except Exception as e: + logger.error(f"Region screenshot failed: {e}") + return False + + def get_screen_size(self) -> Tuple[int, int]: + """ + Get screen size + + Returns: + (width, height) + """ + try: + size = pyautogui.size() + return (size.width, size.height) + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return (1920, 1080) # Default value + + def get_cursor_position(self) -> Tuple[int, int]: + """ + Get cursor position + + Returns: + (x, y) + """ + try: + pos = pyautogui.position() + return (pos.x, pos.y) + except Exception as e: + logger.error(f"Failed to get cursor position: {e}") + return (0, 0) + + def capture_to_base64(self, with_cursor: bool = True) -> Optional[str]: + """ + Capture screenshot and convert to base64 + + Args: + with_cursor: Whether to include cursor + + Returns: + Base64 encoded image string + """ + import tempfile + import base64 + + try: + # Create temporary file + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: + tmp_path = tmp.name + + # Capture screenshot + if self.capture(tmp_path, with_cursor): + # Read and encode + with open(tmp_path, 'rb') as f: + img_data = f.read() + img_base64 = base64.b64encode(img_data).decode('utf-8') + + # Delete temporary file + os.remove(tmp_path) + + return img_base64 + else: + if os.path.exists(tmp_path): + os.remove(tmp_path) + return None + + except Exception as e: + logger.error(f"Failed to convert screenshot to base64: {e}") + return None + + def compare_screenshots(self, path1: str, path2: str) -> float: + """ + Compare similarity between two screenshots + + Args: + path1: First image path + path2: Second image path + + Returns: + Similarity (0-1), 1 means identical + """ + try: + from PIL import ImageChops + import math + import operator + from functools import reduce + + img1 = Image.open(path1) + img2 = Image.open(path2) + + # Ensure same size + if img1.size != img2.size: + # Resize to same size + img2 = img2.resize(img1.size) + + # Calculate difference + diff = ImageChops.difference(img1, img2) + + # Calculate statistics + stat = diff.histogram() + sum_of_squares = reduce( + operator.add, + map(lambda h, i: h * (i ** 2), stat, range(len(stat))) + ) + + # Calculate RMS + rms = math.sqrt(sum_of_squares / float(img1.size[0] * img1.size[1])) + + # Normalize to 0-1, RMS max value is approximately 441 (for RGB) + similarity = 1 - (rms / 441.0) + + return max(0, min(1, similarity)) + + except Exception as e: + logger.error(f"Failed to compare screenshots: {e}") + return 0.0 + + def annotate_screenshot( + self, + input_path: str, + output_path: str, + annotations: list + ) -> bool: + """ + Add annotations to screenshot + + Args: + input_path: Input image path + output_path: Output image path + annotations: List of annotations, each annotation is a dict: + {'type': 'rectangle'/'text', 'x': int, 'y': int, + 'width': int, 'height': int, 'text': str, 'color': tuple} + + Returns: + Whether successful + """ + try: + from PIL import ImageDraw, ImageFont + + img = Image.open(input_path) + draw = ImageDraw.Draw(img) + + for annotation in annotations: + ann_type = annotation.get('type', 'rectangle') + color = annotation.get('color', (255, 0, 0)) + + if ann_type == 'rectangle': + x = annotation.get('x', 0) + y = annotation.get('y', 0) + width = annotation.get('width', 100) + height = annotation.get('height', 100) + + draw.rectangle( + [(x, y), (x + width, y + height)], + outline=color, + width=2 + ) + + elif ann_type == 'text': + x = annotation.get('x', 0) + y = annotation.get('y', 0) + text = annotation.get('text', '') + + try: + font = ImageFont.truetype("Arial.ttf", 20) + except: + font = ImageFont.load_default() + + draw.text((x, y), text, fill=color, font=font) + + img.save(output_path) + logger.info(f"Annotated screenshot successfully: {output_path}") + return True + + except Exception as e: + logger.error(f"Failed to annotate screenshot: {e}") + return False \ No newline at end of file diff --git a/openspace/mcp_server.py b/openspace/mcp_server.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3428ff5512f704a1d692abc37cf3be8a9fe476 --- /dev/null +++ b/openspace/mcp_server.py @@ -0,0 +1,866 @@ +"""OpenSpace MCP Server + +Exposes the following tools to MCP clients: + execute_task — Delegate a task (auto-registers skills, auto-searches, auto-evolves) + search_skills — Standalone search across local & cloud skills + fix_skill — Manually fix a broken skill (FIX only; DERIVED/CAPTURED via execute_task) + upload_skill — Upload a local skill to cloud (pre-saved metadata, bot decides visibility) + +Usage: + python -m openspace.mcp_server # stdio (default) + python -m openspace.mcp_server --transport sse # SSE on port 8080 + python -m openspace.mcp_server --port 9090 # SSE on custom port + +Environment variables: see ``openspace/host_detection/`` and ``openspace/cloud/auth.py``. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +import os +import sys +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class _MCPSafeStdout: + """Stdout wrapper: binary (.buffer) → real stdout, text (.write) → stderr.""" + + def __init__(self, real_stdout, stderr): + self._real = real_stdout + self._stderr = stderr + + @property + def buffer(self): + return self._real.buffer + + def fileno(self): + return self._real.fileno() + + def write(self, s): + return self._stderr.write(s) + + def writelines(self, lines): + return self._stderr.writelines(lines) + + def flush(self): + self._stderr.flush() + self._real.flush() + + def isatty(self): + return self._stderr.isatty() + + @property + def encoding(self): + return self._stderr.encoding + + @property + def errors(self): + return self._stderr.errors + + @property + def closed(self): + return self._stderr.closed + + def readable(self): + return False + + def writable(self): + return True + + def seekable(self): + return False + + def __getattr__(self, name): + return getattr(self._stderr, name) + +_LOG_DIR = Path(__file__).resolve().parent.parent / "logs" +_LOG_DIR.mkdir(parents=True, exist_ok=True) + +_real_stdout = sys.stdout + +# Windows pipe buffers are small. When using stdio MCP transport, +# the parent process only reads stdout for MCP messages and does NOT +# drain stderr. Heavy log/print output during execute_task fills the stderr +# pipe buffer, blocking this process on write() → deadlock → timeout. +# Redirect stderr to a log file on Windows to prevent this. +if os.name == "nt": + _stderr_file = open( + _LOG_DIR / "mcp_stderr.log", "a", encoding="utf-8", buffering=1 + ) + sys.stderr = _stderr_file + +sys.stdout = _MCPSafeStdout(_real_stdout, sys.stderr) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler(_LOG_DIR / "mcp_server.log")], +) +logger = logging.getLogger("openspace.mcp_server") + +from mcp.server.fastmcp import FastMCP + +_fastmcp_kwargs: dict = {} +try: + if "description" in inspect.signature(FastMCP.__init__).parameters: + _fastmcp_kwargs["description"] = ( + "OpenSpace: Unite the Agents. Evolve the Mind. Rebuild the World." + ) +except (TypeError, ValueError): + pass + +mcp = FastMCP("OpenSpace", **_fastmcp_kwargs) + +_openspace_instance = None +_openspace_lock = asyncio.Lock() +_standalone_store = None + +# Internal state: tracks bot skill directories already registered this session. +_registered_skill_dirs: set = set() + +_UPLOAD_META_FILENAME = ".upload_meta.json" + + +async def _get_openspace(): + """Lazy-initialise the OpenSpace engine.""" + global _openspace_instance + if _openspace_instance is not None and _openspace_instance.is_initialized(): + return _openspace_instance + + async with _openspace_lock: + if _openspace_instance is not None and _openspace_instance.is_initialized(): + return _openspace_instance + + logger.info("Initializing OpenSpace engine ...") + from openspace.tool_layer import OpenSpace, OpenSpaceConfig + from openspace.host_detection import build_llm_kwargs, build_grounding_config_path + + env_model = os.environ.get("OPENSPACE_MODEL", "") + workspace = os.environ.get("OPENSPACE_WORKSPACE") + max_iter = int(os.environ.get("OPENSPACE_MAX_ITERATIONS", "20")) + enable_rec = os.environ.get("OPENSPACE_ENABLE_RECORDING", "true").lower() in ("true", "1", "yes") + + backend_scope_raw = os.environ.get("OPENSPACE_BACKEND_SCOPE") + backend_scope = ( + [b.strip() for b in backend_scope_raw.split(",") if b.strip()] + if backend_scope_raw else None + ) + + config_path = build_grounding_config_path() + model, llm_kwargs = build_llm_kwargs(env_model) + + _pkg_root = str(Path(__file__).resolve().parent.parent) + recording_base = workspace or _pkg_root + recording_log_dir = str(Path(recording_base) / "logs" / "recordings") + + config = OpenSpaceConfig( + llm_model=model, + llm_kwargs=llm_kwargs, + workspace_dir=workspace, + grounding_max_iterations=max_iter, + enable_recording=enable_rec, + recording_backends=["shell"] if enable_rec else None, # ["shell", "mcp", "web"] if enable_rec else None + recording_log_dir=recording_log_dir, + backend_scope=backend_scope, + grounding_config_path=config_path, + ) + + _openspace_instance = OpenSpace(config=config) + await _openspace_instance.initialize() + logger.info("OpenSpace engine ready (model=%s).", model) + + # Auto-register host bot skill directories from env (set once by human) + host_skill_dirs_raw = os.environ.get("OPENSPACE_HOST_SKILL_DIRS", "") + if host_skill_dirs_raw: + dirs = [d.strip() for d in host_skill_dirs_raw.split(",") if d.strip()] + if dirs: + await _auto_register_skill_dirs(dirs) + logger.info("Auto-registered host skill dirs from OPENSPACE_HOST_SKILL_DIRS: %s", dirs) + + return _openspace_instance + + +def _get_store(): + """Get SkillStore — reuses OpenSpace's internal instance when available.""" + global _standalone_store + if _openspace_instance and _openspace_instance.is_initialized(): + internal = getattr(_openspace_instance, "_skill_store", None) + if internal and not internal._closed: + return internal + if _standalone_store is None or _standalone_store._closed: + from openspace.skill_engine import SkillStore + _standalone_store = SkillStore() + return _standalone_store + + +def _get_cloud_client(): + """Get a OpenSpaceClient instance (raises CloudError if not configured).""" + from openspace.cloud.auth import get_openspace_auth + from openspace.cloud.client import OpenSpaceClient + auth_headers, api_base = get_openspace_auth() + return OpenSpaceClient(auth_headers, api_base) + + +def _write_upload_meta(skill_dir: Path, info: Dict[str, Any]) -> None: + """Write ``.upload_meta.json`` so ``upload_skill`` can read pre-saved metadata. + + Called after evolution (execute_task auto-evolve or fix_skill). + The bot then only needs to provide ``skill_dir`` + ``visibility`` + when uploading — everything else is pre-filled. + """ + meta = { + "origin": info.get("origin", "imported"), + "parent_skill_ids": info.get("parent_skill_ids", []), + "change_summary": info.get("change_summary", ""), + "created_by": info.get("created_by", "openspace"), + "tags": info.get("tags", []), + } + meta_path = skill_dir / _UPLOAD_META_FILENAME + try: + meta_path.write_text( + json.dumps(meta, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + logger.debug(f"Wrote upload metadata to {meta_path}") + except Exception as e: + logger.warning(f"Failed to write upload metadata: {e}") + + +def _read_upload_meta(skill_dir: Path) -> Dict[str, Any]: + """Read upload metadata with three-tier fallback. + + Resolution order: + 1. ``.upload_meta.json`` sidecar file (written right after evolution) + 2. SkillStore DB lookup by path (long-term persistence) + 3. Empty dict (caller applies defaults) + + This ensures metadata survives even if the sidecar file is deleted + or the user comes back to upload much later. + """ + # Tier 1: sidecar file + meta_path = skill_dir / _UPLOAD_META_FILENAME + if meta_path.exists(): + try: + data = json.loads(meta_path.read_text(encoding="utf-8")) + if data: + return data + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to read upload metadata file: {e}") + + # Tier 2: DB lookup + try: + store = _get_store() + rec = store.load_record_by_path(str(skill_dir)) + if rec: + logger.debug(f"Upload metadata resolved from DB for {skill_dir}") + return { + "origin": rec.lineage.origin.value, + "parent_skill_ids": rec.lineage.parent_skill_ids, + "change_summary": rec.lineage.change_summary, + "created_by": rec.lineage.created_by or "", + "tags": rec.tags, + } + except Exception as e: + logger.debug(f"DB upload metadata lookup failed: {e}") + + return {} + + +async def _auto_register_skill_dirs(skill_dirs: List[str]) -> int: + """Register bot skill directories into OpenSpace's SkillRegistry + DB. + + Called automatically by ``execute_task`` on every invocation. Directories + are re-scanned each time so that skills created by the host bot since the last call are discovered immediately. + """ + global _registered_skill_dirs + + valid_dirs = [Path(d) for d in skill_dirs if Path(d).is_dir()] + if not valid_dirs: + return 0 + + openspace = await _get_openspace() + registry = openspace._skill_registry + if not registry: + logger.warning("_auto_register_skill_dirs: SkillRegistry not initialized") + return 0 + + added = registry.discover_from_dirs(valid_dirs) + + db_created = 0 + if added: + store = _get_store() + db_created = await store.sync_from_registry(added) + + is_first = any(d not in _registered_skill_dirs for d in skill_dirs) + for d in skill_dirs: + _registered_skill_dirs.add(d) + + if added: + action = "Auto-registered" if is_first else "Re-scanned & found" + logger.info( + f"{action} {len(added)} skill(s) from {len(valid_dirs)} dir(s), " + f"{db_created} new DB record(s)" + ) + return len(added) + + +async def _cloud_search_and_import(task: str, limit: int = 8) -> List[Dict[str, Any]]: + """Search cloud for skills relevant to *task* and auto-import top hits. + + This is **stage 1** of a two-stage pipeline: + Stage 1 (here): cloud BM25+embedding → pick top-N to import locally. + Stage 2 (tool_layer): local BM25 + LLM → select from ALL local skills + (including ones just imported) for injection. + + Stage 1 intentionally imports more than will be used (default: 8) so + that stage 2 has a larger pool to choose from. The two BM25 passes + are NOT redundant — stage 1 filters thousands of cloud candidates down + to a manageable import set; stage 2 makes the final task-specific choice. + """ + try: + from openspace.cloud.search import ( + SkillSearchEngine, build_cloud_candidates, + ) + from openspace.cloud.embedding import generate_embedding, resolve_embedding_api + + client = _get_cloud_client() + embedding_api_key, _ = resolve_embedding_api() + has_embedding = bool(embedding_api_key) + + items = await asyncio.to_thread( + client.fetch_metadata, include_embedding=has_embedding, limit=200, + ) + if not items: + return [] + + candidates = build_cloud_candidates(items) + if not candidates: + return [] + + query_embedding: Optional[List[float]] = None + if has_embedding: + query_embedding = await asyncio.to_thread( + generate_embedding, task, + ) + + engine = SkillSearchEngine() + results = engine.search(task, candidates, query_embedding=query_embedding, limit=limit * 2) + + cloud_hits = [ + r for r in results + if r.get("source") == "cloud" + and r.get("visibility", "public") == "public" + and r.get("skill_id") + ][:limit] + + import_results: List[Dict[str, Any]] = [] + for hit in cloud_hits: + try: + imp = await _do_import_cloud_skill(skill_id=hit["skill_id"]) + import_results.append({ + "skill_id": hit["skill_id"], + "name": hit.get("name", ""), + "import_status": imp.get("status", "error"), + "local_path": imp.get("local_path", ""), + }) + except Exception as e: + logger.warning(f"Cloud import failed for {hit['skill_id']}: {e}") + + if import_results: + logger.info(f"Cloud search imported {len(import_results)} skill(s)") + return import_results + + except Exception as e: + logger.warning(f"_cloud_search_and_import failed (non-fatal): {e}") + return [] + + +async def _do_import_cloud_skill(skill_id: str, target_dir: Optional[str] = None) -> Dict[str, Any]: + """Download a cloud skill and register it locally.""" + client = _get_cloud_client() + + if target_dir: + base_dir = Path(target_dir) + else: + host_ws = ( + os.environ.get("NANOBOT_WORKSPACE") + or os.environ.get("OPENCLAW_STATE_DIR") + ) + if host_ws: + base_dir = Path(host_ws) / "skills" + base_dir.mkdir(parents=True, exist_ok=True) + else: + openspace = await _get_openspace() + skill_cfg = openspace._grounding_config.skills if openspace._grounding_config else None + if skill_cfg and skill_cfg.skill_dirs: + base_dir = Path(skill_cfg.skill_dirs[0]) + else: + base_dir = Path(__file__).resolve().parent / "skills" + + result = await asyncio.to_thread(client.import_skill, skill_id, base_dir) + + skill_dir = Path(result.get("local_path", "")) + if skill_dir.exists(): + openspace = await _get_openspace() + registry = openspace._skill_registry + if registry: + meta = registry.register_skill_dir(skill_dir) + if meta: + store = _get_store() + await store.sync_from_registry([meta]) + result["registered"] = True + + result.setdefault("registered", False) + return result + + +def _format_task_result(result: Dict[str, Any]) -> Dict[str, Any]: + """Format an OpenSpace execution result for MCP transport.""" + tool_execs = result.get("tool_executions", []) + tool_summary = [ + { + "tool": te.get("tool_name", te.get("tool", "")), + "status": te.get("status", ""), + "error": te.get("error", "")[:200] if te.get("error") else None, + } + for te in tool_execs[:20] + ] + + output: Dict[str, Any] = { + "status": result.get("status", "unknown"), + "response": result.get("response", ""), + "execution_time": round(result.get("execution_time", 0), 2), + "iterations": result.get("iterations", 0), + "skills_used": result.get("skills_used", []), + "task_id": result.get("task_id", ""), + "tool_call_count": len(tool_execs), + "tool_summary": tool_summary, + } + if result.get("warning"): + output["warning"] = result["warning"] + + # Format evolved_skills with skill_dir and upload_ready flag + raw_evolved = result.get("evolved_skills", []) + if raw_evolved: + formatted_evolved = [] + for es in raw_evolved: + skill_path = es.get("path", "") + skill_dir = str(Path(skill_path).parent) if skill_path else "" + formatted_evolved.append({ + "skill_dir": skill_dir, + "name": es.get("name", ""), + "origin": es.get("origin", ""), + "change_summary": es.get("change_summary", ""), + "upload_ready": bool(skill_dir), + }) + output["evolved_skills"] = formatted_evolved + # Prompt the calling agent to upload evolved skills + names = [es["name"] for es in formatted_evolved if es.get("upload_ready")] + if names: + output["action_required"] = ( + f"OpenSpace auto-evolved {len(names)} skill(s): {', '.join(names)}. " + f"Follow the 'When to upload' rules in your delegate-task skill to " + f"decide visibility, then upload via upload_skill. " + f"Tell the user what you evolved and what you uploaded." + ) + + return output + + +def _json_ok(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _json_error(error: Any, **extra) -> str: + return json.dumps({"error": str(error), **extra}, ensure_ascii=False) + + +# MCP Tools (4 tools) +@mcp.tool() +async def execute_task( + task: str, + workspace_dir: str | None = None, + max_iterations: int | None = None, + skill_dirs: list[str] | None = None, + search_scope: str = "all", +) -> str: + """Execute a task with OpenSpace's full grounding engine. + + OpenSpace will: + 1. Auto-register bot skills from skill_dirs (if provided) + 2. Search for relevant skills (scope controls local vs cloud+local) + 3. Attempt skill-guided execution → fallback to pure tools + 4. Auto-analyze → auto-evolve (FIX/DERIVED/CAPTURED) if needed + + If skills are auto-evolved, the response includes ``evolved_skills`` + with ``upload_ready: true``. Call ``upload_skill`` with just the + ``skill_dir`` + ``visibility`` to upload — metadata is pre-saved. + + Note: This call blocks until the task completes (may take minutes). + Set MCP client tool-call timeout ≥ 600 seconds. + + Args: + task: The task instruction (natural language). + workspace_dir: Working directory. Defaults to OPENSPACE_WORKSPACE env. + max_iterations: Max agent iterations (default: 20). + skill_dirs: Bot's skill directories to auto-register so OpenSpace + can select and track them. Directories are re-scanned + on every call to discover skills created since the last + invocation. + search_scope: Skill search scope before execution. + "all" (default) — local + cloud; falls back to local + if no API key is configured. + "local" — local SkillRegistry only (fast, no cloud). + """ + try: + openspace = await _get_openspace() + + # Re-scan host skill directories (from env) to pick up skills + # created by the host bot since the last call. + host_skill_dirs_raw = os.environ.get("OPENSPACE_HOST_SKILL_DIRS", "") + if host_skill_dirs_raw: + env_dirs = [d.strip() for d in host_skill_dirs_raw.split(",") if d.strip()] + if env_dirs: + await _auto_register_skill_dirs(env_dirs) + + # Auto-register bot skill directories (from call parameter) + if skill_dirs: + await _auto_register_skill_dirs(skill_dirs) + + # Cloud search + import (if requested) + imported_skills: List[Dict[str, Any]] = [] + if search_scope == "all": + imported_skills = await _cloud_search_and_import(task) + + # Execute + result = await openspace.execute( + task=task, + workspace_dir=workspace_dir, + max_iterations=max_iterations, + ) + + # Write .upload_meta.json for each evolved skill + for es in result.get("evolved_skills", []): + skill_path = es.get("path", "") + if skill_path: + _write_upload_meta(Path(skill_path).parent, es) + + formatted = _format_task_result(result) + if imported_skills: + formatted["imported_skills"] = imported_skills + return _json_ok(formatted) + + except Exception as e: + logger.error(f"execute_task failed: {e}", exc_info=True) + return _json_error(e, status="error", traceback=traceback.format_exc(limit=5)) + + +@mcp.tool() +async def search_skills( + query: str, + source: str = "all", + limit: int = 20, + auto_import: bool = True, +) -> str: + """Search skills across local registry and cloud community. + + Standalone search for browsing / discovery. Use this when the bot + wants to find available skills, then decide whether to handle the + task locally or delegate to ``execute_task``. + + **Scope difference from execute_task**: + - ``search_skills`` returns results to the bot for decision-making. + - ``execute_task``'s internal search feeds directly into execution + (the bot never sees the search results). + + Uses hybrid ranking: BM25 → embedding re-rank → lexical boost. + Embedding requires OPENAI_API_KEY; falls back to lexical-only without it. + + Args: + query: Search query text (natural language or keywords). + source: "all" (cloud + local), "local", or "cloud". Default: "all". + limit: Maximum results to return (default: 20). + auto_import: Auto-download top public cloud skills (default: True). + """ + try: + from openspace.cloud.search import hybrid_search_skills + + q = query.strip() + if not q: + return _json_ok({"results": [], "count": 0}) + + # Re-scan host skill directories so newly created skills are searchable. + local_skills = None + store = None + if source in ("all", "local"): + openspace = await _get_openspace() + + host_skill_dirs_raw = os.environ.get("OPENSPACE_HOST_SKILL_DIRS", "") + if host_skill_dirs_raw: + env_dirs = [d.strip() for d in host_skill_dirs_raw.split(",") if d.strip()] + if env_dirs: + await _auto_register_skill_dirs(env_dirs) + + registry = openspace._skill_registry + if registry: + local_skills = registry.list_skills() + store = _get_store() + + results = await hybrid_search_skills( + query=q, + local_skills=local_skills, + store=store, + source=source, + limit=limit, + ) + + _AUTO_IMPORT_MAX = 3 + import_summary: List[Dict[str, Any]] = [] + if auto_import: + cloud_results = [ + r for r in results + if r.get("source") == "cloud" + and r.get("visibility", "public") == "public" + and r.get("skill_id") + ][:_AUTO_IMPORT_MAX] + for cr in cloud_results: + try: + imp_result = await _do_import_cloud_skill(skill_id=cr["skill_id"]) + status = imp_result.get("status", "error") + import_summary.append({ + "skill_id": cr["skill_id"], + "name": cr.get("name", ""), + "import_status": status, + "local_path": imp_result.get("local_path", ""), + }) + if status in ("success", "already_exists"): + cr["auto_imported"] = True + cr["local_path"] = imp_result.get("local_path", "") + except Exception as imp_err: + logger.warning(f"auto_import failed for {cr['skill_id']}: {imp_err}") + import_summary.append({ + "skill_id": cr["skill_id"], + "import_status": "error", + "error": str(imp_err), + }) + + output: Dict[str, Any] = {"results": results, "count": len(results)} + if import_summary: + output["auto_import_summary"] = import_summary + return _json_ok(output) + + except Exception as e: + logger.error(f"search_skills failed: {e}", exc_info=True) + return _json_error(e) + + +@mcp.tool() +async def fix_skill( + skill_dir: str, + direction: str, +) -> str: + """Manually fix a broken skill. + + This is the **only** manual evolution entry point. DERIVED and + CAPTURED evolutions are triggered automatically by ``execute_task`` + (they need a task to run). Use ``fix_skill`` when: + + - A skill's instructions are wrong or outdated + - The bot knows exactly which skill is broken and what to fix + - Auto-evolution inside ``execute_task`` didn't catch the issue + + The skill does NOT need to be pre-registered in OpenSpace — + provide the skill directory path and OpenSpace will register it + automatically before fixing. + + After fixing, the new skill is saved locally and ``.upload_meta.json`` + is pre-written. Call ``upload_skill`` with just ``skill_dir`` + + ``visibility`` to upload. + + Args: + skill_dir: Path to the broken skill directory (must contain SKILL.md). + direction: What's broken and how to fix it. Be specific: + e.g. "The API endpoint changed from v1 to v2" or + "Add retry logic for HTTP 429 rate limit errors". + """ + try: + from openspace.skill_engine.types import EvolutionSuggestion, EvolutionType + from openspace.skill_engine.evolver import EvolutionContext, EvolutionTrigger + + if not direction: + return _json_error("direction is required — describe what to fix.") + + skill_path = Path(skill_dir) + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + return _json_error(f"SKILL.md not found in {skill_dir}") + + openspace = await _get_openspace() + registry = openspace._skill_registry + if not registry: + return _json_error("SkillRegistry not initialized") + if not openspace._skill_evolver: + return _json_error("Skill evolution is not enabled") + + # Step 1: Register the skill (idempotent) + meta = registry.register_skill_dir(skill_path) + if not meta: + return _json_error(f"Failed to register skill from {skill_dir}") + + store = _get_store() + await store.sync_from_registry([meta]) + + # Step 2: Load record + content + rec = store.load_record(meta.skill_id) + if not rec: + return _json_error(f"Failed to load skill record for {meta.skill_id}") + + evolver = openspace._skill_evolver + content = evolver._load_skill_content(rec) + if not content: + return _json_error(f"Cannot load content for skill: {meta.skill_id}") + + # Step 3: Run FIX evolution + recent = store.load_analyses(skill_id=meta.skill_id, limit=5) + + ctx = EvolutionContext( + trigger=EvolutionTrigger.ANALYSIS, + suggestion=EvolutionSuggestion( + evolution_type=EvolutionType.FIX, + target_skill_ids=[meta.skill_id], + direction=direction, + ), + skill_records=[rec], + skill_contents=[content], + skill_dirs=[skill_path], + recent_analyses=recent, + available_tools=evolver._available_tools, + ) + + logger.info(f"fix_skill: {meta.skill_id} — {direction[:100]}") + new_record = await evolver.evolve(ctx) + + if not new_record: + return _json_ok({ + "status": "failed", + "error": "Evolution did not produce a new skill.", + }) + + # Step 4: Write .upload_meta.json + new_skill_dir = Path(new_record.path).parent if new_record.path else skill_path + _write_upload_meta(new_skill_dir, { + "origin": new_record.lineage.origin.value, + "parent_skill_ids": new_record.lineage.parent_skill_ids, + "change_summary": new_record.lineage.change_summary, + "created_by": new_record.lineage.created_by or "openspace", + "tags": new_record.tags, + }) + + return _json_ok({ + "status": "success", + "new_skill": { + "skill_dir": str(new_skill_dir), + "name": new_record.name, + "origin": new_record.lineage.origin.value, + "change_summary": new_record.lineage.change_summary, + "upload_ready": True, + }, + }) + + except Exception as e: + logger.error(f"fix_skill failed: {e}", exc_info=True) + return _json_error(e, status="error", traceback=traceback.format_exc(limit=5)) + + +@mcp.tool() +async def upload_skill( + skill_dir: str, + visibility: str = "public", + origin: str | None = None, + parent_skill_ids: list[str] | None = None, + tags: list[str] | None = None, + created_by: str | None = None, + change_summary: str | None = None, +) -> str: + """Upload a local skill to the cloud. + + For evolved skills (from ``execute_task`` or ``fix_skill``), most + metadata is **pre-saved** in ``.upload_meta.json``. The bot only + needs to provide: + + - ``skill_dir`` — path to the skill directory + - ``visibility`` — "public" or "private" + + All other parameters are optional overrides. If omitted, pre-saved + values are used. If no pre-saved values exist, sensible defaults + are applied. + + **origin + parent_skill_ids constraints** (enforced by cloud): + - imported / captured → parent_skill_ids must be empty + - derived → at least 1 parent + - fixed → exactly 1 parent + + Args: + skill_dir: Path to skill directory (must contain SKILL.md). + visibility: "public" or "private". This is the one thing the + bot MUST decide. + origin: Override origin. Default: from .upload_meta.json or "imported". + parent_skill_ids: Override parents. Default: from .upload_meta.json. + tags: Override tags. Default: from .upload_meta.json. + created_by: Override creator. Default: from .upload_meta.json. + change_summary: Override summary. Default: from .upload_meta.json. + """ + try: + skill_path = Path(skill_dir) + if not (skill_path / "SKILL.md").exists(): + return _json_error(f"SKILL.md not found in {skill_dir}") + + # Read pre-saved metadata (written by execute_task/fix_skill) + meta = _read_upload_meta(skill_path) + + # Merge: explicit params override pre-saved values + final_origin = origin if origin is not None else meta.get("origin", "imported") + final_parents = parent_skill_ids if parent_skill_ids is not None else meta.get("parent_skill_ids", []) + final_tags = tags if tags is not None else meta.get("tags", []) + final_created_by = created_by if created_by is not None else meta.get("created_by", "") + final_change_summary = change_summary if change_summary is not None else meta.get("change_summary", "") + + client = _get_cloud_client() + result = await asyncio.to_thread( + client.upload_skill, + skill_path, + visibility=visibility, + origin=final_origin, + parent_skill_ids=final_parents, + tags=final_tags, + created_by=final_created_by, + change_summary=final_change_summary, + ) + return _json_ok(result) + + except Exception as e: + logger.error(f"upload_skill failed: {e}", exc_info=True) + return _json_error(e, status="error", traceback=traceback.format_exc(limit=5)) + +def run_mcp_server() -> None: + """Console-script entry point for ``openspace-mcp``.""" + import argparse + + parser = argparse.ArgumentParser(description="OpenSpace MCP Server") + parser.add_argument("--transport", choices=["stdio", "sse"], default="stdio") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + if args.transport == "sse": + mcp.run(transport="sse", sse_params={"port": args.port}) + else: + mcp.run(transport="stdio") + + +if __name__ == "__main__": + run_mcp_server() diff --git a/openspace/platforms/__init__.py b/openspace/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85952b9c272ff0f83c35b27ed8906bc1a6e0e876 --- /dev/null +++ b/openspace/platforms/__init__.py @@ -0,0 +1,23 @@ +from .system_info import SystemInfoClient, get_system_info, get_screen_size +from .recording import RecordingClient, RecordingContextManager +from .screenshot import ScreenshotClient, AutoScreenshotWrapper +from .config import get_local_server_config, get_client_base_url + +__all__ = [ + # System Info + "SystemInfoClient", + "get_system_info", + "get_screen_size", + + # Recording + "RecordingClient", + "RecordingContextManager", + + # Screenshot + "ScreenshotClient", + "AutoScreenshotWrapper", + + # Config + "get_local_server_config", + "get_client_base_url", +] \ No newline at end of file diff --git a/openspace/platforms/config.py b/openspace/platforms/config.py new file mode 100644 index 0000000000000000000000000000000000000000..892853e059ef601750834c9bf995844680290306 --- /dev/null +++ b/openspace/platforms/config.py @@ -0,0 +1,91 @@ +import os +import json +from typing import Dict, Any +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +def get_local_server_config() -> Dict[str, Any]: + """ + Read local server configuration. + + Priority: + 1. Environment variable LOCAL_SERVER_URL (parsed into host/port) + 2. Config file local_server/config.json + 3. Defaults (127.0.0.1:5000) + + Returns: + Dict with 'host' and 'port' from server config + """ + # Check environment variable first (for OSWorld/remote VM integration) + env_url = os.getenv("LOCAL_SERVER_URL") + if env_url: + try: + # Parse URL like "http://localhost:5000" + from urllib.parse import urlparse + parsed = urlparse(env_url) + host = parsed.hostname or '127.0.0.1' + port = parsed.port or 5000 + logger.debug(f"Using LOCAL_SERVER_URL: {host}:{port}") + return { + 'host': host, + 'port': port, + 'debug': False, + } + except Exception as e: + logger.warning(f"Failed to parse LOCAL_SERVER_URL: {e}") + + # Find local_server config file + try: + # Try relative path from this file + current_dir = os.path.dirname(__file__) + config_path = os.path.join(current_dir, '../local_server/config.json') + config_path = os.path.abspath(config_path) + + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = json.load(f) + server_config = config.get('server', {}) + return { + 'host': server_config.get('host', '127.0.0.1'), + 'port': server_config.get('port', 5000), + 'debug': server_config.get('debug', False), + } + except Exception as e: + logger.debug(f"Failed to read local server config: {e}") + + # Return defaults + return { + 'host': '127.0.0.1', + 'port': 5000, + 'debug': False, + } + + +def get_client_base_url() -> str: + """ + Get base URL for connecting to local server. + + Priority: + 1. Environment variable LOCAL_SERVER_URL + 2. Read from local_server/config.json + 3. Default http://localhost:5000 + + Returns: + Base URL string + """ + # Check environment variable first + env_url = os.getenv("LOCAL_SERVER_URL") + if env_url: + return env_url + + # Read from config file + config = get_local_server_config() + host = config['host'] + port = config['port'] + + # Convert 0.0.0.0 to localhost for client + if host == '0.0.0.0': + host = 'localhost' + + return f"http://{host}:{port}" \ No newline at end of file diff --git a/openspace/platforms/recording.py b/openspace/platforms/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..0701c230154f9bc78989dc83286a6ed5b6b3bc63 --- /dev/null +++ b/openspace/platforms/recording.py @@ -0,0 +1,193 @@ +import aiohttp +from typing import Optional +from openspace.utils.logging import Logger +from .config import get_client_base_url + +logger = Logger.get_logger(__name__) + + +class RecordingClient: + """ + Client for screen recording via HTTP API. + + This client directly calls the local server's recording endpoints: + - POST /start_recording + - POST /end_recording + """ + + def __init__( + self, + base_url: Optional[str] = None, + timeout: int = 30 + ): + """ + Initialize recording client. + + Args: + base_url: Base URL of the local server + (default: read from local_server/config.json or env LOCAL_SERVER_URL) + timeout: Request timeout in seconds + """ + # Get base_url: priority is explicit > env > config file + if base_url is None: + base_url = get_client_base_url() + + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + return self._session + + async def start_recording(self, auto_cleanup: bool = True) -> bool: + """ + Start screen recording. + + Args: + auto_cleanup: If True, automatically end previous recording if one is in progress + """ + try: + session = await self._get_session() + url = f"{self.base_url}/start_recording" + + async with session.post(url) as response: + if response.status == 200: + logger.info("Screen recording started") + return True + elif response.status == 400 and auto_cleanup: + # Check if error is due to recording already in progress + error_text = await response.text() + if "already in progress" in error_text.lower(): + logger.warning("Recording already in progress, stopping previous recording...") + + # Try to end the previous recording + video_bytes = await self.end_recording() + if video_bytes: + logger.info("Previous recording ended successfully, retrying start...") + else: + logger.warning("Failed to end previous recording, but will retry start anyway...") + + # Retry starting recording (without auto_cleanup to avoid infinite loop) + return await self.start_recording(auto_cleanup=False) + else: + logger.error(f"Failed to start recording: HTTP {response.status} - {error_text}") + return False + else: + error_text = await response.text() + logger.error(f"Failed to start recording: HTTP {response.status} - {error_text}") + return False + + except Exception as e: + logger.error(f"Failed to start recording: {e}") + return False + + async def end_recording(self, dest: Optional[str] = None) -> Optional[bytes]: + """ + End screen recording and optionally save to file. + """ + try: + session = await self._get_session() + url = f"{self.base_url}/end_recording" + + # Use longer timeout for end_recording (file may be large) + async with session.post(url, timeout=aiohttp.ClientTimeout(total=60)) as response: + if response.status == 200: + video_bytes = await response.read() + + # Save to file if destination provided + if dest: + try: + with open(dest, "wb") as f: + f.write(video_bytes) + logger.info(f"Recording saved to: {dest}") + except Exception as e: + logger.error(f"Failed to save recording file: {e}") + return None + + logger.info("Screen recording ended") + return video_bytes + else: + error_text = await response.text() + logger.error(f"Failed to end recording: HTTP {response.status} - {error_text}") + return None + + except Exception as e: + logger.error(f"Failed to end recording: {e}") + return None + + async def close(self): + """Close the HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + # Give aiohttp time to finish cleanup callbacks + import asyncio + await asyncio.sleep(0.25) + logger.debug("Recording client session closed") + + async def __aenter__(self): + """Context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + await self.close() + return False + + +class RecordingContextManager: + + def __init__( + self, + base_url: Optional[str] = None, + output_path: Optional[str] = None, + timeout: Optional[int] = None + ): + """ + Initialize recording context manager. + + Args: + base_url: Base URL of the local server (default: from config) + output_path: Path to save recording (default: from config) + timeout: Request timeout in seconds (default: from config) + """ + # Load output_path from config if not provided + if output_path is None: + try: + from openspace.config import get_config + config = get_config() + if config.recording.screen_recording_path: + output_path = config.recording.screen_recording_path + except Exception: + pass + + self.client = RecordingClient(base_url=base_url, timeout=timeout) + self.output_path = output_path + self.recording_started = False + + async def __aenter__(self) -> RecordingClient: + """Start recording on context entry.""" + success = await self.client.start_recording() + if success: + self.recording_started = True + logger.info("Recording context started") + else: + logger.warning("Failed to start recording in context") + + return self.client + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Stop recording on context exit.""" + if self.recording_started: + try: + await self.client.end_recording(dest=self.output_path) + logger.info("Recording context ended") + except Exception as e: + logger.error(f"Failed to end recording in context: {e}") + + await self.client.close() + return False \ No newline at end of file diff --git a/openspace/platforms/screenshot.py b/openspace/platforms/screenshot.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3f723cece23acea0190df9eb2eb5d59738b10 --- /dev/null +++ b/openspace/platforms/screenshot.py @@ -0,0 +1,263 @@ +""" +Screenshot client for capturing screens via HTTP API. + +This module provides a screenshot client that captures screenshots by calling +the local_server's /screenshot endpoint. + +Always uses HTTP API (like RecordingClient): +- Local: http://127.0.0.1:5000/screenshot +- Remote: http://remote-vm:5000/screenshot +""" +import aiohttp +from typing import Optional +from openspace.utils.logging import Logger +from .config import get_client_base_url + +logger = Logger.get_logger(__name__) + + +class ScreenshotClient: + + def __init__( + self, + base_url: Optional[str] = None, + timeout: int = 10 + ): + """ + Initialize screenshot client. + + Args: + base_url: Base URL of local_server + (default: read from config/env, typically http://127.0.0.1:5000) + timeout: Request timeout (seconds) + """ + # Get base_url from config if not provided + if base_url is None: + base_url = get_client_base_url() + + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._session = None + + logger.debug(f"ScreenshotClient initialized: {self.base_url}") + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + return self._session + + @staticmethod + def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool: + """ + Validate image response using magic bytes. + + Args: + content_type: HTTP Content-Type header + data: Response data bytes + + Returns: + True if data is valid PNG/JPEG image + """ + if not isinstance(data, (bytes, bytearray)) or not data: + return False + + # PNG magic bytes: \x89PNG\r\n\x1a\n + if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n": + return True + + # JPEG magic bytes: \xff\xd8\xff + if len(data) >= 3 and data[:3] == b"\xff\xd8\xff": + return True + + # Fallback to content-type check + if content_type and ("image/png" in content_type or "image/jpeg" in content_type): + return True + + return False + + async def capture(self) -> Optional[bytes]: + """ + Capture screenshot via HTTP API. + + Calls: GET {base_url}/screenshot + + Returns: + PNG image bytes, or None on failure + """ + try: + session = await self._get_session() + url = f"{self.base_url}/screenshot" + + logger.debug(f"Requesting screenshot: {url}") + + async with session.get(url) as response: + if response.status == 200: + content_type = response.headers.get("Content-Type", "") + screenshot_bytes = await response.read() + + # Validate image format + if self._is_valid_image_response(content_type, screenshot_bytes): + logger.debug(f"Screenshot captured: {len(screenshot_bytes)} bytes") + return screenshot_bytes + else: + logger.error("Invalid screenshot format received") + return None + else: + error_text = await response.text() + logger.error(f"Failed to capture screenshot: HTTP {response.status} - {error_text}") + return None + + except Exception as e: + logger.error(f"Failed to capture screenshot: {e}") + return None + + async def capture_to_file(self, output_path: str) -> bool: + try: + screenshot = await self.capture() + if screenshot: + import os + os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) + with open(output_path, 'wb') as f: + f.write(screenshot) + logger.info(f"Screenshot saved to: {output_path}") + return True + return False + except Exception as e: + logger.error(f"Failed to save screenshot to file: {e}") + return False + + async def get_screen_size(self) -> tuple[int, int]: + """ + Get screen size via HTTP API. + + Calls: GET {base_url}/screen_size + + Returns: + (width, height) + """ + try: + session = await self._get_session() + url = f"{self.base_url}/screen_size" + + async with session.get(url) as response: + if response.status == 200: + data = await response.json() + width = data.get('width', 1920) + height = data.get('height', 1080) + logger.debug(f"Screen size: {width}x{height}") + return (width, height) + else: + logger.warning("Failed to get screen size, using default") + return (1920, 1080) + + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return (1920, 1080) + + async def close(self): + """Close HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + logger.debug("Screenshot client session closed") + + async def __aenter__(self): + """Context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + await self.close() + return False + + +class AutoScreenshotWrapper: + """ + Wrapper that automatically captures screenshots after backend calls. + + This wrapper can be used to wrap any backend tool/session and automatically + capture screenshots after each operation. + + Usage: + # Wrap a backend tool + wrapped_tool = AutoScreenshotWrapper( + tool=gui_tool, + screenshot_client=screenshot_client, + on_screenshot=lambda screenshot: recorder.record_step(...) + ) + + # Use wrapped tool normally + result = await wrapped_tool.execute(...) + # Screenshot is automatically captured and handled + """ + + def __init__( + self, + tool, + screenshot_client: Optional[ScreenshotClient] = None, + on_screenshot=None, + enabled: bool = True + ): + """ + Initialize auto-screenshot wrapper. + + Args: + tool: The tool/session to wrap + screenshot_client: Screenshot client to use (created if None) + on_screenshot: Callback function(screenshot_bytes) called after each screenshot + enabled: Whether auto-screenshot is enabled + """ + self._tool = tool + self._screenshot_client = screenshot_client or ScreenshotClient() + self._on_screenshot = on_screenshot + self._enabled = enabled + + def __getattr__(self, name): + """Delegate attribute access to wrapped tool.""" + return getattr(self._tool, name) + + async def _capture_and_notify(self): + """Capture screenshot and notify callback.""" + if not self._enabled: + return + + try: + screenshot = await self._screenshot_client.capture() + if screenshot and self._on_screenshot: + await self._on_screenshot(screenshot) + except Exception as e: + logger.warning(f"Failed to auto-capture screenshot: {e}") + + async def execute(self, *args, **kwargs): + """ + Execute tool and auto-capture screenshot. + """ + # Execute original method + result = await self._tool.execute(*args, **kwargs) + + # Capture screenshot after execution + await self._capture_and_notify() + + return result + + async def _arun(self, *args, **kwargs): + """ + Run tool and auto-capture screenshot. + """ + # Execute original method + result = await self._tool._arun(*args, **kwargs) + + # Capture screenshot after execution + await self._capture_and_notify() + + return result + + def enable(self): + """Enable auto-screenshot.""" + self._enabled = True + + def disable(self): + """Disable auto-screenshot.""" + self._enabled = False \ No newline at end of file diff --git a/openspace/platforms/system_info.py b/openspace/platforms/system_info.py new file mode 100644 index 0000000000000000000000000000000000000000..6c480bcfda05306b67b2f935bf91ee2f71fd2433 --- /dev/null +++ b/openspace/platforms/system_info.py @@ -0,0 +1,172 @@ +import aiohttp +from typing import Optional, Dict, Any +from openspace.utils.logging import Logger +from .config import get_client_base_url + +logger = Logger.get_logger(__name__) + + +class SystemInfoClient: + """ + This client provides simple methods to get: + - Platform info (OS, architecture, version, etc.) + - Screen size + - Cursor position + """ + + def __init__( + self, + base_url: Optional[str] = None, + timeout: int = 10 + ): + """ + Initialize system info client. + + Args: + base_url: Base URL of the local server + (default: read from local_server/config.json or env LOCAL_SERVER_URL) + timeout: Request timeout in seconds + """ + # Get base_url: priority is explicit > env > config file + if base_url is None: + base_url = get_client_base_url() + + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._session: Optional[aiohttp.ClientSession] = None + self._cached_info: Optional[Dict[str, Any]] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + return self._session + + async def get_system_info(self, use_cache: bool = True) -> Optional[Dict[str, Any]]: + """ + Get comprehensive system information. + + Returns information including: + - system: OS name (Linux, Darwin, Windows) + - release: OS release version + - version: Detailed version string + - machine: Architecture (x86_64, arm64, etc.) + - processor: Processor type + - Additional platform-specific info + + Args: + use_cache: Whether to use cached info (default: True) + """ + # Check cache + if use_cache and self._cached_info: + logger.debug("Using cached system info") + return self._cached_info + + try: + session = await self._get_session() + url = f"{self.base_url}/platform" + + async with session.get(url) as response: + if response.status == 200: + info = await response.json() + + # Cache the result + if use_cache: + self._cached_info = info + + logger.debug(f"System info retrieved: {info.get('system')}") + return info + else: + error_text = await response.text() + logger.error(f"Failed to get system info: HTTP {response.status} - {error_text}") + return None + + except Exception as e: + logger.error(f"Failed to get system info: {e}") + return None + + async def get_screen_size(self) -> Optional[Dict[str, int]]: + """ + Get screen size. + + Returns: + Dict with 'width' and 'height', or None on failure + """ + try: + session = await self._get_session() + url = f"{self.base_url}/screen_size" + + async with session.get(url) as response: + if response.status == 200: + size = await response.json() + logger.debug(f"Screen size: {size.get('width')}x{size.get('height')}") + return { + "width": size.get("width"), + "height": size.get("height") + } + else: + error_text = await response.text() + logger.error(f"Failed to get screen size: HTTP {response.status} - {error_text}") + return None + + except Exception as e: + logger.error(f"Failed to get screen size: {e}") + return None + + async def get_cursor_position(self) -> Optional[Dict[str, int]]: + """ + Get current cursor position. + + Returns: + Dict with 'x' and 'y', or None on failure + """ + try: + session = await self._get_session() + url = f"{self.base_url}/cursor_position" + + async with session.get(url) as response: + if response.status == 200: + pos = await response.json() + return { + "x": pos.get("x"), + "y": pos.get("y") + } + else: + error_text = await response.text() + logger.error(f"Failed to get cursor position: HTTP {response.status} - {error_text}") + return None + + except Exception as e: + logger.error(f"Failed to get cursor position: {e}") + return None + + def clear_cache(self): + """Clear cached system information.""" + self._cached_info = None + logger.debug("System info cache cleared") + + async def close(self): + """Close the HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + logger.debug("System info client session closed") + + async def __aenter__(self): + """Context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + await self.close() + return False + +async def get_system_info(base_url: Optional[str] = None) -> Optional[Dict[str, Any]]: + async with SystemInfoClient(base_url=base_url) as client: + return await client.get_system_info(use_cache=False) + + +async def get_screen_size(base_url: Optional[str] = None) -> Optional[Dict[str, int]]: + async with SystemInfoClient(base_url=base_url) as client: + return await client.get_screen_size() \ No newline at end of file diff --git a/openspace/prompts/__init__.py b/openspace/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36d80ee64b7de1722f8bd6e641baa7896cdbf2b1 --- /dev/null +++ b/openspace/prompts/__init__.py @@ -0,0 +1,4 @@ +from openspace.prompts.grounding_agent_prompts import GroundingAgentPrompts +from openspace.prompts.skill_engine_prompts import SkillEnginePrompts + +__all__ = ["GroundingAgentPrompts", "SkillEnginePrompts"] \ No newline at end of file diff --git a/openspace/prompts/grounding_agent_prompts.py b/openspace/prompts/grounding_agent_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..3a68d0dfae9317fa188f475848feb788ef830438 --- /dev/null +++ b/openspace/prompts/grounding_agent_prompts.py @@ -0,0 +1,296 @@ +from typing import List, Optional, Set + + +class GroundingAgentPrompts: + + TASK_COMPLETE = "" + + @classmethod + def build_system_prompt(cls, backends: Optional[List[str]] = None) -> str: + """Build a system prompt tailored to the actually registered backends. + + Args: + backends: Active backend names (e.g. ``["shell", "mcp", "gui"]``). + ``None`` falls back to all backends for backward compatibility. + """ + scope: Set[str] = set(backends) if backends else {"gui", "shell", "mcp", "web", "system"} + + sections: List[str] = [] + + # Core + sections.append( + "You are a Grounding Agent. Execute tasks using tools.\n\n" + "# Tool Execution\n\n" + "- Select appropriate tools from descriptions and schemas\n" + "- Provide correct parameters\n" + "- Call multiple tools if needed\n" + "- Tools execute immediately, results appear in next iteration\n" + "- If you need results to decide next action, wait for next iteration" + ) + + # Tool Selection Tips (only mention backends that exist) + tips: List[str] = [] + has_mcp = "mcp" in scope + has_shell = "shell" in scope + has_gui = "gui" in scope + + if has_mcp and has_shell: + tips.append("- **MCP tools** and **Shell tools** are typically faster and more accurate when applicable") + elif has_mcp: + tips.append("- **MCP tools** are typically faster and more accurate when applicable") + elif has_shell: + tips.append("- **Shell tools** are fast and accurate for command-line and scripting tasks") + + if has_gui: + if has_mcp or has_shell: + tips.append("- **GUI tools** offer finer-grained control and can handle tasks not covered by other tools") + else: + tips.append("- **GUI tools** provide direct interaction with graphical interfaces") + tips.append("- Choose based on the task requirements and tool availability") + + if tips: + sections.append("# Tool Selection Tips\n\n" + "\n".join(tips)) + + # Visual Analysis Control + if has_gui: + sections.append( + "# Visual Analysis Control\n\n" + "GUI tools auto-analyze screenshots to extract information.\n\n" + "To skip analysis when NOT needed, add parameter:\n" + "```json\n" + '{"task_description": "...", "skip_visual_analysis": true}\n' + "```\n\n" + "**Decision Rule:**\n" + "- Task goal is OPERATIONAL (open/navigate/click/show): Skip analysis\n" + "- Task goal requires KNOWLEDGE EXTRACTION (read/extract/save data): Keep analysis\n\n" + "**Examples:**\n" + '- "Open settings page": Operational only, skip analysis\n' + '- "Open settings and record all values": Needs knowledge, keep analysis\n' + '- "Navigate to GitHub homepage": Operational only, skip analysis\n' + '- "Search Python tutorials and save top 5 titles": Needs knowledge, keep analysis\n\n' + "**Key principle:** If you need to extract information FROM the screen for " + "subsequent steps or user reporting, keep analysis (don't skip).\n" + "**Note:** Only GUI tools support this parameter. Other backend tools ignore it." + ) + + # Mid-iteration skill retrieval hint + sections.append( + "# Skill Retrieval\n\n" + "If the current approach is failing or the task requires domain-specific " + "knowledge you don't have, call `retrieve_skill` with a short description " + "of what guidance you need. It returns verified procedures when available." + ) + + # Task Completion (always present) + sections.append( + "# Task Completion\n\n" + "After each iteration, evaluate if the task is complete:\n\n" + "**If task is COMPLETE:**\n" + "- Write a response summarizing what was accomplished\n" + f"- Include the completion token `{cls.TASK_COMPLETE}` on a new line at the end of your response\n" + "- Example response format:\n" + " ```\n" + " I have successfully completed the task. The file has been created at /path/to/file.txt with the requested content.\n" + f"\n {cls.TASK_COMPLETE}\n" + " ```\n\n" + "**If task is NOT complete:**\n" + "- Continue by calling the appropriate tools\n" + f"- Do NOT output `{cls.TASK_COMPLETE}`\n" + "- Tool results will appear in the next iteration\n\n" + f"The token `{cls.TASK_COMPLETE}` signals that no further iterations are needed." + ) + + return "\n\n".join(sections) + + @staticmethod + def iteration_summary( + instruction: str, + iteration: int, + max_iterations: int + ) -> str: + """ + Build iteration summary prompt for LLMClient auto-summary. + LLM extracts information directly from tool results in conversation history. + """ + return f"""Based on the original task and the tool execution results in the conversation above, generate a structured iteration summary. + +**Original Task:** +{instruction} + +**Progress:** Iteration {iteration} of {max_iterations} + +**Generate Summary in This Format:** + +## Iteration {iteration} Progress + +Actions taken: + +Knowledge obtained (COMPLETE and SPECIFIC): +- File locations: +- Visual content: +- Data retrieved: +- URLs/Links: +- System state: + +Errors encountered: + +CRITICAL GUIDELINES: +- This summary is for preserving knowledge for subsequent iterations +- Extract ALL concrete information from tool outputs in the conversation above +- Filenames, paths, URLs - use exact values from tool outputs +- Visual content - extract actual text/data visible, not just "saw something" +- Search results - include specific data, not vague descriptions +- The next iteration cannot see current tool outputs - this summary is the ONLY source of knowledge""" + + @staticmethod + def visual_analysis( + tool_name: str, + num_screenshots: int, + task_description: str = "" + ) -> str: + """ + Build prompt for visual analysis of screenshots. + + Args: + tool_name: Tool name that generated the screenshots + num_screenshots: Number of screenshots + task_description: Original task description for context + """ + screenshot_text = "screenshot" if num_screenshots == 1 else f"{num_screenshots} screenshots" + these_text = "this screenshot" if num_screenshots == 1 else "these screenshots" + + task_context = f""" +**Original Task**: {task_description} + +Focus on extracting information RELEVANT to this task. Prioritize content that helps accomplish the goal. +""" if task_description else "" + + return f"""Extract the KNOWLEDGE and INFORMATION from {these_text}. This will be passed to the next iteration so it can continue working with the information (search, analyze, save, etc.). Without this extraction, the visual content would only be viewable by humans and unusable for subsequent operations. +{task_context} +**EXTRACT all visible knowledge content** (prioritize task-relevant information): +1. **Text content**: Articles, documentation, code, messages, descriptions - extract the actual text +2. **Data points**: Numbers, statistics, measurements, values, percentages - be specific +3. **List items**: Names, titles, entries in lists/search results/files - list them out +4. **Structured data**: Information from tables, charts, forms - describe what they contain +5. **Key information**: URLs, paths, names, IDs, dates, labels - anything useful for next steps + +**IGNORE interface elements**: +- Buttons, menus, toolbars, navigation bars +- UI design, layout, colors, styling +- Non-informational visual elements + +**Goal**: Extract usable knowledge that enables the next agent to work with this information programmatically. Be SPECIFIC and COMPLETE, but FOCUS on what's relevant to the task. + +{screenshot_text.capitalize()} from tool '{tool_name}'""" + + @staticmethod + def final_summary( + instruction: str, + iterations: int + ) -> str: + """ + Build prompt for generating final summary across all iterations. + """ + return f"""Based on the complete conversation history above (including all {iterations} iteration summaries and tool executions), generate a comprehensive final summary. + +## Final Task Summary + +Task: {instruction} + +What was accomplished: + +Key information obtained: +- Files: +- Data: +- Findings: + +Issues encountered: + +Result: <"Success" or "Incomplete"> + +Guidelines: +- Consolidate information from ALL iteration summaries +- Include concrete deliverables (file paths, data, etc.) +- Be comprehensive but concise +- Focus on what the user cares about""" + + @staticmethod + def workspace_directory(workspace_dir: str) -> str: + """ + Build workspace directory information for cross-iteration/cross-backend data sharing. + """ + import os + # Check if this is a benchmark scenario: + # 1. LiveMCPBench /root mapping + # 2. Workspace already contains files (e.g. GDPVal reference files) + is_benchmark = "/root" in workspace_dir or "LiveMCPBench/root" in workspace_dir + if not is_benchmark: + try: + has_existing_files = os.path.isdir(workspace_dir) and bool(os.listdir(workspace_dir)) + except OSError: + has_existing_files = False + is_benchmark = has_existing_files + + if is_benchmark: + # Benchmark / task mode: task files are in workspace directory + return f"""**Working Directory**: `{workspace_dir}` +- All task files (input/output) are located in this directory +- Read from and write to this directory for all file operations""" + else: + # Normal mode: workspace is for intermediate results + return f"""**Working Directory**: `{workspace_dir}` +- Persist intermediate results here; later iterations/backends can read what you saved earlier +- Note: User's personal files are NOT here - search in ~/Desktop, ~/Documents, ~/Downloads, etc.""" + + @staticmethod + def workspace_matching_files(matching_files: List[str]) -> str: + """ + Build alert for files matching task requirements. + """ + files_str = ', '.join([f"`{f}`" for f in matching_files]) + return f"""**Workspace Alert**: Files matching task requirements found: {files_str} +- Read these files to verify if they satisfy the task +- If satisfied, mark task as completed +- If not satisfied, modify or recreate as needed""" + + @staticmethod + def workspace_recent_files(total_files: int, recent_files: List[str]) -> str: + """ + Build info for recently modified files. + """ + recent_list = ', '.join([f"`{f}`" for f in recent_files[:15]]) + return f"""**Workspace Info**: {total_files} files exist, {len(recent_files)} recently modified +Recent files: {recent_list} +Consider checking recent files before creating new ones""" + + @staticmethod + def workspace_file_list(files: List[str]) -> str: + """ + Build list of all existing files. + """ + files_list = ', '.join([f"`{f}`" for f in files[:15]]) + if len(files) > 15: + files_list += f" (and {len(files) - 15} more)" + return f"**Workspace Info**: {len(files)} existing file(s): {files_list}" + + @staticmethod + def iteration_feedback( + iteration: int, + llm_summary: str, + add_guidance: bool = True + ) -> str: + """ + Build feedback message to pass iteration summary to next iteration. + """ + content = f"""## Iteration {iteration} Summary + +{llm_summary}""" + + if add_guidance: + content += f""" +--- +Now continue with iteration {iteration + 1}. You can see the full conversation history above. Based on all progress so far, decide whether to: +- Call more tools if the task is not yet complete +- Output {GroundingAgentPrompts.TASK_COMPLETE} if the task is fully accomplished""" + + return content \ No newline at end of file diff --git a/openspace/prompts/skill_engine_prompts.py b/openspace/prompts/skill_engine_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..466744c6e3b31265f65e74c7b5ca57d5b84c32c9 --- /dev/null +++ b/openspace/prompts/skill_engine_prompts.py @@ -0,0 +1,803 @@ +"""Prompts for the skill engine subsystem.""" + +class SkillEnginePrompts: + """Central registry of prompts used by the skill engine.""" + + # Evolution self-assessment tokens + EVOLUTION_COMPLETE = "" + EVOLUTION_FAILED = "" + + @staticmethod + def evolution_fix( + *, + current_content: str, + direction: str, + failure_context: str, + tool_issue_summary: str = "", + metric_summary: str = "", + ) -> str: + """Build the prompt for a FIX evolution (in-place repair). + + Args: + current_content: Current SKILL.md content. + direction: What to fix and why (from suggestion or diagnosis). + failure_context: Formatted recent analysis context showing failures. + tool_issue_summary: Optional — tool degradation details (trigger 2). + metric_summary: Optional — skill health metrics (trigger 3). + """ + return _EVOLUTION_FIX_TEMPLATE.format( + current_content=current_content, + direction=direction, + failure_context=failure_context, + tool_issue_summary=tool_issue_summary or "(none)", + metric_summary=metric_summary or "(none)", + evolution_complete=SkillEnginePrompts.EVOLUTION_COMPLETE, + evolution_failed=SkillEnginePrompts.EVOLUTION_FAILED, + ) + + @staticmethod + def evolution_derived( + *, + parent_content: str, + direction: str, + execution_insights: str, + metric_summary: str = "", + ) -> str: + """Build the prompt for a DERIVED evolution (enhanced version). + + Args: + parent_content: Parent SKILL.md content. + direction: What to enhance and why. + execution_insights: Formatted analysis context with improvement signals. + metric_summary: Optional — skill health metrics (trigger 3). + """ + return _EVOLUTION_DERIVED_TEMPLATE.format( + parent_content=parent_content, + direction=direction, + execution_insights=execution_insights, + metric_summary=metric_summary or "(none)", + evolution_complete=SkillEnginePrompts.EVOLUTION_COMPLETE, + evolution_failed=SkillEnginePrompts.EVOLUTION_FAILED, + ) + + @staticmethod + def evolution_captured( + *, + direction: str, + category: str, + execution_highlights: str, + ) -> str: + """Build the prompt for a CAPTURED evolution (brand-new skill). + + Args: + direction: What pattern to capture. + category: Desired skill category (tool_guide / workflow / reference). + execution_highlights: Task context where the pattern was observed. + """ + return _EVOLUTION_CAPTURED_TEMPLATE.format( + direction=direction, + category=category, + execution_highlights=execution_highlights, + evolution_complete=SkillEnginePrompts.EVOLUTION_COMPLETE, + evolution_failed=SkillEnginePrompts.EVOLUTION_FAILED, + ) + + @staticmethod + def evolution_confirm( + *, + skill_id: str, + skill_content: str, + proposed_type: str, + proposed_direction: str, + trigger_context: str, + recent_analyses: str, + ) -> str: + """Build the prompt for LLM confirmation of rule-based evolution candidates. + + Used by Trigger 2 (tool degradation) and Trigger 3 (metric monitor) + to confirm whether a skill truly needs evolution before proceeding. + + Args: + skill_id: Unique skill_id of the candidate skill. + skill_content: Truncated SKILL.md content. + proposed_type: "fix" or "derived". + proposed_direction: What the rule-based system suggests. + trigger_context: Summary of the trigger (metrics or tool issue). + recent_analyses: Formatted recent execution analyses. + """ + return _EVOLUTION_CONFIRM_TEMPLATE.format( + skill_id=skill_id, + skill_content=skill_content, + proposed_type=proposed_type, + proposed_direction=proposed_direction, + trigger_context=trigger_context, + recent_analyses=recent_analyses, + ) + + @staticmethod + def execution_analysis( + *, + task_description: str, + execution_status: str, + iterations: int, + tool_list: str, + skill_section: str, + conversation_log: str, + traj_summary: str, + selected_skill_ids_json: str, + resource_info: str = "", + ) -> str: + """Build the prompt for post-execution skill quality analysis. + + Args: + task_description: Human-readable description of the task. + execution_status: Agent's self-reported status ("success" / "incomplete" / "error"). + NOT ground truth — the analysis LLM assesses actual completion independently. + iterations: Number of agent iterations used. + tool_list: List of available tool names with backend info. + skill_section: Pre-formatted markdown section describing selected skills. + Empty string when no skills were selected. + conversation_log: Formatted execution log (priority-truncated to fit context). + traj_summary: Structured tool execution timeline from traj.jsonl. + selected_skill_ids_json: JSON-encoded list of selected skill IDs. + resource_info: Recording / skill directory paths and tool-use guidance. + """ + return _EXECUTION_ANALYSIS_TEMPLATE.format( + task_description=task_description, + execution_status=execution_status, + iterations=iterations, + tool_list=tool_list, + skill_section=skill_section, + conversation_log=conversation_log, + traj_summary=traj_summary, + selected_skill_ids_json=selected_skill_ids_json, + resource_info=resource_info, + ) + +_EXECUTION_ANALYSIS_TEMPLATE = """\ +You are an expert analyst evaluating an autonomous agent's task execution. +Your job is to assess how the agent used its skills and tools, trace the +reasoning and outcome of each iteration, and surface actionable insights. + +## Task Context + +**Task**: {task_description} +**Agent self-reported status**: {execution_status} +**Iterations used**: {iterations} +**Available tools**: {tool_list} + +> This is the agent's **self-reported** status, not ground truth. +> ``success`` = agent output ```` (may be wrong/premature); +> ``incomplete`` = iteration budget exhausted; ``error`` = code exception. +> You must independently judge actual task completion below. + +{skill_section} + +## Tool Execution Timeline (from traj.jsonl) + +This is a structured summary of every tool invocation and its outcome: + +{traj_summary} + +## Agent Conversation Log + +This shows the agent's reasoning (ASSISTANT), tool calls (TOOL_CALL), +tool results (TOOL_RESULT / TOOL_ERROR), and the user's original instruction. + +**Reading guide**: +- ``[USER INSTRUCTION]`` — the original task from the user. +- ``[Iter N] ASSISTANT:`` — the agent's reasoning and decisions at iteration N. +- ``[Iter N] TOOL_CALL:`` — what tool the agent invoked and with what arguments. +- ``[Iter N] TOOL_ERROR:`` — tool returned an error (high priority for analysis). +- ``[Iter N] TOOL_RESULT:`` — tool returned successfully. + Some tool results include an embedded "Execution Summary" from inner agents + (e.g. shell_agent runs multiple internal steps before returning). + +{conversation_log} + +## Available Resources + +{resource_info} + +## Analysis Instructions + +### 1. Per-iteration trace + +For each agent iteration, identify: +- **What** the agent decided to do and **why** (from ASSISTANT content). +- **Which tool** was called and what happened (success / error / timeout). +- **Cause of next iteration**: did the agent retry due to error? Switch strategy? + Follow a skill step? Or complete the task? + +### 2. Task completion assessment + +Did the agent **actually** accomplish the user's request? +Judge from conversation evidence (tool results, final output), **not** the +self-reported status. + +- ``task_completed = true`` ONLY when the user's goal is genuinely fulfilled. +- Watch for mismatches: agent may claim ```` after giving up or + getting wrong results; conversely, it may finish the work but exhaust + iterations without outputting ````. +- Explain your reasoning in ``execution_note``. + +### 3. Skill assessment + +For each selected skill (IDs: {selected_skill_ids_json}), produce one +``skill_judgments`` entry: +- ``skill_id``: Use the **exact skill_id** from the list above (e.g. + ``weather__imp_a1b2c3d4``). Do NOT use the human-readable name alone. +- ``skill_applied``: Was the skill's information **actually used** (not just injected)? + - WORKFLOW skill: did the agent follow the prescribed steps? + - TOOL_GUIDE skill: did the agent use the tool as the guide describes? + - REFERENCE skill: did the agent rely on the knowledge for decisions? +- ``note``: Describe HOW the skill was used. If it wasn't applied, explain why. + +If no skills were selected, ``skill_judgments`` must be an empty list. + +### 4. Tool issues (separate from skill assessment) + +List **only tools that had actual problems** during this execution. +Do NOT list tools that worked correctly or were simply unused. + +**Tool key format** — use the key that matches the tool list above: +- MCP tools: ``mcp:server_name:tool_name`` +- Other tools: ``backend:tool_name`` + +For each problematic tool, include: +- The **symptom** (error, timeout, wrong output, semantic failure, etc.). +- The **likely cause** if you can infer it (network issue, tool bug, bad parameters, + misleading description, etc.). +- Whether the issue is the **tool's fault** or the **agent's misuse** of the tool. + +These issues are fed to a tool quality tracking system. If the tool returned HTTP 200 +but the data is incorrect or unusable, still flag it — your qualitative judgment +complements the raw success/failure tracking. + +### 5. Evolution suggestions + +The skill library improves through execution feedback. **If something went wrong, +fix it. If something useful was learned, capture or derive it.** Actively look for evolution +opportunities — they are how the system gets smarter over time. + +You may output **0 to N** suggestions. Each suggestion is one of three types: + +| Type | When to use | ``target_skills`` | +|------|------------|-------------------| +| ``fix`` | A selected skill had **incorrect, outdated, or incomplete** instructions that caused failure, deviation, or unnecessary friction. The skill needs repair. | ``["skill_id"]`` — exactly 1 skill, use the exact skill_id | +| ``derived`` | A selected skill worked, but the execution revealed a **better approach** — improved steps, added error handling, broader scope, or useful edge-case handling. Worth creating an enhanced version. Can also **merge** multiple skills. | ``["parent_skill_id"]`` or ``["skill_id_a", "skill_id_b"]`` for merge | +| ``captured`` | The agent solved the task **without skill guidance** (or skills were not relevant) and the approach is **reusable** — a debugging technique, a tool usage pattern, a multi-step workflow, a non-obvious workaround. | ``[]`` (empty list) | + +**One type per suggestion**: Each suggestion MUST have exactly one ``type`` — pick +``fix``, ``derived``, OR ``captured``. A single suggestion cannot be two types at once. +Different suggestions in the same analysis MAY have different types (e.g. one ``fix`` +for a broken skill and one ``captured`` for a novel pattern are both fine). + +**Guiding principles:** +- ``fix``: If the skill's instructions led the agent astray, caused errors, or missed + important steps/caveats, it should be fixed. Suboptimal instructions that cost extra + iterations also warrant a fix. +- ``derived``: If the agent found a meaningfully better way to accomplish what a skill + describes — even if the original skill "worked" — suggest deriving a new version. + The improvement should be generalizable beyond this specific task. +- ``captured``: If the agent discovered a useful pattern, workaround, or technique that + would benefit future executions, capture it. Err on the side of capturing — a slightly + redundant skill is better than lost knowledge. +- **Do NOT** capture trivial one-step operations or highly task-specific data unlikely to recur. +- **Do NOT** capture something that an existing selected skill already covers adequately. + +For each suggestion, specify: +- ``type``: ``"fix"`` | ``"derived"`` | ``"captured"`` +- ``target_skills``: list of **exact skill_id(s)** from the selected skills above — + ``["weather__imp_a1b2c3d4"]`` for fix (exactly 1), + ``["skill_id_1"]`` or ``["skill_id_1", "skill_id_2"]`` for derived (1+ for merge), + ``[]`` for captured +- ``category``: ``"tool_guide"`` | ``"workflow"`` | ``"reference"`` +- ``direction``: 1-2 sentences describing **what** to fix / derive / capture + +### When you need more information + +**In most cases the trace data above is sufficient.** If not: +1. Use ``read_file`` / ``list_dir`` to inspect recording artifacts or output files. +2. If still unclear, use ``run_shell`` or other available tools to reproduce the error. + +### Output format + +Return **exactly one** JSON object (no markdown fences, no explanation outside JSON): + +{{ + "task_completed": true, + "execution_note": "2-3 sentence overview of execution quality and outcome.", + "tool_issues": [ + "mcp:server_name:tool_name — symptom; likely cause (tool fault / agent misuse)", + "backend:tool_name — symptom; likely cause" + ], + "skill_judgments": [ + {{ + "skill_id": "weather__imp_a1b2c3d4", + "skill_applied": true, + "note": "How the skill was used, deviations, and effectiveness." + }} + ], + "evolution_suggestions": [ + {{ + "type": "fix", + "target_skills": ["weather__imp_a1b2c3d4"], + "category": "workflow", + "direction": "What to fix and why." + }}, + {{ + "type": "derived", + "target_skills": ["weather__imp_a1b2c3d4", "geocoding__imp_e5f6g7h8"], + "category": "workflow", + "direction": "Merge these skills into a unified approach." + }} + ] +}} + +**Rules**: +- ``skill_judgments`` must include exactly one entry per selected skill ID. + If no skills were selected, ``skill_judgments`` must be ``[]``. +- ``tool_issues``: ``"key — description"`` format (MCP: ``mcp:server:tool``, other: ``backend:tool``). ``[]`` if no problems. +- ``evolution_suggestions``: ``[]`` only if the execution revealed no issues to fix and no reusable patterns to capture. + For ``fix``, ``target_skills`` must be a list with exactly 1 skill name from the selected skills. + For ``derived``, ``target_skills`` must be a list with 1 or more skill names (multi = merge). + For ``captured``, ``target_skills`` must be ``[]``. + ``category``: one of ``"tool_guide"``, ``"workflow"``, ``"reference"``. +- ``execution_note``: substantive but concise (2-3 sentences). +""" + + +_EVOLUTION_FIX_TEMPLATE = """\ +You are a skill editor. Your job is to **fix** an existing skill that has +been identified as broken, outdated, or incomplete. + +A skill is a directory containing ``SKILL.md`` (the main instruction file) +and optionally auxiliary files (scripts, configs, examples, etc.). + +## Current Skill Content + +{current_content} + +## What needs fixing + +{direction} + +## Execution failure context + +These are recent task executions where this skill was involved: + +{failure_context} + +## Tool issue details + +{tool_issue_summary} + +## Skill health metrics + +{metric_summary} + +## Instructions + +1. Analyze the failure context and identify the root cause in the skill's + instructions (wrong parameters, outdated API, missing error handling, etc.). +2. Fix the affected files to address the identified issues. +3. Preserve the overall structure and YAML frontmatter format (``---`` fences) + in SKILL.md. +4. Keep ``name`` and ``description`` in frontmatter; update ``description`` + only if the skill's purpose has changed. +5. Be surgical — fix what's broken without unnecessary rewrites. + +## Output format + +Your output MUST have exactly two parts: + +**Part 1** — A summary line on the very first line: + +CHANGE_SUMMARY: + +**Part 2** — After one blank line, the actual changes in one of the formats +below. + +### Format A: Patch (PREFERRED for fixes — use this unless you need a full rewrite) + +The patch format lets you make surgical, targeted edits across one or more +files. Structure: + +*** Begin Patch +*** Update File: +@@ + +- ++ + +*** End Patch + +**How ``@@`` anchor lines work** (NOT the same as unified-diff ``@@ -n,m +n,m @@``): +- Write ``@@`` followed by a single line that already exists verbatim in the + file. The system searches forward for this line and applies the changes + immediately after locating it. +- After the ``@@`` line, prefix every line with exactly one character: + ``-`` = delete this old line, ``+`` = insert this new line, + `` `` (one space) = keep this line unchanged (context). +- You may have multiple ``@@`` sections inside one ``*** Update File`` block. + +Other operations: +- ``*** Add File: path`` — every content line prefixed with ``+``. +- ``*** Delete File: path`` — no content lines needed. + +Example — fixing an incorrect curl parameter and adding a missing step: + +CHANGE_SUMMARY: Fixed curl content-type header and added retry logic + +*** Begin Patch +*** Update File: SKILL.md +@@ 3. Send the API request: + 3. Send the API request: +- curl -X POST -H "Content-Type: text/plain" ... ++ curl -X POST -H "Content-Type: application/json" ... +@@ ## Error handling + ## Error handling ++ ++4. **Retry on transient failures**: If you receive a 429 or 5xx status, ++ wait 2 seconds and retry up to 3 times. +*** End Patch + +### Format B: Full rewrite (only when most of the content changes) + +If the fix is so extensive that a patch would be larger than the full file, +output the complete file contents instead: + +*** Begin Files +*** File: SKILL.md +(complete file content) +*** File: examples/helper.sh +(complete file content) +*** End Files + +For single-file skills you may omit the ``*** Begin/End Files`` envelope +and output the complete SKILL.md content directly. + +### Rules + +- Do NOT wrap your output in markdown code fences (no ``` blocks). +- Prefer Format A (patch) for fixes — it is more precise and less error-prone. +- Only use Format B when the patch would touch more than ~60% of the file. + +## Self-Assessment + +After generating your edit, evaluate whether it adequately addresses the +issues identified in the direction and failure context above. + +**If your edit is satisfactory** — it addresses the root cause and the +resulting skill will work correctly — include `{evolution_complete}` on +the last line of your output. + +**If you cannot produce a satisfactory edit** — for example, the skill +is actually correct and the issue is external, you lack critical +information, or the requested change is not feasible — output ONLY: + +{evolution_failed} +Reason: + +Do NOT output any edit content if you signal failure. +""" + + +_EVOLUTION_DERIVED_TEMPLATE = """\ +You are a skill editor. Your job is to **derive** an enhanced version of an +existing skill. The new skill will live in a new directory; the original +stays unchanged. + +A skill is a directory containing ``SKILL.md`` (the main instruction file) +and optionally auxiliary files (scripts, configs, examples, etc.). + +## Parent Skill Content + +{parent_content} + +## Enhancement direction + +{direction} + +## Execution insights + +These are recent task executions that informed this enhancement: + +{execution_insights} + +## Skill health metrics + +{metric_summary} + +## Instructions + +1. Create an enhanced version that addresses the improvement direction. +2. Give the new skill a **different, concise name** (in frontmatter ``name:`` field) + that reflects its specialization or enhancement. + - Name MUST be ≤50 characters, lowercase, hyphens only (e.g. ``resilient-panel-unified``). + - Do NOT just append "-enhanced" or "-merged" to the parent name. + Instead, pick a descriptive name that captures the NEW capability + (e.g. ``panel-circuit-breaker`` instead of ``panel-component-enhanced-enhanced``). +3. Update ``description`` to reflect the new capability. +4. You may restructure, add steps, improve error handling, add alternatives, + or broaden/narrow scope as appropriate. +5. Maintain the YAML frontmatter format (``---`` fences with ``name`` and + ``description`` at minimum). +6. The derived skill should be self-contained — a user should be able to + follow it without referencing the parent. +7. You may add, modify, or remove auxiliary files as needed. + +## Output format + +Your output MUST have exactly two parts: + +**Part 1** — A summary line on the very first line: + +CHANGE_SUMMARY: + +**Part 2** — After one blank line, the actual changes in one of the formats +below. + +### Choosing a format + +- **Small enhancement** (new steps, improved wording, added error handling + while keeping most content intact): use Format A (patch). +- **Major restructure** or **substantially different skill**: use Format B + (full rewrite). This is also the best choice when creating a merged skill + from multiple parents. + +### Format A: Patch + +*** Begin Patch +*** Update File: +@@ + +- ++ + +*** Add File: ++ ++ +*** End Patch + +**How ``@@`` anchor lines work** (NOT unified-diff ``@@ -n,m +n,m @@``): +- ``@@`` followed by a line that exists verbatim in the file. The system + locates this line and applies changes starting there. +- After ``@@``, prefix lines with: ``-`` remove, ``+`` add, `` `` (space) keep. +- Multiple ``@@`` sections per file are allowed. + +Example — renaming and enhancing a skill: + +CHANGE_SUMMARY: Added retry logic and broadened scope to cover batch requests + +*** Begin Patch +*** Update File: SKILL.md +@@ name: api-request-guide +-name: api-request-guide +-description: How to make single API requests ++name: api-request-guide-enhanced ++description: Robust API requests with retry logic and batch support +@@ ## Steps + ## Steps ++ ++0. **Pre-check**: Verify the API endpoint is reachable with a HEAD request. +@@ 3. Send the request + 3. Send the request ++4. **Handle failures**: On 429/5xx, back off exponentially (1s, 2s, 4s) ++ up to 3 retries before reporting an error. +*** Add File: examples/batch_request.sh ++#!/bin/bash ++# Batch API request example ++for endpoint in "$@"; do ++ curl -s "$endpoint" || echo "FAILED: $endpoint" ++done +*** End Patch + +### Format B: Full rewrite + +*** Begin Files +*** File: SKILL.md +(complete file content) +*** File: examples/helper.sh +(complete file content) +*** End Files + +For single-file skills you may omit the envelope and output the complete +SKILL.md content directly. + +### Rules + +- Do NOT wrap your output in markdown code fences (no ``` blocks). +- The new skill MUST have a different ``name`` from the parent. + +## Self-Assessment + +After generating your edit, evaluate whether the derived skill is a +meaningful improvement over the parent(s). + +**If the derived skill is satisfactory** — it provides genuine +enhancement, is self-contained, and would benefit future executions — +include `{evolution_complete}` on the last line of your output. + +**If you cannot produce a worthwhile derived skill** — for example, the +parent skill is already optimal, the enhancement direction is not +feasible, or the result would be too similar to the parent — output ONLY: + +{evolution_failed} +Reason: + +Do NOT output any edit content if you signal failure. +""" + + +_EVOLUTION_CAPTURED_TEMPLATE = """\ +You are a skill author. Your job is to **capture** a reusable pattern that +was observed during task executions into a brand-new skill. + +A skill is a directory containing ``SKILL.md`` (the main instruction file) +and optionally auxiliary files (scripts, configs, examples, etc.). + +## Pattern to capture + +{direction} + +## Desired category + +``{category}`` + +Categories: +- ``tool_guide``: How to use a specific tool effectively +- ``workflow``: End-to-end multi-step procedure +- ``reference``: Reference knowledge / best practices + +## Execution context + +These are task executions where the pattern was observed: + +{execution_highlights} + +## Instructions + +1. Distill the observed pattern into a clear, reusable skill document. +2. Choose a concise, descriptive ``name`` (lowercase, hyphens for spaces). + - Name MUST be ≤50 characters (e.g. ``safe-file-write``, ``ts-compile-check``). + - Capture the core technique, not every detail. +3. Write a brief ``description`` that captures the skill's purpose. +4. Structure the body as clear, actionable instructions that an autonomous + agent can follow. Include code examples where helpful. +5. Make the skill **generalizable** — abstract away task-specific details + while preserving the core technique. +6. Use YAML frontmatter format (``---`` fences with ``name`` and + ``description``). +7. If the pattern benefits from auxiliary files (shell scripts, config + templates, etc.), include them. + +## Output format + +Your output MUST have exactly two parts: + +**Part 1** — A summary line on the very first line: + +CHANGE_SUMMARY: + +**Part 2** — After one blank line, the complete skill content. + +Since this is a brand-new skill, always output the **full content**. + +**If the skill has multiple files**, use the multi-file full format: + +*** Begin Files +*** File: SKILL.md +--- +name: my-skill-name +description: What this skill does +--- + +# My Skill + +Instructions here... +*** File: examples/setup.sh +#!/bin/bash +echo "setup script" +*** End Files + +**If the skill is just SKILL.md** (most common), output the complete +SKILL.md content directly (no ``*** Begin/End Files`` envelope needed): + +--- +name: my-skill-name +description: What this skill does +--- + +# My Skill + +Step-by-step instructions... + +### Rules + +- Do NOT wrap your output in markdown code fences (no ``` blocks). +- The SKILL.md MUST start with YAML frontmatter (``---`` fences) containing + at least ``name`` and ``description``. + +## Self-Assessment + +After generating the skill, evaluate whether it captures a genuinely +reusable pattern. + +**If the captured skill is satisfactory** — it is generalizable, clearly +written, and would benefit future executions — include +`{evolution_complete}` on the last line of your output. + +**If you cannot produce a worthwhile skill** — for example, the pattern +is too task-specific, too trivial, or already covered by existing skills +— output ONLY: + +{evolution_failed} +Reason: + +Do NOT output any skill content if you signal failure. +""" + + +_EVOLUTION_CONFIRM_TEMPLATE = """\ +You are an expert evaluating whether a skill needs evolution. + +A rule-based monitoring system has flagged a skill as a candidate for +evolution based on health metrics or tool degradation signals. Your job +is to **confirm or reject** this recommendation by examining the skill +content and recent execution history. + +## Skill Under Review + +**ID**: {skill_id} + +**Content** (may be truncated): + +{skill_content} + +## Proposed Evolution + +**Type**: ``{proposed_type}`` +**Direction**: {proposed_direction} + +## Trigger Context + +{trigger_context} + +## Recent Execution History + +{recent_analyses} + +## Decision Criteria + +Consider these factors: + +1. **Is the signal real?** Could the poor metrics be caused by external + factors (task distribution shift, temporary tool outage) rather than + a genuine skill deficiency? + +2. **Is the skill actually problematic?** Read the skill content — are + the instructions actually wrong/outdated, or are the metrics + misleading? + +3. **Is evolution worth the cost?** Would fixing/deriving this skill + meaningfully improve future executions, or is the skill rarely used + and not worth the LLM cost? + +4. **Is the proposed direction correct?** Does the suggested fix/derive + direction address the actual root cause? + +## Output Format + +Return **exactly one** JSON object (no markdown fences): + +{{ + "proceed": true, + "reasoning": "1-2 sentence explanation of your decision.", + "adjusted_direction": "Optional: refined direction if you agree but want to adjust the approach. Omit or set to empty string if the original direction is fine." +}} + +Set ``"proceed": false`` to skip this evolution. +Set ``"proceed": true`` to confirm it should proceed. +""" diff --git a/openspace/recording/__init__.py b/openspace/recording/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe64f8302f8a5f064c043c6cc86018f21e19119b --- /dev/null +++ b/openspace/recording/__init__.py @@ -0,0 +1,54 @@ +""" + RecordingManager + ├── internal management of platform.RecordingClient + ├── internal management of platform.ScreenshotClient + ├── internal management of TrajectoryRecorder + └── internal management of ActionRecorder +""" + +from importlib import import_module + +__all__ = [ + 'RecordingManager', + 'TrajectoryRecorder', + 'ActionRecorder', + 'load_trajectory_from_jsonl', + 'load_metadata', + 'format_trajectory_for_export', + 'analyze_trajectory', + 'load_recording_session', + 'filter_trajectory', + 'extract_errors', + 'generate_summary_report', + 'load_agent_actions', + 'analyze_agent_actions', + 'format_agent_actions', +] + +_EXPORTS = { + 'RecordingManager': ('.manager', 'RecordingManager'), + 'TrajectoryRecorder': ('.recorder', 'TrajectoryRecorder'), + 'ActionRecorder': ('.action_recorder', 'ActionRecorder'), + 'load_trajectory_from_jsonl': ('.utils', 'load_trajectory_from_jsonl'), + 'load_metadata': ('.utils', 'load_metadata'), + 'format_trajectory_for_export': ('.utils', 'format_trajectory_for_export'), + 'analyze_trajectory': ('.utils', 'analyze_trajectory'), + 'load_recording_session': ('.utils', 'load_recording_session'), + 'filter_trajectory': ('.utils', 'filter_trajectory'), + 'extract_errors': ('.utils', 'extract_errors'), + 'generate_summary_report': ('.utils', 'generate_summary_report'), + 'load_agent_actions': ('.action_recorder', 'load_agent_actions'), + 'analyze_agent_actions': ('.action_recorder', 'analyze_agent_actions'), + 'format_agent_actions': ('.action_recorder', 'format_agent_actions'), +} + + +def __getattr__(name: str): + try: + module_name, attr_name = _EXPORTS[name] + except KeyError as exc: + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') from exc + module = import_module(module_name, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value diff --git a/openspace/recording/action_recorder.py b/openspace/recording/action_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..f4311a0517aada3b013c54be310d6d1676e84a2e --- /dev/null +++ b/openspace/recording/action_recorder.py @@ -0,0 +1,271 @@ +""" +Agent Action Recorder + +Records agent decision-making processes, reasoning, and outputs. +Focuses on high-level agent behaviors rather than low-level tool executions. +""" + +import datetime +import json +from typing import Any, Dict, Optional +from pathlib import Path + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class ActionRecorder: + """ + Records agent actions and decision-making processes. + + This recorder captures the 'thinking' layer of the agent: + - Task planning and decomposition + - Tool selection reasoning + - Evaluation decisions + """ + + def __init__(self, trajectory_dir: Path): + """ + Initialize action recorder. + + Args: + trajectory_dir: Directory to save action records + """ + self.trajectory_dir = trajectory_dir + self.actions_file = trajectory_dir / "agent_actions.jsonl" + self.step_counter = 0 + + # Ensure directory exists + self.trajectory_dir.mkdir(parents=True, exist_ok=True) + + async def record_action( + self, + agent_name: str, + action_type: str, + input_data: Optional[Dict[str, Any]] = None, + reasoning: Optional[Dict[str, Any]] = None, + output_data: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + related_tool_steps: Optional[list] = None, + correlation_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Record an agent action. + + Args: + agent_name: Name of the agent performing the action + action_type: Type of action (plan | execute | evaluate | monitor) + input_data: Input data the agent received (simplified) + reasoning: Agent's reasoning process (structured) + output_data: Agent's output/decision (structured) + metadata: Additional metadata (LLM model, tokens, duration, etc.) + related_tool_steps: List of tool execution step numbers related to this action + correlation_id: Optional correlation ID to link related events + """ + self.step_counter += 1 + timestamp = datetime.datetime.now().isoformat() + + # Infer agent type from agent name + agent_type = self._infer_agent_type(agent_name) + + action_info = { + "step": self.step_counter, + "timestamp": timestamp, + "agent_name": agent_name, + "agent_type": agent_type, + "action_type": action_type, + "correlation_id": correlation_id or f"action_{self.step_counter}_{timestamp}", + } + + # Add input (with smart truncation) + if input_data: + action_info["input"] = self._truncate_data(input_data, max_length=1000) + + # Add reasoning (keep structured) + if reasoning: + action_info["reasoning"] = self._truncate_data(reasoning, max_length=2000) + + # Add output (keep structured) + if output_data: + action_info["output"] = self._truncate_data(output_data, max_length=1000) + + # Add metadata + if metadata: + action_info["metadata"] = metadata + + # Add related tool steps for correlation + if related_tool_steps: + action_info["related_tool_steps"] = related_tool_steps + + # Append to JSONL file + await self._append_to_file(action_info) + + logger.debug( + f"Recorded {action_type} action from {agent_name} (step {self.step_counter})" + ) + + return action_info + + def _infer_agent_type(self, agent_name: str) -> str: + name_lower = agent_name.lower() + + if "host" in name_lower: + return "host" + elif "grounding" in name_lower: + return "grounding" + elif "eval" in name_lower: + return "eval" + elif "coordinator" in name_lower: + return "coordinator" + else: + return "unknown" + + def _truncate_data(self, data: Any, max_length: int) -> Any: + if isinstance(data, str): + if len(data) > max_length: + return data[:max_length] + "... [truncated]" + return data + + elif isinstance(data, dict): + result = {} + for key, value in data.items(): + if isinstance(value, str) and len(value) > max_length: + result[key] = value[:max_length] + "... [truncated]" + elif isinstance(value, (dict, list)): + # Recursively truncate nested structures + result[key] = self._truncate_data(value, max_length) + else: + result[key] = value + return result + + elif isinstance(data, list): + # Truncate list items + result = [] + for item in data: + if isinstance(item, str) and len(item) > max_length: + result.append(item[:max_length] + "... [truncated]") + elif isinstance(item, (dict, list)): + result.append(self._truncate_data(item, max_length)) + else: + result.append(item) + return result + + else: + return data + + async def _append_to_file(self, action_info: Dict[str, Any]): + """Append action to JSONL file.""" + with open(self.actions_file, "a", encoding="utf-8") as f: + f.write(json.dumps(action_info, ensure_ascii=False)) + f.write("\n") + + def get_step_count(self) -> int: + """Get current step count.""" + return self.step_counter + + +def load_agent_actions(trajectory_dir: str) -> list: + """ + Load agent actions from a trajectory directory. + """ + actions_file = Path(trajectory_dir) / "agent_actions.jsonl" + + if not actions_file.exists(): + logger.warning(f"Agent actions file not found: {actions_file}") + return [] + + actions = [] + try: + with open(actions_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + actions.append(json.loads(line)) + + logger.info(f"Loaded {len(actions)} agent actions from {actions_file}") + return actions + + except Exception as e: + logger.error(f"Failed to load agent actions from {actions_file}: {e}") + return [] + + +def analyze_agent_actions(actions: list) -> Dict[str, Any]: + """ + Analyze agent actions and generate statistics. + """ + if not actions: + return { + "total_actions": 0, + "by_agent": {}, + "by_type": {}, + } + + # Count by agent + by_agent = {} + by_type = {} + + for action in actions: + agent_name = action.get("agent_name", "unknown") + action_type = action.get("action_type", "unknown") + + by_agent[agent_name] = by_agent.get(agent_name, 0) + 1 + by_type[action_type] = by_type.get(action_type, 0) + 1 + + return { + "total_actions": len(actions), + "by_agent": by_agent, + "by_type": by_type, + } + + +def format_agent_actions(actions: list, format_type: str = "compact") -> str: + """ + Format agent actions for display. + """ + if not actions: + return "No agent actions recorded" + + if format_type == "compact": + lines = [] + for action in actions: + step = action.get("step", "?") + agent = action.get("agent_name", "?") + action_type = action.get("action_type", "?") + + # Try to extract key info from reasoning or output + key_info = "" + if action.get("reasoning"): + thought = action["reasoning"].get("thought", "") + if thought: + key_info = f": {thought[:60]}..." + + lines.append(f"Step {step}: [{agent}] {action_type}{key_info}") + + return "\n".join(lines) + + elif format_type == "detailed": + lines = [] + for action in actions: + lines.append(f"\n{'='*60}") + lines.append(f"Step {action.get('step', '?')}: {action.get('agent_name', '?')}") + lines.append(f"Type: {action.get('action_type', '?')}") + lines.append(f"Time: {action.get('timestamp', '?')}") + + if action.get("reasoning"): + lines.append("\nReasoning:") + lines.append(json.dumps(action["reasoning"], indent=2, ensure_ascii=False)) + + if action.get("output"): + lines.append("\nOutput:") + lines.append(json.dumps(action["output"], indent=2, ensure_ascii=False)) + + if action.get("metadata"): + lines.append("\nMetadata:") + lines.append(json.dumps(action["metadata"], indent=2, ensure_ascii=False)) + + return "\n".join(lines) + + else: + raise ValueError(f"Unknown format type: {format_type}") \ No newline at end of file diff --git a/openspace/recording/manager.py b/openspace/recording/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6d7ed35da48a7496c84a8f7f35c4904197df9f --- /dev/null +++ b/openspace/recording/manager.py @@ -0,0 +1,1141 @@ +import datetime +import json +import ast +import types +from typing import Any, Dict, List, Optional +from pathlib import Path + +from openspace.utils.logging import Logger +from .recorder import TrajectoryRecorder +from .action_recorder import ActionRecorder + +logger = Logger.get_logger(__name__) + + +class RecordingManager: + # Global instance management (singleton pattern) + _global_instance: Optional['RecordingManager'] = None + + def __init__( + self, + enabled: bool = True, + task_id: str = "", + log_dir: str = "./logs/recordings", + backends: Optional[List[str]] = None, + enable_screenshot: bool = True, + enable_video: bool = False, + enable_conversation_log: bool = True, + auto_save_interval: int = 10, + server_url: Optional[str] = None, + agent_name: str = "GroundingAgent", + ): + """ + Initialize automatic recording manager + + Args: + enabled: whether to enable recording + task_id: task ID (for naming recording directory) + log_dir: log directory path + backends: list of backends to record (None = all) + (optional: "mcp", "gui", "shell", "system", "web") + enable_screenshot: whether to enable screenshot (through platform.ScreenshotClient) + enable_video: whether to enable video recording (through platform.RecordingClient) + enable_conversation_log: whether to save LLM conversations to conversations.jsonl (default: True) + auto_save_interval: automatic save interval (steps) + server_url: local server address (None = read from config/environment variables) + agent_name: name of the agent performing the recording (default: "GroundingAgent") + """ + self.enabled = enabled + self.task_id = task_id + self.log_dir = log_dir + self.backends = set(backends) if backends else {"mcp", "gui", "shell", "system", "web"} + self.enable_screenshot = enable_screenshot + self.enable_video = enable_video + self.enable_conversation_log = enable_conversation_log + self.auto_save_interval = auto_save_interval + self.server_url = server_url + self.agent_name = agent_name + + # internal state + self._recorder: Optional[TrajectoryRecorder] = None + self._action_recorder: Optional[ActionRecorder] = None + self._is_started = False + self._step_counter = 0 + + # registered LLM clients + self._registered_llm_clients = [] + self._original_methods = {} + + # video/screenshot clients (internal management) + self._recording_client = None + self._screenshot_client = None + + # Register as global instance + RecordingManager._global_instance = self + + @classmethod + def is_recording(cls) -> bool: + """ + Check if there is an active recording session + + Returns: + bool: True if recording is active + """ + return cls._global_instance is not None and cls._global_instance._is_started + + @classmethod + async def record_retrieved_tools( + cls, + task_instruction: str, + tools: List[Any], + search_debug_info: Optional[Dict[str, Any]] = None, + ): + """ + Record the tools retrieved for a task + + Args: + task_instruction: The task instruction used for retrieval + tools: List of retrieved tools + search_debug_info: Debug info from search (similarity scores, LLM selections) + """ + instance = cls._global_instance + if not instance or not instance._is_started or not instance._recorder: + return + + # Extract tool info + tool_info = [] + for tool in tools: + info = { + "name": getattr(tool, "name", str(tool)), + } + # Prefer runtime_info.backend + # over backend_type (may be NOT_SET for cached RemoteTools) + runtime_info = getattr(tool, "_runtime_info", None) + if runtime_info and hasattr(runtime_info, "backend"): + info["backend"] = runtime_info.backend.value if hasattr(runtime_info.backend, "value") else str(runtime_info.backend) + info["server_name"] = runtime_info.server_name + elif hasattr(tool, "backend_type"): + info["backend"] = tool.backend_type.value if hasattr(tool.backend_type, "value") else str(tool.backend_type) + tool_info.append(info) + + # Build metadata + metadata = { + "instruction": task_instruction[:500], # Truncate long instructions + "count": len(tools), + "tools": tool_info, + } + + # Add search debug info if available + if search_debug_info: + metadata["search_debug"] = { + "search_mode": search_debug_info.get("search_mode", ""), + "total_candidates": search_debug_info.get("total_candidates", 0), + "mcp_count": search_debug_info.get("mcp_count", 0), + "non_mcp_count": search_debug_info.get("non_mcp_count", 0), + "llm_filter": search_debug_info.get("llm_filter", {}), + "tool_scores": search_debug_info.get("tool_scores", []), + } + + # Save to metadata + await instance._recorder.add_metadata("retrieved_tools", metadata) + + logger.info(f"Recorded {len(tools)} retrieved tools (with search debug info: {search_debug_info is not None})") + + @classmethod + async def record_skill_selection( + cls, + selection_record: Dict[str, Any], + ): + """ + Record skill selection decision to metadata.json. + + This captures the pre-execution skill matching conversation: + - Which skills were available + - The LLM prompt and response (or keyword fallback) + - Which skills were selected + + Args: + selection_record: Structured record from SkillRegistry.select_skills_with_llm() + Keys: method, task, available_skills, prompt, llm_response, selected, error + """ + instance = cls._global_instance + if not instance or not instance._is_started or not instance._recorder: + return + + # Save to metadata alongside retrieved_tools + await instance._recorder.add_metadata("skill_selection", selection_record) + + selected = selection_record.get("selected", []) + method = selection_record.get("method", "unknown") + logger.info( + f"Recorded skill selection: {len(selected)} selected via {method} " + f"(from {len(selection_record.get('available_skills', []))} available)" + ) + + @staticmethod + def _truncate_messages( + messages: List[Dict[str, Any]], + max_content_length: int = 5000, + ) -> List[Dict[str, Any]]: + """Truncate message content to avoid huge log files.""" + result = [] + for msg in messages: + new_msg = {"role": msg.get("role", "unknown")} + content = msg.get("content", "") + + if isinstance(content, str): + if len(content) > max_content_length: + new_msg["content"] = content[:max_content_length] + f"... [truncated, total {len(content)} chars]" + else: + new_msg["content"] = content + elif isinstance(content, list): + # Handle multi-part content (e.g., with images) + new_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "image": + new_content.append({"type": "image", "note": "[image data omitted]"}) + elif item.get("type") == "text": + text = item.get("text", "") + if len(text) > max_content_length: + new_content.append({ + "type": "text", + "text": text[:max_content_length] + f"... [truncated, total {len(text)} chars]" + }) + else: + new_content.append(item) + else: + new_content.append(item) + else: + new_content.append(item) + new_msg["content"] = new_content + else: + new_msg["content"] = str(content)[:max_content_length] + + if "tool_calls" in msg: + new_msg["tool_calls"] = msg["tool_calls"] + + result.append(new_msg) + return result + + @classmethod + async def record_conversation_setup( + cls, + setup_messages: List[Dict[str, Any]], + tools: Optional[List] = None, + max_content_length: int = 5000, + agent_name: str = "GroundingAgent", + extra: Optional[Dict[str, Any]] = None, + ): + """ + Record initial conversation context to conversations.jsonl (called once before iterations). + + Writes a ``type: "setup"`` line containing all system messages, the user + instruction, **and** the tool schemas exposed to the LLM so the log + gives a complete picture of what the model sees. + + Args: + setup_messages: The initial messages list (system prompts + user instruction). + tools: BaseTool list passed to the LLM (optional). Each tool's + name, backend, and description are recorded. + max_content_length: Max length for message content truncation. + agent_name: Agent/phase identifier. Used to distinguish conversations + from different pipeline stages during replay. + Common values: "GroundingAgent", "ExecutionAnalyzer", + "SkillEvolver", "SkillEvolver.confirm", "SkillEvolver.retry". + extra: Optional dict of additional context (e.g. evolution_type, + trigger, target_skills) merged into the record. + """ + instance = cls._global_instance + if not instance or not instance._is_started or not instance._recorder: + return + if not getattr(instance, 'enable_conversation_log', True): + return + + record: Dict[str, Any] = { + "type": "setup", + "agent_name": agent_name, + "timestamp": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"), + "messages": cls._truncate_messages(setup_messages, max_content_length), + } + if extra: + record["extra"] = extra + + # Record tool definitions so the log shows what the LLM can call. + # Description includes the [Backend] tag that the LLM actually sees. + if tools: + _BACKEND_LABELS = { + "mcp": "MCP", "shell": "Shell", "gui": "GUI", + "web": "Web", "system": "System", + } + tool_defs = [] + for t in tools: + schema = getattr(t, "schema", None) + if schema: + backend_val = getattr(schema, "backend_type", None) + backend_str = ( + backend_val.value + if hasattr(backend_val, "value") + else str(backend_val) if backend_val else None + ) + entry: Dict[str, Any] = { + "name": schema.name, + "backend": backend_str, + } + if schema.description: + desc = schema.description + # Mirror the [Backend] tag that _prepare_tools_for_llmclient + # adds so the recording matches what the LLM sees. + if backend_str and backend_str not in ("not_set",): + label = _BACKEND_LABELS.get(backend_str, backend_str) + desc = f"[{label}] {desc}" + if len(desc) > 200: + desc = desc[:200] + "..." + entry["description"] = desc + else: + entry = {"name": getattr(t, "name", str(t))} + tool_defs.append(entry) + record["tools"] = tool_defs + + conv_file = instance._recorder.trajectory_dir / "conversations.jsonl" + try: + with open(conv_file, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False)) + f.write("\n") + except Exception as e: + logger.debug(f"Failed to write conversation setup: {e}") + + @classmethod + async def record_iteration_context( + cls, + iteration: int, + delta_messages: List[Dict[str, Any]], + response_metadata: Dict[str, Any], + max_content_length: int = 5000, + agent_name: str = "GroundingAgent", + extra: Optional[Dict[str, Any]] = None, + ): + """ + Record a single iteration's delta messages to conversations.jsonl. + + Only the messages produced during this iteration are stored (assistant + response, tool results, inter-iteration guidance), avoiding repetition + of system prompts and initial user instruction. The initial context is + stored once via ``record_conversation_setup``. The full conversation + can be reconstructed by concatenating the setup with all deltas in order. + + Args: + iteration: Iteration number (1-based). + delta_messages: Messages added during this iteration (assistant + tool results). + response_metadata: Lightweight metadata about the LLM response + (has_tool_calls, tool_calls_count). + max_content_length: Max length for message content truncation. + agent_name: Agent/phase identifier (must match the corresponding + ``record_conversation_setup`` call). + extra: Optional dict of additional context merged into the record. + """ + instance = cls._global_instance + if not instance or not instance._is_started or not instance._recorder: + return + if not getattr(instance, 'enable_conversation_log', True): + return + + record = { + "type": "iteration", + "agent_name": agent_name, + "iteration": iteration, + "timestamp": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"), + "response_metadata": response_metadata, + "delta_messages": cls._truncate_messages(delta_messages, max_content_length), + } + if extra: + record["extra"] = extra + + # Append to conversations.jsonl (real-time) + conv_file = instance._recorder.trajectory_dir / "conversations.jsonl" + try: + with open(conv_file, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False)) + f.write("\n") + except Exception as e: + logger.debug(f"Failed to write conversation log: {e}") + + @classmethod + async def record_tool_execution( + cls, + tool_name: str, + backend: str, + parameters: Dict[str, Any], + result: Any, + server_name: Optional[str] = None, + is_success: bool = True, + metadata: Optional[Dict[str, Any]] = None, + ): + """ + Record tool execution (internal method, called by BaseTool automatically) + + Args: + tool_name: Name of the tool + backend: Backend type (gui, shell, mcp, etc.) + parameters: Tool parameters + result: Tool execution result (content or error message) + server_name: Server name for MCP backend + is_success: Whether the tool execution was successful (default: True for backward compatibility) + metadata: Tool result metadata (e.g. intermediate_steps for GUI) + """ + if not cls._global_instance or not cls._global_instance._is_started: + return + + instance = cls._global_instance + + # Infer backend if not_set or not in allowed backends + if backend == "not_set" or backend not in instance.backends: + inferred = cls._infer_backend_from_tool_name(tool_name) + if inferred and inferred in instance.backends: + backend = inferred + elif backend not in instance.backends: + logger.debug( + f"Backend '{backend}' not in recording backends {instance.backends}, " + f"skipping recording for tool '{tool_name}'" + ) + return + + # Create mock tool_call and result objects for compatibility with existing _record_* methods + class MockFunctionCall: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + class MockToolCall: + def __init__(self, name, arguments): + self.function = MockFunctionCall(name, arguments) + + class MockResult: + def __init__(self, content, is_success=True, metadata=None): + self.content = content + self.is_success = is_success + self.is_error = not is_success + self.error = content if not is_success else None + self.metadata = metadata or {} + + tool_call = MockToolCall(tool_name, parameters) + mock_result = MockResult(result, is_success=is_success, metadata=metadata) + + try: + if backend == "mcp": + server = server_name or "unknown" + await instance._record_mcp(tool_call, mock_result, server) + elif backend == "gui": + await instance._record_gui(tool_call, mock_result) + elif backend == "shell": + await instance._record_shell(tool_call, mock_result) + elif backend == "system": + await instance._record_system(tool_call, mock_result) + elif backend == "web": + await instance._record_web(tool_call, mock_result) + else: + logger.warning(f"No recording handler for backend '{backend}', tool '{tool_name}'") + return + + instance._step_counter += 1 + except Exception as e: + logger.warning(f"Failed to record tool execution for {tool_name}: {e}") + + @staticmethod + def _parse_arguments(arg_data): + """Safely parse tool_call.function.arguments which may be JSON string. + + Handles: + 1. Proper JSON strings with true/false/null + 2. Python literal strings (produced by OpenAI) using ast.literal_eval + 3. Already-dict objects (returned by SDK) + """ + if not isinstance(arg_data, str): + return arg_data or {} + + # First, try JSON + try: + return json.loads(arg_data) + except json.JSONDecodeError: + pass + + # Fallback to Python literal + try: + return ast.literal_eval(arg_data) + except Exception: + logger.debug("Failed to parse arguments, returning raw string") + return {"raw": arg_data} + + async def start(self, task_id: Optional[str] = None): + """Start automatic recording + Args: + task_id: If provided, override the current task_id for this recording session. This allows + external callers (e.g. Coordinator) to specify a meaningful task identifier without + having to recreate the RecordingManager instance. + """ + # Allow dynamic update of task_id before recording actually starts + if task_id: + self.task_id = task_id + if not self.enabled or self._is_started: + return + + try: + # check server availability (only when video or screenshot is enabled) + if self.enable_video or self.enable_screenshot: + await self._check_server_availability() + + self._recorder = TrajectoryRecorder( + task_name=self.task_id, + log_dir=self.log_dir, + enable_screenshot=self.enable_screenshot, + enable_video=self.enable_video, + server_url=self.server_url, + ) + + # create action recorder for agent decision tracking + self._action_recorder = ActionRecorder( + trajectory_dir=Path(self._recorder.get_trajectory_dir()) + ) + + + # create video client (internal management) + if self.enable_video: + from openspace.platforms import RecordingClient + self._recording_client = RecordingClient(base_url=self.server_url) + success = await self._recording_client.start_recording() + if success: + logger.info("Video recording started") + else: + logger.warning("Video recording failed to start") + + # create screenshot client (internal management) + if self.enable_screenshot: + from openspace.platforms import ScreenshotClient + self._screenshot_client = ScreenshotClient(base_url=self.server_url) + logger.debug("Screenshot client ready") + + # save initial metadata + await self._recorder.add_metadata("task_id", self.task_id) + await self._recorder.add_metadata("backends", list(self.backends)) + await self._recorder.add_metadata("start_time", datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")) + + # Capture and save initial screenshot if enabled + if self.enable_screenshot and self._screenshot_client: + try: + init_shot = await self._screenshot_client.capture() + if init_shot: + await self._recorder.save_init_screenshot(init_shot) + logger.debug("Initial screenshot saved") + except Exception as e: + logger.debug(f"Failed to capture initial screenshot: {e}") + + self._is_started = True + logger.info(f"Recording started: {self._recorder.get_trajectory_dir()}") + + except Exception as e: + logger.error(f"Recording failed to start: {e}") + raise + + async def _check_server_availability(self): + """Check if local server is available""" + try: + from openspace.platforms import SystemInfoClient + + # Use context manager to ensure aiohttp session is closed, avoiding warning of unclosed session + async with SystemInfoClient(base_url=self.server_url) as client: + info = await client.get_system_info() + + if info: + logger.info(f"Server connected ({info.get('platform', 'unknown')})") + else: + logger.warning("Server not responding, video/screenshot functionality unavailable") + + except Exception: + logger.warning("Cannot connect to server, video/screenshot functionality unavailable") + + async def save_execution_outcome( + self, + status: str, + iterations: int, + execution_time: float = 0, + ) -> None: + """Persist task-level execution outcome into metadata.json. + + Should be called **before** ``stop()`` so the data is included in the + finalized recording. The saved dict has the structure:: + + {"status": "success"|"incomplete"|"error", + "iterations": int, + "execution_time": float} + """ + if self._recorder: + await self._recorder.add_metadata("execution_outcome", { + "status": status, + "iterations": iterations, + "execution_time": round(execution_time, 2), + }) + + async def stop(self): + """Stop automatic recording""" + if not self.enabled or not self._is_started: + return + + try: + # stop video recording and save + if self._recording_client: + try: + video_path = None + if self._recorder: + video_path = str(Path(self._recorder.get_trajectory_dir()) / "screen_recording.mp4") + + video_bytes = await self._recording_client.end_recording(dest=video_path) + if video_bytes and video_path: + video_size_mb = len(video_bytes) / (1024 * 1024) + logger.info(f"Video recording saved: {video_path} ({video_size_mb:.2f} MB)") + except Exception as e: + logger.warning(f"Video recording failed to save: {e}") + + # close RecordingClient session, avoid unclosed session warning + try: + if self._recording_client: + await self._recording_client.close() + except Exception as e: + logger.debug(f"Failed to close RecordingClient session: {e}") + + # close screenshot client + if self._screenshot_client: + try: + await self._screenshot_client.close() + except Exception as e: + logger.debug(f"Screenshot client failed to close: {e}") + finally: + self._screenshot_client = None + + # finalize trajectory recording + if self._recorder: + # save final metadata + await self._recorder.add_metadata("end_time", datetime.datetime.now().isoformat()) + await self._recorder.add_metadata("total_steps", self._step_counter) + + # generate summary + await self.generate_summary() + + # finalize recording + await self._recorder.finalize() + + logger.info(f"Recording completed: {self._recorder.get_trajectory_dir()}") + + # Restore original methods for registered LLM clients + for client in self._registered_llm_clients: + client_id = id(client) + if client_id in self._original_methods: + try: + original_method = self._original_methods[client_id] + client.complete = original_method + except Exception as e: + logger.debug(f"Failed to restore original method for LLM client: {e}") + self._registered_llm_clients.clear() + self._original_methods.clear() + + self._is_started = False + self._recorder = None + self._action_recorder = None + + except Exception as e: + logger.error(f"Recording failed to stop: {e}") + + def register_to_llm(self, llm_client): + """Register LLM client: wrap complete() to record tool results (Path B, aligned with AnyTool).""" + if not self.enabled: + return + if id(llm_client) in self._original_methods: + return + original_complete = llm_client.complete + self._original_methods[id(llm_client)] = original_complete + + async def wrapped_complete(self_client, *args, **kwargs): + response = await original_complete(*args, **kwargs) + if response.get("tool_results"): + await self._auto_record_tool_results(response["tool_results"]) + return response + + llm_client.complete = types.MethodType(wrapped_complete, llm_client) + self._registered_llm_clients.append(llm_client) + + @staticmethod + def _infer_backend_from_tool_name(tool_name: str) -> Optional[str]: + """Infer backend from tool name when tool_results lack backend.""" + if not tool_name or not isinstance(tool_name, str): + return None + name = tool_name.strip() + if "__" in name: + name = name.split("__", 1)[-1] + shell_tools = {"shell_agent", "read_file", "write_file", "list_dir", "run_shell"} + if name in shell_tools: + return "shell" + if name in ("gui_agent",) or "gui" in name.lower(): + return "gui" + if "mcp" in name.lower() or ("." in name and "__" not in name): + return "mcp" + if name in ("deep_research_agent", "deep_research"): + return "web" + return None + + async def _auto_record_tool_results(self, tool_results: List[Dict]): + """Record tool execution results from LLM complete() (Path B, aligned with AnyTool).""" + if not self._recorder or not self._is_started: + return + for tool_result in tool_results: + tool_call = tool_result.get("tool_call") + result = tool_result.get("result") + backend = tool_result.get("backend") + server_name = tool_result.get("server_name") + + if not tool_call or not result: + continue + if not backend: + _name = getattr(getattr(tool_call, "function", None), "name", None) or str(tool_result.get("tool_call", "")) + backend = self._infer_backend_from_tool_name(_name) + if not backend: + logger.warning(f"Tool result missing 'backend', cannot infer for '{_name}', skipping") + continue + + result_metadata = result.metadata if hasattr(result, 'metadata') else None + await RecordingManager.record_tool_execution( + tool_name=tool_call.function.name, + backend=backend, + parameters=self._parse_arguments(tool_call.function.arguments), + result=result.content if hasattr(result, 'content') else str(result), + server_name=server_name, + is_success=result.is_success if hasattr(result, 'is_success') else True, + metadata=result_metadata, + ) + + async def _record_mcp(self, tool_call, result, server: str): + tool_name = tool_call.function.name + parameters = self._parse_arguments(tool_call.function.arguments) + + command = f"{server}.{tool_name}" + result_str = str(result.content) if result.is_success else str(result.error) + result_brief = result_str[:200] + "..." if len(result_str) > 200 else result_str + + is_actual_success = result.is_success and not result_str.startswith("ERROR:") + + step_info = await self._recorder.record_step( + backend="mcp", + tool=tool_name, + command=command, + result={ + "status": "success" if is_actual_success else "error", + "output": result_brief, + }, + parameters=parameters, + extra={ + "server": server, + }, + auto_screenshot=self.enable_screenshot + ) + + # Add agent_name to step_info + step_info["agent_name"] = self.agent_name + + async def _record_gui(self, tool_call, result): + tool_name = tool_call.function.name + parameters = self._parse_arguments(tool_call.function.arguments) + + # Extract actual pyautogui command (from action_history) + command = "gui_agent" + if result.is_success and hasattr(result, 'metadata') and result.metadata: + action_history = result.metadata.get("action_history", []) + if action_history: + # Get last successful execution action + for action in reversed(action_history): + planned_action = action.get("planned_action", {}) + execution_result = action.get("execution_result", {}) + + if planned_action.get("action_type") == "PYAUTOGUI_COMMAND": + cmd = planned_action.get("command", "") + if cmd and execution_result.get("status") == "success": + command = cmd + break + elif execution_result.get("status") == "success": + action_type = planned_action.get("action_type", "") + if action_type and action_type not in ["WAIT", "DONE", "FAIL"]: + params = planned_action.get("parameters", {}) + if params: + param_str = ", ".join([f"{k}={v}" for k, v in list(params.items())[:2]]) + command = f"{action_type}({param_str})" + else: + command = action_type + break + + result_str = str(result.content) if result.is_success else str(result.error) + + is_actual_success = result.is_success + if result.is_success: + first_200_chars = result_str[:200] if result_str else "" + critical_failure_patterns = ["Task failed", "CRITICAL ERROR:", "FATAL:"] + has_critical_failure = any(pattern in first_200_chars for pattern in critical_failure_patterns) + is_actual_success = not has_critical_failure + + # Extract intermediate_steps from metadata for embedding in traj.jsonl + extra = {} + if hasattr(result, 'metadata') and result.metadata: + intermediate_steps = result.metadata.get("intermediate_steps") + if intermediate_steps: + extra["intermediate_steps"] = intermediate_steps + + step_info = await self._recorder.record_step( + backend="gui", + tool="gui_agent", + command=command, + result={ + "status": "success" if is_actual_success else "error", + "output": result_str, + }, + parameters=parameters, + auto_screenshot=self.enable_screenshot, + extra=extra if extra else None, + ) + + step_info["agent_name"] = self.agent_name + + async def _record_shell(self, tool_call, result): + tool_name = tool_call.function.name + parameters = self._parse_arguments(tool_call.function.arguments) + + task = parameters.get("task", tool_name) + exit_code = 0 if result.is_success else 1 + + stdout = str(result.content) if result.is_success else "" + stderr = str(result.error) if result.is_error else "" + + command = task + if hasattr(result, 'metadata') and result.metadata: + code_history = result.metadata.get("code_history", []) + if code_history: + # Try to find the last successful execution + found_success = False + for code_info in reversed(code_history): + if code_info.get("status") == "success": + lang = code_info.get("lang", "bash") + code = code_info.get("code", "") + # String format code block: ```lang\ncode\n``` + command = f"```{lang}\n{code}\n```" + found_success = True + break + + # If no successful execution found, use last code block + if not found_success and code_history: + last_code = code_history[-1] + lang = last_code.get("lang", "bash") + code = last_code.get("code", "") + command = f"```{lang}\n{code}\n```" + + stdout_brief = stdout[:200] + "..." if len(stdout) > 200 else stdout + stderr_brief = stderr[:200] + "..." if len(stderr) > 200 else stderr + + is_actual_success = result.is_success + if result.is_success: + first_500_chars = stdout[:500] if stdout else "" + critical_failure_patterns = [ + "Task failed after", + "[TASK_FAILED:", + "EXECUTION ERROR", + "timed out", + ] + has_critical_failure = any(pattern in first_500_chars for pattern in critical_failure_patterns) + is_actual_success = not has_critical_failure + + step_info = await self._recorder.record_step( + backend="shell", + tool="shell_agent", + command=command, + result={ + "status": "success" if is_actual_success else "error", + "exit_code": exit_code, + "stdout": stdout_brief, + "stderr": stderr_brief, + }, + auto_screenshot=self.enable_screenshot + ) + + step_info["agent_name"] = self.agent_name + + async def _record_system(self, tool_call, result): + tool_name = tool_call.function.name + parameters = self._parse_arguments(tool_call.function.arguments) + + command = tool_name + if parameters: + key_params = [] + for key in ['path', 'file', 'directory', 'name', 'provider', 'backend']: + if key in parameters and parameters[key]: + key_params.append(f"{parameters[key]}") + if key_params: + command = f"{tool_name}({', '.join(key_params[:2])})" + + result_str = str(result.content) if result.is_success else str(result.error) + result_brief = result_str[:200] + "..." if len(result_str) > 200 else result_str + + is_actual_success = result.is_success + if result.is_success and result_str: + is_actual_success = not result_str.startswith("ERROR:") + + step_info = await self._recorder.record_step( + backend="system", + tool=tool_name, + command=command, + result={ + "status": "success" if is_actual_success else "error", + "output": result_brief, + }, + auto_screenshot=self.enable_screenshot + ) + + step_info["agent_name"] = self.agent_name + + async def _record_web(self, tool_call, result): + tool_name = tool_call.function.name + parameters = self._parse_arguments(tool_call.function.arguments) + + query = parameters.get("query", "") + command = query if query else "deep_research" + + result_str = str(result.content) if result.is_success else str(result.error) + + is_actual_success = result.is_success + if result.is_success and result_str: + is_actual_success = not result_str.startswith("ERROR:") + + step_info = await self._recorder.record_step( + backend="web", + tool="deep_research_agent", + command=command, + result={ + "status": "success" if is_actual_success else "error", + "output": result_str, # Full output preserved for training/replay + }, + auto_screenshot=self.enable_screenshot + ) + + # Add agent_name to step_info + step_info["agent_name"] = self.agent_name + + async def add_metadata(self, key: str, value: Any): + if self._recorder: + await self._recorder.add_metadata(key, value) + + async def save_plan(self, plan: Dict[str, Any], agent_name: str = "GroundingAgent"): + """ + Save agent plan to recording directory. + This integrates planning information with execution trajectory. + + Args: + plan: The plan data (usually containing task_updates or plan steps) + agent_name: Name of the agent creating the plan + """ + if not self._recorder or not self._is_started: + logger.warning("Cannot save plan: recording not started") + return + + try: + plan_dir = Path(self._recorder.get_trajectory_dir()) / "plans" + plan_dir.mkdir(exist_ok=True) + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + plan_data = { + "version": timestamp, + "created_at": datetime.datetime.now().isoformat(), + "created_by": agent_name, + "plan": plan + } + + # Save versioned plan + plan_file = plan_dir / f"plan_{timestamp}.json" + with open(plan_file, 'w', encoding='utf-8') as f: + json.dump(plan_data, f, indent=2, ensure_ascii=False) + + # Save current plan (latest) + current_plan_file = plan_dir / "current_plan.json" + with open(current_plan_file, 'w', encoding='utf-8') as f: + json.dump(plan_data, f, indent=2, ensure_ascii=False) + + logger.debug(f"Saved plan to recording: {plan_file.name}") + except Exception as e: + logger.error(f"Failed to save plan: {e}") + + async def log_decision( + self, + agent_name: str, + decision: str, + context: Optional[Dict[str, Any]] = None + ): + """ + Log agent decision with optional context. + This provides insight into agent reasoning process. + + Args: + agent_name: Name of the agent making the decision + decision: Description of the decision + context: Additional context information + """ + if not self._recorder or not self._is_started: + logger.warning("Cannot log decision: recording not started") + return + + try: + traj_dir = Path(self._recorder.get_trajectory_dir()) + log_file = traj_dir / "decisions.log" + + timestamp = datetime.datetime.now().isoformat() + log_entry = f"[{timestamp}] {agent_name}: {decision}" + if context: + log_entry += f"\n Context: {json.dumps(context, ensure_ascii=False)}" + log_entry += "\n" + + with open(log_file, 'a', encoding='utf-8') as f: + f.write(log_entry) + + logger.debug(f"Logged decision from {agent_name}") + except Exception as e: + logger.error(f"Failed to log decision: {e}") + + async def record_agent_action( + self, + agent_name: str, + action_type: str, + input_data: Optional[Dict[str, Any]] = None, + reasoning: Optional[Dict[str, Any]] = None, + output_data: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + related_tool_steps: Optional[list] = None, + correlation_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """ + Record an agent's action and decision-making process. + + Args: + agent_name: Name of the agent performing the action + action_type: Type of action (plan | execute | evaluate | monitor) + input_data: Input data the agent received (simplified) + reasoning: Agent's reasoning process (structured) + output_data: Agent's output/decision (structured) + metadata: Additional metadata (LLM model, tokens, duration, etc.) + related_tool_steps: List of tool execution step numbers related to this action + correlation_id: Optional correlation ID to link related events + + Returns: + The recorded action info, or None if recording not started + """ + if not self._action_recorder or not self._is_started: + logger.debug("Cannot record agent action: recording not started") + return None + + try: + action_info = await self._action_recorder.record_action( + agent_name=agent_name, + action_type=action_type, + input_data=input_data, + reasoning=reasoning, + output_data=output_data, + metadata=metadata, + related_tool_steps=related_tool_steps, + correlation_id=correlation_id, + ) + + logger.debug(f"Recorded agent action: {agent_name} - {action_type}") + return action_info + + except Exception as e: + logger.error(f"Failed to record agent action: {e}") + return None + + async def generate_summary(self) -> Dict[str, Any]: + """ + Generate a comprehensive summary of the recording session. + """ + if not self._recorder or not self._is_started: + logger.warning("Cannot generate summary: recording not started") + return {} + + try: + from .action_recorder import load_agent_actions, analyze_agent_actions + from .utils import load_trajectory_from_jsonl, analyze_trajectory + + traj_dir = self._recorder.get_trajectory_dir() + + # Load all recorded data + trajectory = load_trajectory_from_jsonl(f"{traj_dir}/traj.jsonl") + agent_actions = load_agent_actions(traj_dir) + + # Analyze data + traj_stats = analyze_trajectory(trajectory) + action_stats = analyze_agent_actions(agent_actions) + + # Build summary + summary = { + "task_id": self.task_id, + "start_time": self._recorder.metadata.get("start_time", ""), + "end_time": self._recorder.metadata.get("end_time", ""), + "trajectory": { + "total_steps": traj_stats.get("total_steps", 0), + "success_count": traj_stats.get("success_count", 0), + "success_rate": traj_stats.get("success_rate", 0), + "by_backend": traj_stats.get("backends", {}), + "by_tool": traj_stats.get("tools", {}), + }, + "agent_actions": { + "total_actions": action_stats.get("total_actions", 0), + "by_agent": action_stats.get("by_agent", {}), + "by_type": action_stats.get("by_type", {}), + } + } + + # Save summary to file + summary_file = Path(traj_dir) / "summary.json" + with open(summary_file, 'w', encoding='utf-8') as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + logger.info(f"Generated summary: {summary_file}") + return summary + + except Exception as e: + logger.error(f"Failed to generate summary: {e}") + return {} + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.stop() + return False + + @property + def recording_status(self) -> bool: + return self._is_started + + @property + def trajectory_dir(self) -> Optional[str]: + if self._recorder: + return str(self._recorder.get_trajectory_dir()) + return None + + @property + def recording_client(self): + return self._recording_client + + @property + def screenshot_client(self): + return self._screenshot_client + + @property + def step_count(self) -> int: + """Get current step count""" + return self._step_counter + + +__all__ = [ + 'RecordingManager', +] \ No newline at end of file diff --git a/openspace/recording/recorder.py b/openspace/recording/recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..aaee792c52135c82857260273d6b8738554e7714 --- /dev/null +++ b/openspace/recording/recorder.py @@ -0,0 +1,424 @@ +import datetime +import json +from typing import Any, Dict, List, Optional +from pathlib import Path + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class TrajectoryRecorder: + def __init__( + self, + task_name: str = "", + log_dir: str = "./logs/trajectories", + enable_screenshot: bool = True, + enable_video: bool = False, + server_url: Optional[str] = None, + ): + """ + Initialize trajectory recorder + + Args: + task_name: task name (optional, will be saved in metadata) + log_dir: log directory + enable_screenshot: whether to save screenshots (through platforms.ScreenshotClient) + enable_video: whether to enable video recording (through platform.RecordingClient) + server_url: local_server address (None = read from config/environment variables) + """ + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Simplify naming rule: add prefix if task_name is provided, otherwise use timestamp only + if task_name: + folder_name = f"{task_name}_{timestamp}" + else: + folder_name = timestamp + + self.trajectory_dir = Path(log_dir) / folder_name + self.trajectory_dir.mkdir(parents=True, exist_ok=True) + + # Create screenshots directory + if enable_screenshot: + self.screenshots_dir = self.trajectory_dir / "screenshots" + self.screenshots_dir.mkdir(exist_ok=True) + else: + self.screenshots_dir = None + + # Config + self.task_name = task_name + self.enable_screenshot = enable_screenshot + self.enable_video = enable_video + self.server_url = server_url + + # Trajectory data + self.steps: List[Dict] = [] + self.step_counter = 0 + + # Metadata + self.metadata = { + "task_name": task_name, + "start_time": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"), + "enable_screenshot": enable_screenshot, + "enable_video": enable_video, + } + + # Video recorder (lazy initialization) + self._video_recorder = None + + # Save initial metadata + self._save_metadata() + + async def record_step( + self, + backend: str, + tool: str, + command: str, + result: Optional[Dict[str, Any]] = None, + parameters: Optional[Dict[str, Any]] = None, + screenshot: Optional[bytes] = None, + extra: Optional[Dict[str, Any]] = None, + auto_screenshot: bool = False, + ) -> Dict[str, Any]: + """ + Record one step operation + + Args: + backend: backend type (gui/shell/mcp/web/system) + tool: tool name (name of BaseTool) + command: human-readable core command + result: execution result + parameters: tool parameters + screenshot: screenshot bytes (if provided) + extra: extra information (e.g. server field for MCP) + auto_screenshot: whether to automatically capture screenshot (through platforms.ScreenshotClient) + """ + self.step_counter += 1 + step_num = self.step_counter + timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") + + step_info = { + "step": step_num, + "timestamp": timestamp, + "backend": backend, + } + + # MCP needs to record server (between backend and tool) + if extra and "server" in extra: + step_info["server"] = extra.pop("server") + + # General fields + step_info["tool"] = tool # BaseTool name + step_info["command"] = command # human-readable core command + + # parameters unified write to top level + if parameters: + step_info["parameters"] = parameters + elif extra and "parameters" in extra: + step_info["parameters"] = extra.pop("parameters") + + # Execution result remains original + step_info["result"] = result or {} + + # Other extra information (e.g. coordinates/url) only added when needed + if extra: + step_info.update(extra) + + # Automatic screenshot (if enabled and no screenshot provided) + if auto_screenshot and screenshot is None and self.enable_screenshot: + screenshot = await self._capture_screenshot() + + # Save screenshot + if screenshot and self.enable_screenshot and self.screenshots_dir: + screenshot_filename = f"step_{step_num:03d}.png" + screenshot_path = self.screenshots_dir / screenshot_filename + with open(screenshot_path, "wb") as f: + f.write(screenshot) + step_info["screenshot"] = f"screenshots/{screenshot_filename}" + + # Add to trajectory + self.steps.append(step_info) + + # Save to traj.jsonl in real time + await self._append_to_traj_file(step_info) + + return step_info + + async def _capture_screenshot(self) -> Optional[bytes]: + """Capture screenshot automatically through platforms.ScreenshotClient""" + try: + from openspace.platforms import ScreenshotClient + + # Lazy initialization screenshot client + if not hasattr(self, '_screenshot_client'): + try: + self._screenshot_client = ScreenshotClient(base_url=self.server_url) + except Exception: + self._screenshot_client = None + return None + + if self._screenshot_client is None: + return None + + return await self._screenshot_client.capture() + + except Exception: + return None + + async def save_init_screenshot(self, screenshot: bytes, filename: str = "init.png"): + """Save initial screenshot to screenshots dir and update metadata.""" + if not (self.enable_screenshot and self.screenshots_dir and screenshot): + return + try: + filepath = self.screenshots_dir / filename + with open(filepath, "wb") as f: + f.write(screenshot) + # Update metadata + self.metadata["init_screenshot"] = f"screenshots/{filename}" + self._save_metadata() + except Exception as e: + logger.debug(f"Failed to save initial screenshot: {e}") + + async def _append_to_traj_file(self, step_info: Dict[str, Any]): + """Add step to traj.jsonl file""" + traj_file = self.trajectory_dir / "traj.jsonl" + try: + line = json.dumps(step_info, ensure_ascii=False, default=str) + with open(traj_file, "a", encoding="utf-8") as f: + f.write(line) + f.write("\n") + except Exception as e: + logger.warning(f"Failed to append step {step_info.get('step', '?')} to traj.jsonl: {e}") + + def _save_metadata(self): + """Save metadata to metadata.json""" + metadata_file = self.trajectory_dir / "metadata.json" + with open(metadata_file, "w", encoding="utf-8") as f: + json.dump(self.metadata, f, indent=2, ensure_ascii=False) + + async def start_video_recording(self): + """Start video recording (through platform.RecordingClient)""" + if not self.enable_video: + return + + try: + from openspace.recording.video import VideoRecorder + + video_path = self.trajectory_dir / "recording.mp4" + self._video_recorder = VideoRecorder(str(video_path), base_url=self.server_url) + + success = await self._video_recorder.start() + if not success: + self._video_recorder = None + + except Exception as e: + logger.warning(f"Video recording failed to start: {e}") + self._video_recorder = None + + async def stop_video_recording(self): + """Stop video recording""" + if self._video_recorder: + try: + await self._video_recorder.stop() + except Exception: + pass + finally: + self._video_recorder = None + + async def add_metadata(self, key: str, value: Any): + """Add metadata""" + self.metadata[key] = value + self._save_metadata() + + async def finalize(self): + """Finalize recording, save final information""" + self.metadata["end_time"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") + self.metadata["total_steps"] = self.step_counter + + # Backend statistics + backend_counts = {} + for step in self.steps: + backend = step.get("backend", "unknown") + backend_counts[backend] = backend_counts.get(backend, 0) + 1 + self.metadata["backend_counts"] = backend_counts + + self._save_metadata() + + # Close internal ScreenshotClient, avoid unclosed session warning + await self._cleanup_screenshot_client() + + # Stop video recording + await self.stop_video_recording() + + logger.info(f"Recording completed: {self.trajectory_dir} (steps: {self.step_counter})") + + async def _cleanup_screenshot_client(self): + """Cleanup screenshot client resources""" + if hasattr(self, '_screenshot_client') and self._screenshot_client: + try: + await self._screenshot_client.close() + except Exception as e: + logger.debug(f"Failed to close screenshot client: {e}") + finally: + self._screenshot_client = None + + def __del__(self): + """Ensure resources are cleaned up even if finalize() is not called""" + # Note: This is a safety net. Best practice is to call finalize() explicitly. + if hasattr(self, '_video_recorder') and self._video_recorder: + logger.warning( + f"TrajectoryRecorder for {self.trajectory_dir} was not finalized properly. " + "Consider calling finalize() or using async context manager." + ) + + def get_trajectory_dir(self) -> str: + """Get trajectory directory path""" + return str(self.trajectory_dir) + + async def __aenter__(self): + """Async context manager entry""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - ensures finalize() is called""" + await self.finalize() + return False + +async def record_gui_step( + recorder: TrajectoryRecorder, + command: str, + task_description: str, + result: Dict[str, Any] = None, + screenshot: Optional[bytes] = None, + max_steps: int = 10, + tool: str = "gui_agent", +) -> Dict[str, Any]: + """ + Record GUI step + + Args: + recorder: recorder instance + command: actual executed pyautogui command (e.g. "pyautogui.moveTo(960, 540)") + task_description: task description + result: execution result + screenshot: screenshot + max_steps: maximum number of steps + tool: tool name + """ + parameters = { + "task_description": task_description, + "max_steps": max_steps, + } + + return await recorder.record_step( + backend="gui", + tool=tool, + command=command, + result=result, + parameters=parameters, + screenshot=screenshot, + ) + + +async def record_shell_step( + recorder: TrajectoryRecorder, + command: str, + exit_code: int, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + screenshot: Optional[bytes] = None, + tool: str = "shell_agent", +) -> Dict[str, Any]: + """ + Record Shell step + + Args: + recorder: recorder instance + command: command executed + exit_code: exit code + stdout: standard output (simplified version, not saved completely) + stderr: standard error (simplified version) + screenshot: screenshot + tool: tool name + """ + stdout_brief = stdout[:200] + "..." if stdout and len(stdout) > 200 else stdout + stderr_brief = stderr[:200] + "..." if stderr and len(stderr) > 200 else stderr + + result = { + "status": "success" if exit_code == 0 else "error", + "exit_code": exit_code, + "stdout": stdout_brief, + "stderr": stderr_brief, + } + + return await recorder.record_step( + backend="shell", + tool=tool, + command=command, + result=result, + screenshot=screenshot, + ) + +async def record_mcp_step( + recorder: TrajectoryRecorder, + server: str, + tool_name: str, + parameters: Dict[str, Any], + result: Any, + screenshot: Optional[bytes] = None, +) -> Dict[str, Any]: + """ + Record MCP step + + Args: + recorder: recorder instance + server: MCP server name + tool_name: tool name + parameters: tool parameters + result: execution result + screenshot: screenshot + """ + command = f"{server}.{tool_name}" + + result_str = str(result) + result_brief = result_str[:200] + "..." if len(result_str) > 200 else result_str + + return await recorder.record_step( + backend="mcp", + tool=tool_name, + command=command, + result={"status": "success", "output": result_brief}, + parameters=parameters, + screenshot=screenshot, + extra={ + "server": server, + } + ) + + +async def record_web_step( + recorder: TrajectoryRecorder, + query: str, + result: Dict[str, Any], + screenshot: Optional[bytes] = None, + tool: str = "deep_research_agent", +) -> Dict[str, Any]: + """ + Record Web step (deep research) + + Args: + recorder: recorder instance + query: search query + result: execution result + screenshot: screenshot + tool: tool name + """ + command = query # directly use query as command + + return await recorder.record_step( + backend="web", + tool=tool, + command=command, + result=result, + screenshot=screenshot, + ) \ No newline at end of file diff --git a/openspace/recording/utils.py b/openspace/recording/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8da08a68426ad3f3020078462a9f0bd2504863f4 --- /dev/null +++ b/openspace/recording/utils.py @@ -0,0 +1,386 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +def load_trajectory_from_jsonl(jsonl_path: str) -> List[Dict[str, Any]]: + trajectory = [] + + # Check if file exists first + if not os.path.exists(jsonl_path): + logger.debug(f"No trajectory file found at {jsonl_path} (this is normal for knowledge-only tasks)") + return [] + + try: + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + step = json.loads(line) + trajectory.append(step) + + logger.info(f"Loaded {len(trajectory)} steps from {jsonl_path}") + return trajectory + + except Exception as e: + logger.error(f"Failed to load trajectory from {jsonl_path}: {e}") + return [] + + +def load_metadata(trajectory_dir: str) -> Optional[Dict[str, Any]]: + metadata_path = os.path.join(trajectory_dir, "metadata.json") + + try: + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + return metadata + except Exception as e: + logger.warning(f"Failed to load metadata from {metadata_path}: {e}") + return None + + +def format_trajectory_for_export( + trajectory: List[Dict[str, Any]], + format_type: str = "compact" +) -> str: + if format_type == "compact": + return _format_compact(trajectory) + elif format_type == "detailed": + return _format_detailed(trajectory) + elif format_type == "markdown": + return _format_markdown(trajectory) + else: + raise ValueError(f"Unknown format type: {format_type}") + + +def _format_compact(trajectory: List[Dict[str, Any]]) -> str: + """Compact format: one line per step.""" + lines = [] + for step in trajectory: + step_num = step.get("step", "?") + backend = step.get("backend", "?") + server = step.get("server") + tool = step.get("tool", "?") + result_status = "success" if step.get("result", {}).get("status") == "success" else "error" + + # Include server name for MCP backend + backend_str = f"{backend}@{server}" if server else backend + lines.append(f"Step {step_num}: [{backend_str}] {tool} -> {result_status}") + + return "\n".join(lines) + + +def _format_detailed(trajectory: List[Dict[str, Any]]) -> str: + """Detailed format: multiple lines per step with parameters.""" + lines = [] + for step in trajectory: + step_num = step.get("step", "?") + timestamp = step.get("timestamp", "?") + backend = step.get("backend", "?") + server = step.get("server") + tool = step.get("tool", "?") + command = step.get("command", "?") + parameters = step.get("parameters", {}) + result = step.get("result", {}) + + from openspace.utils.display import Box, BoxStyle + + box = Box(width=66, style=BoxStyle.ROUNDED, color='bl') + lines.append("") + lines.append(box.top_line(0)) + lines.append(box.text_line(f"Step {step_num} ({timestamp})", align='center', indent=0, text_color='c')) + lines.append(box.separator_line(0)) + lines.append(box.text_line(f"Backend: {backend}", indent=0)) + if server: + lines.append(box.text_line(f"Server: {server}", indent=0)) + lines.append(box.text_line(f"Tool: {tool}", indent=0)) + lines.append(box.text_line(f"Command: {command}", indent=0)) + lines.append(box.separator_line(0)) + # Parameters and result can be multi-line + param_str = json.dumps(parameters, indent=2) + for param_line in param_str.split('\n'): + lines.append(box.text_line(param_line, indent=0)) + lines.append(box.separator_line(0)) + result_str = json.dumps(result, indent=2) + for result_line in result_str.split('\n'): + lines.append(box.text_line(result_line, indent=0)) + lines.append(box.bottom_line(0)) + + return "\n".join(lines) + + +def _format_markdown(trajectory: List[Dict[str, Any]]) -> str: + """Markdown format: table format.""" + lines = [ + "# Trajectory", + "", + "| Step | Backend | Server | Tool | Status | Screenshot |", + "|------|---------|--------|------|--------|------------|" + ] + + for step in trajectory: + step_num = step.get("step", "?") + backend = step.get("backend", "?") + server = step.get("server", "-") + tool = step.get("tool", "?") + result_status = "✓" if step.get("result", {}).get("status") == "success" else "✗" + screenshot = "📷" if step.get("screenshot") else "" + + lines.append(f"| {step_num} | {backend} | {server} | {tool} | {result_status} | {screenshot} |") + + return "\n".join(lines) + + +def analyze_trajectory(trajectory: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Analyze trajectory and return statistics. + """ + if not trajectory: + return { + "total_steps": 0, + "success_rate": 0.0, + "backends": {}, + "action_types": {} + } + + total_steps = len(trajectory) + success_count = 0 + backends = {} + action_types = {} + + for step in trajectory: + # Count successes + if step.get("result", {}).get("status") == "success": + success_count += 1 + + # Count backends + backend = step.get("backend", "unknown") + backends[backend] = backends.get(backend, 0) + 1 + + # Count tool types + tool = step.get("tool", "unknown") + action_types[tool] = action_types.get(tool, 0) + 1 + + return { + "total_steps": total_steps, + "success_count": success_count, + "success_rate": success_count / total_steps if total_steps > 0 else 0.0, + "backends": backends, + "tools": action_types + } + + +def load_recording_session(recording_dir: str) -> Dict[str, Any]: + """ + Load complete recording session including trajectory, metadata, plans, and snapshots. + + Args: + recording_dir: Path to recording directory + + Returns: + Dictionary containing all session data: + { + "trajectory": List[Dict], + "metadata": Dict, + "plans": List[Dict], + "decisions": List[str], + "statistics": Dict + } + """ + recording_path = Path(recording_dir) + + if not recording_path.exists(): + logger.error(f"Recording directory not found: {recording_dir}") + return {} + + session = { + "trajectory": [], + "metadata": None, + "plans": [], + "decisions": [], + "statistics": {} + } + + # Load trajectory + traj_file = recording_path / "traj.jsonl" + if traj_file.exists(): + session["trajectory"] = load_trajectory_from_jsonl(str(traj_file)) + session["statistics"] = analyze_trajectory(session["trajectory"]) + + # Load metadata + metadata_file = recording_path / "metadata.json" + if metadata_file.exists(): + session["metadata"] = load_metadata(str(recording_path)) + + # Load plans + plans_dir = recording_path / "plans" + if plans_dir.exists(): + for plan_file in sorted(plans_dir.glob("plan_*.json")): + try: + with open(plan_file, 'r', encoding='utf-8') as f: + session["plans"].append(json.load(f)) + except Exception as e: + logger.warning(f"Failed to load plan {plan_file}: {e}") + + # Load decisions log + decisions_file = recording_path / "decisions.log" + if decisions_file.exists(): + try: + with open(decisions_file, 'r', encoding='utf-8') as f: + session["decisions"] = f.readlines() + except Exception as e: + logger.warning(f"Failed to load decisions: {e}") + + return session + + +def filter_trajectory( + trajectory: List[Dict[str, Any]], + backend: Optional[str] = None, + tool: Optional[str] = None, + status: Optional[str] = None, + time_range: Optional[Tuple[str, str]] = None +) -> List[Dict[str, Any]]: + filtered = trajectory + + if backend: + filtered = [s for s in filtered if s.get("backend") == backend] + + if tool: + filtered = [s for s in filtered if s.get("tool") == tool] + + if status: + filtered = [s for s in filtered if s.get("result", {}).get("status") == status] + + if time_range: + start_time, end_time = time_range + filtered = [ + s for s in filtered + if start_time <= s.get("timestamp", "") <= end_time + ] + + return filtered + + +def extract_errors(trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + step for step in trajectory + if step.get("result", {}).get("status") == "error" + ] + + +def generate_summary_report(recording_dir: str, output_file: Optional[str] = None) -> str: + session = load_recording_session(recording_dir) + + if not session: + return "Error: Could not load recording session" + + lines = [] + lines.append("# Recording Session Summary\n") + + # Metadata section + if session["metadata"]: + lines.append("## Metadata") + metadata = session["metadata"] + lines.append(f"- **Task ID**: {metadata.get('task_id', 'N/A')}") + lines.append(f"- **Start Time**: {metadata.get('start_time', 'N/A')}") + lines.append(f"- **End Time**: {metadata.get('end_time', 'N/A')}") + lines.append(f"- **Total Steps**: {metadata.get('total_steps', 0)}") + lines.append(f"- **Backends**: {', '.join(metadata.get('backends', []))}") + lines.append("") + + # Statistics section + if session["statistics"]: + lines.append("## Statistics") + stats = session["statistics"] + lines.append(f"- **Total Steps**: {stats.get('total_steps', 0)}") + lines.append(f"- **Success Count**: {stats.get('success_count', 0)}") + lines.append(f"- **Success Rate**: {stats.get('success_rate', 0):.2%}") + lines.append("") + + lines.append("### Backend Distribution") + for backend, count in stats.get('backends', {}).items(): + lines.append(f"- {backend}: {count}") + lines.append("") + + lines.append("### Tool Distribution") + for tool, count in sorted(stats.get('tools', {}).items(), key=lambda x: x[1], reverse=True): + lines.append(f"- {tool}: {count}") + lines.append("") + + # Plans section + if session["plans"]: + lines.append(f"## Plans ({len(session['plans'])} total)") + for i, plan in enumerate(session["plans"], 1): + lines.append(f"### Plan {i}") + lines.append(f"- Created: {plan.get('created_at', 'N/A')}") + lines.append(f"- Created by: {plan.get('created_by', 'N/A')}") + plan_data = plan.get('plan', {}) + if 'task_updates' in plan_data: + lines.append(f"- Tasks: {len(plan_data['task_updates'])}") + lines.append("") + + # Errors section + if session["trajectory"]: + errors = extract_errors(session["trajectory"]) + if errors: + lines.append(f"## Errors ({len(errors)} total)") + for error in errors[:5]: # Show first 5 errors + lines.append(f"- Step {error.get('step')}: {error.get('backend')} - {error.get('tool')}") + error_msg = error.get('result', {}).get('output', 'No error message') + lines.append(f" ```\n {error_msg[:200]}\n ```") + if len(errors) > 5: + lines.append(f" ... and {len(errors) - 5} more errors") + lines.append("") + + # Decisions section + if session["decisions"]: + lines.append(f"## Decisions ({len(session['decisions'])} total)") + for decision in session["decisions"][:10]: # Show first 10 decisions + lines.append(f" {decision.strip()}") + if len(session["decisions"]) > 10: + lines.append(f" ... and {len(session['decisions']) - 10} more decisions") + lines.append("") + + report = "\n".join(lines) + + # Save to file if requested + if output_file: + try: + with open(output_file, 'w', encoding='utf-8') as f: + f.write(report) + logger.info(f"Report saved to {output_file}") + except Exception as e: + logger.error(f"Failed to save report: {e}") + + return report + + +def compare_recordings(recording_dir1: str, recording_dir2: str) -> Dict[str, Any]: + session1 = load_recording_session(recording_dir1) + session2 = load_recording_session(recording_dir2) + + stats1 = session1.get("statistics", {}) + stats2 = session2.get("statistics", {}) + + return { + "session1": { + "path": recording_dir1, + "total_steps": stats1.get("total_steps", 0), + "success_rate": stats1.get("success_rate", 0), + "backends": stats1.get("backends", {}) + }, + "session2": { + "path": recording_dir2, + "total_steps": stats2.get("total_steps", 0), + "success_rate": stats2.get("success_rate", 0), + "backends": stats2.get("backends", {}) + }, + "differences": { + "step_diff": stats2.get("total_steps", 0) - stats1.get("total_steps", 0), + "success_rate_diff": stats2.get("success_rate", 0) - stats1.get("success_rate", 0) + } + } \ No newline at end of file diff --git a/openspace/recording/video.py b/openspace/recording/video.py new file mode 100644 index 0000000000000000000000000000000000000000..65e7c1cd749e666cb475a4aeccec89b69b78b9f0 --- /dev/null +++ b/openspace/recording/video.py @@ -0,0 +1,88 @@ +""" +Video Recorder + +Communicates with local_server through platforms.RecordingClient +Supports local and remote recording (through configuration LOCAL_SERVER_URL) +""" + +from pathlib import Path +from typing import Optional + +from openspace.utils.logging import Logger +from openspace.platforms import RecordingClient + +logger = Logger.get_logger(__name__) + + +class VideoRecorder: + def __init__( + self, + output_path: str, + base_url: Optional[str] = None, + ): + """ + Initialize video recorder + + Args: + output_path: output video path + base_url: local_server address (None = read from config/environment variables) + """ + self.output_path = Path(output_path) + self.base_url = base_url + self.is_recording = False + self._client: Optional[RecordingClient] = None + + async def start(self): + """Start recording screen""" + if self.is_recording: + return False + + try: + if self._client is None: + self._client = RecordingClient(base_url=self.base_url) + + success = await self._client.start_recording() + + if success: + self.is_recording = True + logger.info(f"Video recording started") + return True + else: + logger.warning("Video recording failed to start") + return False + + except Exception as e: + logger.warning(f"Video recording failed to start: {e}") + return False + + async def stop(self): + """Stop recording screen and save to local""" + if not self.is_recording: + return False + + try: + if self._client: + video_bytes = await self._client.end_recording(dest=str(self.output_path)) + + if video_bytes: + video_size_mb = len(video_bytes) / (1024 * 1024) + self.is_recording = False + logger.info(f"Video recording stopped ({video_size_mb:.2f} MB)") + return True + else: + logger.warning("Video recording failed to stop") + return False + + except Exception as e: + logger.warning(f"Video recording failed to stop: {e}") + return False + finally: + if self._client: + try: + await self._client.close() + except Exception: + pass + self._client = None + + +__all__ = ['VideoRecorder'] \ No newline at end of file diff --git a/openspace/recording/viewer.py b/openspace/recording/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..70b93c3693192c564316dd55fc9cd7126aea3468 --- /dev/null +++ b/openspace/recording/viewer.py @@ -0,0 +1,334 @@ +""" +Recording Viewer +Convenient tools for viewing and analyzing recording sessions. +""" + +import json +from pathlib import Path +from typing import Optional, Dict, Any, List + +from openspace.utils.logging import Logger +from .utils import load_recording_session, generate_summary_report +from .action_recorder import load_agent_actions, analyze_agent_actions, format_agent_actions + +logger = Logger.get_logger(__name__) + + +class RecordingViewer: + """ + Viewer for analyzing recording sessions. + + Provides convenient methods to: + - Load and display recordings + - Analyze agent behaviors + - Generate reports + """ + + def __init__(self, recording_dir: str): + """ + Initialize viewer with a recording directory. + + Args: + recording_dir: Path to recording directory + """ + self.recording_dir = Path(recording_dir) + + if not self.recording_dir.exists(): + raise ValueError(f"Recording directory not found: {recording_dir}") + + # Load session data + self.session = load_recording_session(str(self.recording_dir)) + + logger.info(f"Loaded recording from {recording_dir}") + + def show_summary(self) -> str: + """ + Display a summary of the recording. + + Returns: + Formatted summary string + """ + if not self.session.get("metadata"): + return "No metadata available" + + metadata = self.session["metadata"] + stats = self.session.get("statistics", {}) + + lines = [] + lines.append("=" * 70) + lines.append("RECORDING SUMMARY") + lines.append("=" * 70) + lines.append(f"Task ID: {metadata.get('task_id', 'N/A')}") + lines.append(f"Start: {metadata.get('start_time', 'N/A')}") + lines.append(f"End: {metadata.get('end_time', 'N/A')}") + lines.append(f"Total Steps: {metadata.get('total_steps', 0)}") + lines.append("") + + lines.append("Statistics:") + lines.append(f" - Success Rate: {stats.get('success_rate', 0):.2%}") + lines.append(f" - Success Count: {stats.get('success_count', 0)}/{stats.get('total_steps', 0)}") + lines.append("") + + if stats.get("backends"): + lines.append("Backend Usage:") + for backend, count in sorted(stats["backends"].items(), key=lambda x: x[1], reverse=True): + lines.append(f" - {backend}: {count}") + + lines.append("=" * 70) + + return "\n".join(lines) + + def show_agent_actions(self, format_type: str = "compact", agent_name: Optional[str] = None) -> str: + actions = load_agent_actions(str(self.recording_dir)) + + if agent_name: + actions = [a for a in actions if a.get("agent_name") == agent_name] + + if not actions: + return f"No agent actions found{' for ' + agent_name if agent_name else ''}" + + # Add header + header = f"\nAGENT ACTIONS ({len(actions)} total)" + if agent_name: + header += f" - {agent_name}" + header += "\n" + "=" * 70 + + # Format actions + formatted = format_agent_actions(actions, format_type) + + return header + "\n" + formatted + + def analyze_agents(self) -> str: + actions = load_agent_actions(str(self.recording_dir)) + stats = analyze_agent_actions(actions) + + lines = [] + lines.append("\nAGENT ANALYSIS") + lines.append("=" * 70) + lines.append(f"Total Actions: {stats.get('total_actions', 0)}") + lines.append("") + + lines.append("By Agent:") + for agent, count in sorted(stats.get('by_agent', {}).items(), key=lambda x: x[1], reverse=True): + percentage = (count / stats['total_actions'] * 100) if stats['total_actions'] > 0 else 0 + lines.append(f" - {agent}: {count} ({percentage:.1f}%)") + lines.append("") + + lines.append("By Action Type:") + for action_type, count in sorted(stats.get('by_type', {}).items(), key=lambda x: x[1], reverse=True): + percentage = (count / stats['total_actions'] * 100) if stats['total_actions'] > 0 else 0 + lines.append(f" - {action_type}: {count} ({percentage:.1f}%)") + + return "\n".join(lines) + + def generate_full_report(self, output_file: Optional[str] = None) -> str: + return generate_summary_report(str(self.recording_dir), output_file) + + def export_to_json(self, output_file: str): + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(self.session, f, indent=2, ensure_ascii=False) + + logger.info(f"Exported session to {output_file}") + + def show_timeline(self, max_events: int = 50) -> str: + # Load all events + actions = load_agent_actions(str(self.recording_dir)) + trajectory = self.session.get("trajectory", []) + + # Combine all events with unified format + timeline = [] + + # Add agent actions + for action in actions: + timeline.append({ + "timestamp": action.get("timestamp", ""), + "type": "agent_action", + "agent_name": action.get("agent_name", ""), + "agent_type": action.get("agent_type", "unknown"), + "action_type": action.get("action_type", ""), + "step": action.get("step"), + "correlation_id": action.get("correlation_id", ""), + "description": f"[{action.get('agent_type', '?').upper()}] {action.get('action_type', '?')}", + "related_tool_steps": action.get("related_tool_steps", []), + }) + + # Add tool executions + for traj_step in trajectory: + timeline.append({ + "timestamp": traj_step.get("timestamp", ""), + "type": "tool_execution", + "backend": traj_step.get("backend", ""), + "tool": traj_step.get("tool", ""), + "step": traj_step.get("step"), + "agent_name": traj_step.get("agent_name", ""), + "description": f"[TOOL:{traj_step.get('backend', '?').upper()}] {traj_step.get('tool', '?')}", + "status": traj_step.get("result", {}).get("status", ""), + }) + + # Sort by timestamp + timeline.sort(key=lambda x: x.get("timestamp", "")) + + # Format output + lines = [] + lines.append("\nUNIFIED TIMELINE") + lines.append("=" * 100) + lines.append(f"Total events: {len(timeline)} (showing first {max_events})") + lines.append("") + + for i, item in enumerate(timeline[:max_events]): + timestamp = item.get("timestamp", "N/A") + time_str = timestamp.split("T")[1][:8] if "T" in timestamp else timestamp[-8:] + + # Format line with type indicator + type_marker = { + "agent_action": "🤖", + "tool_execution": "🔧" + }.get(item.get("type"), "•") + + desc = item.get("description", "") + agent = item.get("agent_name", "") + agent_type = item.get("agent_type", "") + + line = f"{time_str} {type_marker} {desc}" + + # Add agent info if available + if agent and agent_type: + line += f" (by {agent}/{agent_type})" + elif agent: + line += f" (by {agent})" + + lines.append(line) + + # Show correlations + correlations = [] + if item.get("related_tool_steps"): + correlations.append(f"→ tool steps: {item['related_tool_steps']}") + if item.get("related_action_step"): + correlations.append(f"→ action step: {item['related_action_step']}") + + if correlations: + for corr in correlations: + lines.append(f" {corr}") + + if len(timeline) > max_events: + lines.append(f"\n... and {len(timeline) - max_events} more events") + + return "\n".join(lines) + + def show_agent_flow(self, agent_name: Optional[str] = None) -> str: + """ + Show the flow of a specific agent's actions and related events. + """ + actions = load_agent_actions(str(self.recording_dir)) + + if agent_name: + actions = [a for a in actions if a.get("agent_name") == agent_name] + + lines = [] + lines.append(f"\nAGENT FLOW{' - ' + agent_name if agent_name else ''}") + lines.append("=" * 100) + + # Sort by timestamp + actions.sort(key=lambda x: x.get("timestamp", "")) + + for action in actions: + timestamp = action.get("timestamp", "N/A").split("T")[1][:8] if "T" in action.get("timestamp", "") else "N/A" + + agent_type = action.get("agent_type", "?").upper() + action_type = action.get("action_type", "?") + step = action.get("step", "?") + lines.append(f"{timestamp} [{agent_type}] Action #{step}: {action_type}") + + # Show reasoning if available + if action.get("reasoning"): + thought = action["reasoning"].get("thought", "") + if thought: + lines.append(f" 💭 {thought[:80]}...") + + # Show output + if action.get("output"): + output = action["output"] + if isinstance(output, dict): + for key in ["message", "status", "evaluation"]: + if key in output: + lines.append(f" 📤 {key}: {str(output[key])[:60]}") + + lines.append("") + + return "\n".join(lines) + + +def view_recording(recording_dir: str): + """ + Quick interactive viewer for a recording. + """ + try: + viewer = RecordingViewer(recording_dir) + + print(viewer.show_summary()) + print("\n") + + print(viewer.analyze_agents()) + print("\n") + + print("Agent Actions (compact):") + print(viewer.show_agent_actions(format_type="compact")) + + except Exception as e: + logger.error(f"Failed to view recording: {e}") + print(f"Error: {e}") + + +def compare_recordings(recording_dir1: str, recording_dir2: str) -> str: + """ + Compare two recordings side by side. + """ + try: + viewer1 = RecordingViewer(recording_dir1) + viewer2 = RecordingViewer(recording_dir2) + + lines = [] + lines.append("=" * 70) + lines.append("RECORDING COMPARISON") + lines.append("=" * 70) + lines.append("") + + # Compare metadata + meta1 = viewer1.session.get("metadata", {}) + meta2 = viewer2.session.get("metadata", {}) + + lines.append("Recording 1:") + lines.append(f" Task: {meta1.get('task_id', 'N/A')}") + lines.append(f" Steps: {meta1.get('total_steps', 0)}") + lines.append("") + + lines.append("Recording 2:") + lines.append(f" Task: {meta2.get('task_id', 'N/A')}") + lines.append(f" Steps: {meta2.get('total_steps', 0)}") + lines.append("") + + # Compare statistics + stats1 = viewer1.session.get("statistics", {}) + stats2 = viewer2.session.get("statistics", {}) + + lines.append("Differences:") + lines.append(f" Steps: {meta2.get('total_steps', 0) - meta1.get('total_steps', 0):+d}") + lines.append(f" Success Rate: {stats2.get('success_rate', 0) - stats1.get('success_rate', 0):+.2%}") + + return "\n".join(lines) + + except Exception as e: + logger.error(f"Failed to compare recordings: {e}") + return f"Error: {e}" + + +# CLI interface +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m openspace.recording.viewer ") + sys.exit(1) + + recording_dir = sys.argv[1] + view_recording(recording_dir) \ No newline at end of file diff --git a/openspace/skill_engine/__init__.py b/openspace/skill_engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60f022bc6226111dee9b1715eb73abf20873d200 --- /dev/null +++ b/openspace/skill_engine/__init__.py @@ -0,0 +1,54 @@ +from importlib import import_module + +__all__ = [ + 'SkillRegistry', + 'SkillMeta', + 'SKILL_ID_FILENAME', + 'write_skill_id', + 'EvolutionSuggestion', + 'EvolutionType', + 'ExecutionAnalysis', + 'SkillCategory', + 'SkillJudgment', + 'SkillOrigin', + 'SkillLineage', + 'SkillRecord', + 'SkillVisibility', + 'SkillStore', + 'ExecutionAnalyzer', + 'SkillEvolver', + 'EvolutionTrigger', + 'EvolutionContext', +] + +_EXPORTS = { + 'SkillRegistry': ('.registry', 'SkillRegistry'), + 'SkillMeta': ('.registry', 'SkillMeta'), + 'SKILL_ID_FILENAME': ('.registry', 'SKILL_ID_FILENAME'), + 'write_skill_id': ('.registry', 'write_skill_id'), + 'EvolutionSuggestion': ('.types', 'EvolutionSuggestion'), + 'EvolutionType': ('.types', 'EvolutionType'), + 'ExecutionAnalysis': ('.types', 'ExecutionAnalysis'), + 'SkillCategory': ('.types', 'SkillCategory'), + 'SkillJudgment': ('.types', 'SkillJudgment'), + 'SkillOrigin': ('.types', 'SkillOrigin'), + 'SkillLineage': ('.types', 'SkillLineage'), + 'SkillRecord': ('.types', 'SkillRecord'), + 'SkillVisibility': ('.types', 'SkillVisibility'), + 'SkillStore': ('.store', 'SkillStore'), + 'ExecutionAnalyzer': ('.analyzer', 'ExecutionAnalyzer'), + 'SkillEvolver': ('.evolver', 'SkillEvolver'), + 'EvolutionTrigger': ('.evolver', 'EvolutionTrigger'), + 'EvolutionContext': ('.evolver', 'EvolutionContext'), +} + + +def __getattr__(name: str): + try: + module_name, attr_name = _EXPORTS[name] + except KeyError as exc: + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') from exc + module = import_module(module_name, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value diff --git a/openspace/skill_engine/analyzer.py b/openspace/skill_engine/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5b74fabd57c97bbb5755f7799265e1698b36e2 --- /dev/null +++ b/openspace/skill_engine/analyzer.py @@ -0,0 +1,933 @@ +"""ExecutionAnalyzer — post-execution analysis and skill quality tracking. + +Responsibilities: + 1. After each task execution, load recording artifacts. + 2. Build an LLM prompt and obtain an ``ExecutionAnalysis``. + 3. Persist the analysis and update ``SkillRecord`` counters via ``SkillStore``. + 4. Surface evolution candidates for downstream processing. + +Integration: + Instantiated once during ``OpenSpace.initialize()``. + ``analyze_execution()`` is called in the ``finally`` block of ``OpenSpace.execute()``. +""" + +from __future__ import annotations + +import copy +import json +import re +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from openspace.grounding.core.tool import BaseTool + +from .types import ( + EvolutionSuggestion, + EvolutionType, + ExecutionAnalysis, + SkillCategory, + SkillJudgment, +) +from .store import SkillStore +from openspace.prompts import SkillEnginePrompts +from openspace.utils.logging import Logger +from .conversation_formatter import format_conversations + +if TYPE_CHECKING: + from openspace.llm import LLMClient + from openspace.grounding.core.quality import ToolQualityManager + from .registry import SkillRegistry + +logger = Logger.get_logger(__name__) + + +# Maximum characters of conversation log to include in the analysis prompt. +_MAX_CONVERSATION_CHARS = 80_000 + +# Per-section truncation limits +_TOOL_ERROR_MAX_CHARS = 1000 # Errors: keep key info, no full stack traces +_TOOL_SUCCESS_MAX_CHARS = 800 # Success results +_TOOL_ARGS_MAX_CHARS = 500 # Tool call arguments +_TOOL_SUMMARY_MAX_CHARS = 1500 # Embedded execution summaries from inner agents + +# Skill & analysis-agent constants +_SKILL_CONTENT_MAX_CHARS = 8000 # Max chars per skill SKILL.md in prompt +_ANALYSIS_MAX_ITERATIONS = 5 # Max tool-calling rounds for analysis agent + + +def _correct_skill_ids( + ids: List[str], known_ids: set, +) -> List[str]: + """Best-effort correction of LLM-hallucinated skill IDs. + + LLMs frequently garble the hex suffix of generated IDs (e.g. swap + ``cb`` → ``bc``). For each *id* not in *known_ids*, find the closest + known ID sharing the same name prefix (before ``__``) and within + edit-distance ≤ 3. If a unique match is found, silently replace it. + """ + if not known_ids: + return ids + + corrected: List[str] = [] + for raw_id in ids: + if raw_id in known_ids: + corrected.append(raw_id) + continue + + # Extract name prefix (everything before the first "__") + prefix = raw_id.split("__")[0] if "__" in raw_id else "" + + # Candidates: known IDs sharing the same name prefix + candidates = [ + k for k in known_ids + if prefix and k.split("__")[0] == prefix + ] + + best, best_dist = None, 4 # threshold: edit distance ≤ 3 + for cand in candidates: + d = _edit_distance(raw_id, cand) + if d < best_dist: + best, best_dist = cand, d + + if best is not None: + logger.info( + f"Corrected LLM skill ID: {raw_id!r} → {best!r} " + f"(edit_distance={best_dist})" + ) + corrected.append(best) + else: + corrected.append(raw_id) # keep as-is; evolver will warn + + return corrected + + +def _edit_distance(a: str, b: str) -> int: + """Levenshtein edit distance (compact DP, O(min(m,n)) space).""" + if len(a) < len(b): + a, b = b, a + if not b: + return len(a) + prev = list(range(len(b) + 1)) + for i, ca in enumerate(a, 1): + curr = [i] + [0] * len(b) + for j, cb in enumerate(b, 1): + curr[j] = min( + prev[j] + 1, + curr[j - 1] + 1, + prev[j - 1] + (0 if ca == cb else 1), + ) + prev = curr + return prev[-1] + + +class ExecutionAnalyzer: + """Analyzes task execution results and tracks skill quality. + + Args: + store: Persistence layer for skill records and analyses. + llm_client: LLM client used for the analysis call. + model: Override model for analysis. If None, uses ``llm_client``'s default model. + enabled: Set False to skip analysis entirely. + """ + + def __init__( + self, + store: SkillStore, + llm_client: "LLMClient", + model: Optional[str] = None, + enabled: bool = True, + skill_registry: Optional["SkillRegistry"] = None, + quality_manager: Optional["ToolQualityManager"] = None, + ) -> None: + self._store = store + self._llm_client = llm_client + self._model = model + self.enabled = enabled + self._skill_registry = skill_registry + self._quality_manager = quality_manager + + async def analyze_execution( + self, + task_id: str, + recording_dir: str, + execution_result: Dict[str, Any], + available_tools: Optional[List[BaseTool]] = None, + ) -> Optional[ExecutionAnalysis]: + """Run LLM analysis on a completed task and persist the result. + + Args: + task_id: Unique identifier for the task. + recording_dir: Path to the recording directory containing metadata.json, + conversations.jsonl, etc. + execution_result: The return value of ``OpenSpace.execute()`` — contains status, + iterations, skills_used, etc. + available_tools: BaseTool instances from the execution (shell tools, + MCP tools, etc.). Passed through to the analysis agent loop so + it can reproduce errors or verify results when trace data is + ambiguous. A lightweight ``run_shell`` tool is always appended. + """ + if not self.enabled: + return None + + rec_path = Path(recording_dir) + if not rec_path.is_dir(): + logger.warning( + f"Recording directory not found, skipping analysis: {recording_dir}" + ) + return None + + # Check for duplicate — one analysis per task + existing = self._store.load_analyses_for_task(task_id) + if existing is not None: + logger.debug(f"Analysis already exists for task {task_id}, skipping") + return existing + + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("analyzer") + except ImportError: + _src_tok = None + + try: + # 1. Load recording artifacts + context = self._load_recording_context(rec_path, execution_result) + if context is None: + return None + + # 2. Build prompt + prompt = self._build_analysis_prompt(context) + + # 3. Run analysis (agent loop with optional tool use) + raw_json = await self._run_analysis_loop( + prompt, available_tools=available_tools or [], + ) + if raw_json is None: + return None + + # 4. Parse into ExecutionAnalysis + analysis = self._parse_analysis(task_id, raw_json, context) + if analysis is None: + return None + + # 5. Persist + await self._store.record_analysis(analysis) + evo_types = [s.evolution_type.value for s in analysis.evolution_suggestions] + logger.info( + f"Execution analysis saved for task {task_id}: " + f"completed={analysis.task_completed}, " + f"skills_judged={len(analysis.skill_judgments)}, " + f"evolution_suggestions={evo_types or 'none'}" + ) + + # 6. Feed tool issues to quality manager (if available). + # Build tool-status map from raw traj records for dedup. + traj_tool_status = self._build_tool_status_map( + context.get("traj_records", []) + ) + await self._record_tool_quality_feedback(analysis, traj_tool_status) + + return analysis + + except Exception as e: + logger.error(f"Execution analysis failed for task {task_id}: {e}") + return None + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + async def get_evolution_candidates( + self, limit: int = 20 + ) -> List[ExecutionAnalysis]: + """Return recent analyses flagged as evolution candidates.""" + return self._store.load_evolution_candidates(limit=limit) + + @staticmethod + def _build_tool_status_map( + traj_records: List[Dict[str, Any]], + ) -> Dict[str, bool]: + """Build {tool_key: has_any_success} from raw traj records. + + Used for deduplication: if all calls for a tool already failed + (rule-based caught them), there's no need for the LLM to add + another failure record. + """ + tool_has_success: Dict[str, bool] = {} + for entry in traj_records: + backend = entry.get("backend", "unknown") + tool = entry.get("tool", "unknown") + server = entry.get("server", "") + status = (entry.get("result") or {}).get("status", "unknown") + + # Build canonical key matching the prompt format + key = f"{backend}:{server}:{tool}" if server else f"{backend}:{tool}" + + if key not in tool_has_success: + tool_has_success[key] = False + if status != "error": + tool_has_success[key] = True + return tool_has_success + + async def _record_tool_quality_feedback( + self, + analysis: ExecutionAnalysis, + traj_tool_status: Dict[str, bool], + ) -> None: + """Feed LLM-identified tool issues to the ToolQualityManager. + + **Deduplication**: The rule-based system already records each tool + call as success/failure. The LLM adds value only when it catches + *semantic* failures the rule-based system missed. + + ``traj_tool_status`` maps ``tool_key → has_any_success_call``. + If all calls already failed → skip (rule-based caught it). + If any call was "success" but LLM says problematic → inject correction. + If tool not in traj → trust LLM (internal/system call). + """ + if not self._quality_manager or not analysis.tool_issues: + return + try: + filtered_issues: list[str] = [] + for issue in analysis.tool_issues: + # Extract key from "key — description" + if "—" in issue: + key_part = issue.split("—", 1)[0].strip() + elif " - " in issue: + key_part = issue.split(" - ", 1)[0].strip() + else: + key_part = issue.strip() + + if key_part in traj_tool_status and not traj_tool_status[key_part]: + logger.debug( + f"Skipping LLM issue for {key_part}: " + f"rule-based already recorded all calls as errors" + ) + continue + filtered_issues.append(issue) + + if not filtered_issues: + return + + updated = await self._quality_manager.record_llm_tool_issues( + tool_issues=filtered_issues, + task_id=analysis.task_id, + ) + if updated: + logger.debug( + f"Fed {updated} LLM tool issue(s) to ToolQualityManager " + f"(filtered from {len(analysis.tool_issues)} total) " + f"for task {analysis.task_id}" + ) + except Exception as e: + # Quality feedback is best-effort; never break analysis flow + logger.debug(f"Tool quality feedback failed: {e}") + + def _load_recording_context( + self, + rec_path: Path, + execution_result: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """Load and structure all recording artifacts needed for analysis. + + Returns a dict with keys used by ``_build_analysis_prompt()``, + or None if critical files are missing. + """ + # metadata.json (always present) + metadata_file = rec_path / "metadata.json" + if not metadata_file.exists(): + logger.warning(f"metadata.json not found in {rec_path}") + return None + try: + metadata = json.loads(metadata_file.read_text(encoding="utf-8")) + except Exception as e: + logger.warning(f"Failed to read metadata.json: {e}") + return None + + # conversations.jsonl (primary analysis source) + conv_file = rec_path / "conversations.jsonl" + conversations: List[Dict[str, Any]] = [] + if conv_file.exists(): + try: + for line in conv_file.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if line: + conversations.append(json.loads(line)) + except Exception as e: + logger.warning(f"Failed to read conversations.jsonl: {e}") + + if not conversations: + logger.warning(f"No conversations found in {rec_path}, skipping analysis") + return None + + # traj.jsonl (structured tool execution records) + traj_records = self._load_traj_data(rec_path) + + # Extract key fields from metadata + task_description = metadata.get( + "task_description", + (metadata.get("skill_selection") or {}).get("task", ""), + ) + if not task_description: + task_description = execution_result.get("instruction", "") + + skill_selection = metadata.get("skill_selection", {}) + selected_skills = skill_selection.get("selected", []) + + retrieved_tools = metadata.get("retrieved_tools", {}) + tool_defs = retrieved_tools.get("tools", []) + tool_names = [t.get("name", "") for t in tool_defs] + + # Extract skill content from conversations setup message + # selected_skills contains skill_ids (stored in metadata by tool_layer) + skill_contents: Dict[str, str] = {} + for conv in conversations: + if conv.get("type") == "setup": + for msg in conv.get("messages", []): + content = msg.get("content", "") + if isinstance(content, str) and "# Active Skills" in content: + skill_contents = self._extract_skill_contents( + content, selected_skills + ) + break + + # Execution status — prefer runtime result, fall back to persisted metadata + status = execution_result.get("status", "") + iterations = execution_result.get("iterations", 0) + if not status: + outcome = metadata.get("execution_outcome", {}) + status = outcome.get("status", "unknown") + iterations = iterations or outcome.get("iterations", 0) + + # Derive actually-used tools from traj.jsonl + # traj_records tells us exactly which tools were invoked; retrieved_tools + # is the broader set that was *available* to the agent. + used_tool_keys: set = set() + for entry in traj_records: + backend = entry.get("backend", "") + tool = entry.get("tool", "") + server = entry.get("server", "") + if tool: + used_tool_keys.add(f"{backend}:{tool}") + if server: + used_tool_keys.add(f"{backend}:{server}:{tool}") + + return { + "task_id": metadata.get("task_id", ""), + "task_description": task_description, + "selected_skills": selected_skills, + "skill_selection": skill_selection, + "skill_contents": skill_contents, + "tool_names": tool_names, + "tool_defs": tool_defs, + "used_tool_keys": used_tool_keys, + "conversations": conversations, + "traj_records": traj_records, + "execution_status": status, + "iterations": iterations, + "recording_dir": str(rec_path), + } + + @staticmethod + def _load_traj_data(rec_path: Path) -> List[Dict[str, Any]]: + """Load traj.jsonl and return structured tool execution records. + + Each record contains: step, timestamp, backend, tool, command, + result (status, output/stderr), parameters, extra. + """ + traj_file = rec_path / "traj.jsonl" + records: List[Dict[str, Any]] = [] + if not traj_file.exists(): + return records + try: + for line in traj_file.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if line: + records.append(json.loads(line)) + except Exception as e: + logger.warning(f"Failed to read traj.jsonl: {e}") + return records + + @staticmethod + def _extract_skill_contents( + injection_text: str, + selected_skill_ids: List[str], + ) -> Dict[str, str]: + """Parse the injected skill context to extract per-skill content. + + The injection text uses ``### Skill: {skill_id}`` headers, so + we split by that pattern and match against the provided skill_ids. + """ + contents: Dict[str, str] = {} + id_set = set(selected_skill_ids) + parts = re.split(r"###\s+Skill:\s+", injection_text) + for part in parts[1:]: # skip preamble + lines = part.split("\n", 1) + sid = lines[0].strip() + body = lines[1] if len(lines) > 1 else "" + if sid in id_set: + contents[sid] = body[:5000] + return contents + + def _load_skill_contents_from_disk( + self, skill_ids: List[str], + ) -> Dict[str, Dict[str, str]]: + """Load skill SKILL.md from disk via SkillRegistry. + + Returns dict mapping ``skill_id`` → ``{"content", "dir", "description", "name"}``. + Falls back gracefully if registry is unavailable. + """ + result: Dict[str, Dict[str, str]] = {} + if not self._skill_registry or not skill_ids: + return result + for sid in skill_ids: + meta = self._skill_registry.get_skill(sid) + if not meta: + continue + content = self._skill_registry.load_skill_content(sid) + if not content: + continue + skill_dir = str(meta.path.parent) + if len(content) > _SKILL_CONTENT_MAX_CHARS: + content = ( + content[:_SKILL_CONTENT_MAX_CHARS] + + f"\n\n... [truncated at {_SKILL_CONTENT_MAX_CHARS} chars — " + f"use read_file(\"{meta.path}\") to see full content]" + ) + result[sid] = { + "content": content, + "dir": skill_dir, + "description": meta.description, + "name": meta.name, + } + return result + + def _build_analysis_prompt(self, context: Dict[str, Any]) -> str: + """Build the LLM prompt for execution analysis. + + ``context["selected_skills"]`` contains true ``skill_id`` values. + """ + # Format conversation log (priority-based truncation) + conv_text = self._format_conversations(context["conversations"]) + + # Format traj.jsonl tool execution summary + traj_section = self._format_traj_summary(context["traj_records"]) + + # Skill section — keyed by skill_id throughout + selected_skill_ids: List[str] = context["selected_skills"] + skill_data = self._load_skill_contents_from_disk(selected_skill_ids) + + if not skill_data and selected_skill_ids: + # Fallback: use content extracted from conversation injection text + for sid in selected_skill_ids: + content = context["skill_contents"].get(sid) + if content: + skill_data[sid] = { + "content": content, + "dir": "(unknown)", + "description": "", + "name": sid, + } + + skill_section = "" + if skill_data: + parts = [] + for sid, info in skill_data.items(): + desc_line = ( + f"\n**Description**: {info['description']}" + if info.get("description") else "" + ) + display_name = info.get("name", sid) + parts.append( + f"### {sid}\n" + f"**Name**: {display_name}\n" + f"**Directory**: `{info['dir']}`{desc_line}\n\n" + f"{info['content']}" + ) + skill_section = "## Selected Skills\n\n" + "\n\n---\n\n".join(parts) + # If no skills selected → skill_section stays "" (omitted from prompt) + + # Tool list + tool_list = self._format_tool_list( + context.get("tool_defs", []), + context.get("used_tool_keys", set()), + ) + + # Resource info (recording dir + skill dirs) + rec_dir = context.get("recording_dir", "") + resource_lines: List[str] = [] + if rec_dir: + resource_lines.append(f"**Recording directory**: `{rec_dir}`") + rec_path = Path(rec_dir) + if rec_path.is_dir(): + files = [f.name for f in sorted(rec_path.iterdir()) if f.is_file()] + if files: + resource_lines.append(f" Files: {', '.join(files)}") + + skill_dirs = { + sid: info["dir"] + for sid, info in skill_data.items() + if info.get("dir") and info["dir"] != "(unknown)" + } + if skill_dirs: + resource_lines.append("**Skill directories**:") + for sid, d in skill_dirs.items(): + resource_lines.append(f" - {sid}: `{d}`") + + resource_lines.append( + "\nYou have `read_file`, `list_dir`, and `run_shell` tools for deeper " + "investigation.\n**In most cases the trace above is sufficient** — only " + "use tools when evidence is ambiguous or you need to verify specific details." + ) + resource_info = "\n".join(resource_lines) + + return SkillEnginePrompts.execution_analysis( + task_description=context["task_description"], + execution_status=context["execution_status"], + iterations=context["iterations"], + tool_list=tool_list, + skill_section=skill_section, + conversation_log=conv_text, + traj_summary=traj_section, + selected_skill_ids_json=json.dumps(selected_skill_ids), + resource_info=resource_info, + ) + + @staticmethod + def _format_tool_list( + tool_defs: List[Dict[str, Any]], + used_tool_keys: set = None, + ) -> str: + """Format tool definitions with usage annotation. + + Tools that appear in ``used_tool_keys`` (derived from traj.jsonl) + are marked as "Actually used". This lets the analysis LLM focus + on what actually happened without being distracted by unused tools. + + Args: + tool_defs: Tool definitions from ``metadata.retrieved_tools.tools``. + Backend should be correctly recorded (mcp, shell, etc.) now + that the recording layer prefers ``runtime_info.backend``. + used_tool_keys: Set of ``"backend:tool_name"`` or ``"backend:server:tool_name"`` + strings derived from traj.jsonl. + """ + if not tool_defs: + return "none" + if used_tool_keys is None: + used_tool_keys = set() + + used_parts = [] + available_parts = [] + for t in tool_defs: + name = t.get("name", "?") + backend = t.get("backend", "?") + server = t.get("server_name") + label = f"{name} ({backend}/{server})" if server else f"{name} ({backend})" + + # Match by backend:tool or backend:server:tool + key = f"{backend}:{name}" + key_with_server = f"{backend}:{server}:{name}" if server else "" + if key in used_tool_keys or key_with_server in used_tool_keys: + used_parts.append(label) + else: + available_parts.append(label) + + sections = [] + if used_parts: + sections.append(f"Actually used: {', '.join(used_parts)}") + if available_parts: + sections.append(f"Available but unused: {', '.join(available_parts)}") + return "\n".join(sections) if sections else "none" + + @staticmethod + def _format_traj_summary(traj_records: List[Dict[str, Any]]) -> str: + """Format traj.jsonl records into a concise tool execution timeline. + + This provides the LLM with a structured view of every tool invocation + and its outcome, complementing the conversation log which shows the + agent's reasoning. + """ + if not traj_records: + return "(no traj.jsonl data available)" + + lines = [f"Total tool invocations: {len(traj_records)}"] + error_count = sum( + 1 for r in traj_records + if r.get("result", {}).get("status") == "error" + ) + if error_count: + lines.append(f"Errors: {error_count}/{len(traj_records)}") + + lines.append("") # blank line before timeline + + for entry in traj_records: + step = entry.get("step", "?") + backend = entry.get("backend", "?") + tool = entry.get("tool", "?") + server = entry.get("server", "") + result = entry.get("result", {}) + status = result.get("status", "?") + + # Build compact one-line summary + command = entry.get("command", "") + if isinstance(command, str) and len(command) > 150: + command = command[:150] + "..." + + # Include server for MCP tools so key is unambiguous + if server: + tool_label = f"{backend}:{server}:{tool}" + else: + tool_label = f"{backend}:{tool}" + line = f" Step {step} [{tool_label}] → {status}" + + # Add error details for failed steps + if status == "error": + stderr = result.get("stderr", result.get("output", "")) + if isinstance(stderr, str) and stderr: + # Extract first meaningful line of error + error_first_line = stderr.strip().split("\n")[0][:200] + line += f" | {error_first_line}" + + # Add brief command context + if command and not command.startswith("```"): + line += f" | cmd: {command[:100]}" + + lines.append(line) + + return "\n".join(lines) + + @staticmethod + def _format_conversations(conversations: List[Dict[str, Any]]) -> str: + """Format conversations.jsonl into a readable text block for the LLM. + + Delegates to :func:`conversation_formatter.format_conversations`. + """ + return format_conversations(conversations, _MAX_CONVERSATION_CHARS) + + async def _run_analysis_loop( + self, + prompt: str, + available_tools: Optional[List[BaseTool]] = None, + ) -> Optional[Dict[str, Any]]: + """Run analysis as an agent loop with optional tool use. + + Most analyses complete in a single pass (LLM outputs JSON directly). + When the trace is ambiguous, the LLM may call the execution's own + tools (``read_file``, ``list_dir``, ``run_shell``, ``shell_agent``, + MCP tools, etc.) for deeper investigation or error reproduction. + + Reuses ``LLMClient.complete()`` for retry, rate-limiting, tool + serialization, and tool execution. + + Conversations are recorded to ``conversations.jsonl`` via + ``RecordingManager`` (agent_name="ExecutionAnalyzer") so the full + analysis dialogue is preserved alongside the grounding trace. + """ + from openspace.recording import RecordingManager + + model = self._model or self._llm_client.model + analysis_tools: List[BaseTool] = list(available_tools or []) + + messages: List[Dict[str, Any]] = [ + {"role": "user", "content": prompt}, + ] + + # Record initial conversation setup + await RecordingManager.record_conversation_setup( + setup_messages=copy.deepcopy(messages), + tools=analysis_tools if analysis_tools else None, + agent_name="ExecutionAnalyzer", + ) + + for iteration in range(_ANALYSIS_MAX_ITERATIONS): + is_last = iteration == _ANALYSIS_MAX_ITERATIONS - 1 + + # Snapshot message count before any additions + LLM call + msg_count_before = len(messages) + + # On the final iteration, force JSON output (no tools). + if is_last: + messages.append({ + "role": "system", + "content": ( + "This is your FINAL round — no more tool calls allowed. " + "You MUST output the JSON analysis object now based on " + "all information gathered so far." + ), + }) + + try: + result = await self._llm_client.complete( + messages=messages, + tools=analysis_tools if not is_last else None, + execute_tools=True, + model=model, + ) + except Exception as e: + logger.error(f"Analysis LLM call failed (iter {iteration}): {e}") + return None + + content = result["message"].get("content", "") + has_tool_calls = result["has_tool_calls"] + + # Record iteration delta + updated_messages = result["messages"] + delta = updated_messages[msg_count_before:] + await RecordingManager.record_iteration_context( + iteration=iteration + 1, + delta_messages=copy.deepcopy(delta), + response_metadata={ + "has_tool_calls": has_tool_calls, + "tool_calls_count": len(result.get("tool_results", [])), + "is_final": not has_tool_calls, + }, + agent_name="ExecutionAnalyzer", + ) + + if not has_tool_calls: + # No tool calls → final response, parse JSON + return self._extract_json(content) + + # Tools were called and executed by complete() — continue with + # the updated messages (includes assistant + tool result messages). + messages = updated_messages + logger.debug( + f"Analysis agent used tools " + f"(iter {iteration + 1}/{_ANALYSIS_MAX_ITERATIONS})" + ) + + # Should not reach here (last iteration disables tools), but just in case + logger.warning( + f"Analysis agent reached max iterations ({_ANALYSIS_MAX_ITERATIONS})" + ) + for m in reversed(messages): + if m.get("role") == "assistant" and m.get("content"): + return self._extract_json(m["content"]) + return None + + @staticmethod + def _extract_json(text: str) -> Optional[Dict[str, Any]]: + """Extract a JSON object from LLM response text. + + Handles markdown code fences and bare JSON. + """ + # Try code block first + code_match = re.search( + r"```(?:json)?\s*\n?(.*?)\n?```", text, re.DOTALL + ) + if code_match: + text = code_match.group(1).strip() + else: + # Try bare JSON object + json_match = re.search(r"\{.*\}", text, re.DOTALL) + if json_match: + text = json_match.group() + + try: + data = json.loads(text) + if isinstance(data, dict): + return data + logger.warning(f"LLM returned non-dict JSON: {type(data)}") + return None + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse LLM analysis JSON: {e}") + logger.debug(f"Raw LLM output (first 500 chars): {text[:500]}") + return None + + @staticmethod + def _parse_analysis( + task_id: str, + data: Dict[str, Any], + context: Dict[str, Any], + ) -> Optional[ExecutionAnalysis]: + """Convert the raw LLM JSON output into an ExecutionAnalysis. + + Also attaches observed tool execution records from ``traj.jsonl`` + so the analysis contains both LLM judgments and factual data. + """ + try: + now = datetime.now() + + # Collect all known skill IDs from context for fuzzy correction. + # LLMs often garble hex suffixes when reproducing skill IDs. + known_skill_ids: set = set() + for sid in context.get("selected_skills", []): + known_skill_ids.add(sid) + # Also include skill IDs from the skill_selection metadata + skill_sel = context.get("skill_selection") or {} + for sid in skill_sel.get("available_skills", []): + known_skill_ids.add(sid) + + # Parse skill judgments (LLM-generated) + judgments: List[SkillJudgment] = [] + for jd in data.get("skill_judgments", []): + raw_sid = jd.get("skill_id", "") + corrected = _correct_skill_ids([raw_sid], known_skill_ids) + judgments.append( + SkillJudgment( + skill_id=corrected[0] if corrected else raw_sid, + skill_applied=bool(jd.get("skill_applied", False)), + note=jd.get("note", ""), + ) + ) + + # Parse evolution_suggestions (new format: list of typed suggestions) + suggestions: List[EvolutionSuggestion] = [] + for raw_sug in data.get("evolution_suggestions", []): + try: + evo_type = EvolutionType(raw_sug.get("type", "")) + except ValueError: + logger.debug(f"Unknown evolution type: {raw_sug.get('type')}") + continue + + cat = None + if raw_sug.get("category"): + try: + cat = SkillCategory(raw_sug["category"]) + except ValueError: + logger.debug(f"Unknown category: {raw_sug.get('category')}") + + # Support both "target_skills" (list) and legacy "target_skill" (str) + raw_targets = raw_sug.get("target_skills") + if isinstance(raw_targets, list): + targets = [t for t in raw_targets if t] + else: + legacy = raw_sug.get("target_skill", "") + targets = [legacy] if legacy else [] + + # Correct LLM-hallucinated skill IDs against known IDs. + # LLMs frequently swap/drop characters in hex suffixes + # (e.g. "61f694bc" instead of "61f694cb"). + targets = _correct_skill_ids(targets, known_skill_ids) + + suggestions.append(EvolutionSuggestion( + evolution_type=evo_type, + target_skill_ids=targets, + category=cat, + direction=raw_sug.get("direction", ""), + )) + + analysis = ExecutionAnalysis( + task_id=task_id, + timestamp=now, + task_completed=bool(data.get("task_completed", False)), + execution_note=data.get("execution_note", ""), + tool_issues=data.get("tool_issues", []), + skill_judgments=judgments, + evolution_suggestions=suggestions, + analyzed_by=data.get("analyzed_by", ""), + analyzed_at=now, + ) + return analysis + + except Exception as e: + logger.error(f"Failed to parse analysis response: {e}") + return None + + # Convenience queries (delegated to store) + def get_store(self) -> SkillStore: + """Access the underlying SkillStore for direct queries.""" + return self._store + + def close(self) -> None: + """Close the store connection.""" + self._store.close() diff --git a/openspace/skill_engine/conversation_formatter.py b/openspace/skill_engine/conversation_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..3f11f212d57c0de4fe981910c59f3255206b93c1 --- /dev/null +++ b/openspace/skill_engine/conversation_formatter.py @@ -0,0 +1,335 @@ +"""Conversation log formatting for execution analysis. + +Converts ``conversations.jsonl`` entries into a priority-based text block +suitable for LLM analysis prompts. All functions are pure (stateless). + +Priority levels (lower = more important): + 0 — CRITICAL : User instruction (never truncated) + 1 — CRITICAL : Final iteration assistant response (never truncated) + 2 — HIGH : Tool calls (name + args) AND tool errors — kept together + 3 — HIGH : Non-final assistant reasoning; tool results with embedded summary + 4 — MEDIUM : Tool success results (try to preserve) + 5 — LOW : System guidance messages between iterations + SKIP : Skill injection text, verbose system prompts (not included; + skill & tool info are provided separately in the prompt) +""" + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +# Per-section truncation limits (kept in sync with analyzer constants) +TOOL_ERROR_MAX_CHARS = 1000 +TOOL_SUCCESS_MAX_CHARS = 800 +TOOL_ARGS_MAX_CHARS = 500 +TOOL_SUMMARY_MAX_CHARS = 1500 + + +def format_conversations( + conversations: List[Dict[str, Any]], + budget: int, +) -> str: + """Format ``conversations.jsonl`` entries into a readable text block. + + Uses priority-based truncation instead of simple tail-truncation. + + When total exceeds *budget*: + 1. Include all priority ≤ 3 (CRITICAL + HIGH) segments in full. + 2. Add MEDIUM + LOW segments until budget is exhausted, truncating + if possible. + 3. If even HIGH content exceeds budget, keep priority 0-1 in full, + budget-allocate priority 2, and summarize priority 3. + """ + # Count total iterations for priority assignment + total_iters = sum( + 1 for c in conversations if c.get("type") == "iteration" + ) + + # Phase 1: Collect all segments in chronological order with priority + segments: List[Dict[str, Any]] = [] + + for conv in conversations: + conv_type = conv.get("type", "") + if conv_type == "setup": + _collect_setup_segments(conv, segments) + elif conv_type == "iteration": + _collect_iteration_segments(conv, total_iters, segments) + + # Phase 2: Assemble with budget management + return _assemble_with_budget(segments, budget) + +def _collect_setup_segments( + conv: Dict[str, Any], + segments: List[Dict[str, Any]], +) -> None: + """Extract segments from a ``type: "setup"`` conversation entry. + + Only the user instruction is extracted. System prompts (including skill + injection text and tool descriptions) are skipped — they are provided in + dedicated sections of the analysis prompt. + """ + for msg in conv.get("messages", []): + role = msg.get("role", "") + content = msg.get("content", "") + if not isinstance(content, str): + content = str(content) + + if role == "user": + segments.append({ + "priority": 0, # CRITICAL — always keep + "text": f"[USER INSTRUCTION]\n{content}", + "iteration": 0, + "role": "user", + "truncatable_to": None, + }) + +def _collect_iteration_segments( + conv: Dict[str, Any], + total_iters: int, + segments: List[Dict[str, Any]], +) -> None: + """Extract segments from a ``type: "iteration"`` conversation entry. + + Key design decisions: + - Tool calls and tool errors share the SAME high priority (2) + - Tool success results get MEDIUM priority (4) + - Shell agent results with embedded "Execution Summary" get HIGH (3). + """ + iteration = conv.get("iteration", "?") + is_last = (iteration == total_iters) if isinstance(iteration, int) else False + + # Process delta_messages in order + for msg in conv.get("delta_messages", []): + role = msg.get("role", "") + content = msg.get("content", "") + if not isinstance(content, str): + content = str(content) + + if role == "assistant": + # Assistant reasoning + if content: + priority = 1 if is_last else 3 + segments.append({ + "priority": priority, + "text": f"[Iter {iteration}] ASSISTANT: {content}", + "iteration": iteration, + "role": "assistant", + "truncatable_to": None, + }) + + # Tool calls + for tc in msg.get("tool_calls", []): + fn = tc.get("function", {}) + fn_name = fn.get("name", "?") + fn_args = fn.get("arguments", "") + if isinstance(fn_args, str) and len(fn_args) > TOOL_ARGS_MAX_CHARS: + fn_args = fn_args[:TOOL_ARGS_MAX_CHARS] + "..." + segments.append({ + "priority": 2, # HIGH — paired with tool results/errors + "text": f"[Iter {iteration}] TOOL_CALL: {fn_name}({fn_args})", + "iteration": iteration, + "role": "tool_call", + "truncatable_to": None, + }) + + elif role == "tool": + # Tool result + is_error = _is_error_result(content) + + if is_error: + truncated = content[:TOOL_ERROR_MAX_CHARS] + if len(content) > TOOL_ERROR_MAX_CHARS: + truncated += f"... [truncated, total {len(content)} chars]" + segments.append({ + "priority": 2, # HIGH — errors are critical, same tier as tool calls + "text": f"[Iter {iteration}] TOOL_ERROR: {truncated}", + "iteration": iteration, + "role": "tool_error", + "truncatable_to": None, + }) + else: + # Check if result contains a self-generated summary + # (e.g. shell_agent produces "Execution Summary (N steps):") + summary = _extract_embedded_summary(content) + if summary: + # Show the embedded summary (high value, compact) + segments.append({ + "priority": 3, # HIGH — self-generated summaries are informative + "text": f"[Iter {iteration}] TOOL_RESULT (with summary):\n{summary}", + "iteration": iteration, + "role": "tool_result", + "truncatable_to": 500, + }) + else: + truncated = content[:TOOL_SUCCESS_MAX_CHARS] + if len(content) > TOOL_SUCCESS_MAX_CHARS: + truncated += f"... [truncated, total {len(content)} chars]" + segments.append({ + "priority": 4, # MEDIUM — try to preserve success results + "text": f"[Iter {iteration}] TOOL_RESULT: {truncated}", + "iteration": iteration, + "role": "tool_result", + "truncatable_to": 300, + }) + + elif role == "system": + # System guidance between iterations (e.g. "Iteration N complete...") + if content: + segments.append({ + "priority": 5, # LOW — guidance messages + "text": f"[Iter {iteration}] SYSTEM: {content}", + "iteration": iteration, + "role": "system", + "truncatable_to": 150, + }) + +def _assemble_with_budget( + segments: List[Dict[str, Any]], + budget: int, +) -> str: + """Assemble segments into final text respecting the character budget. + + Strategy: + 1. Include all segments with priority ≤ 3 (CRITICAL + HIGH) in full. + 2. Add MEDIUM + LOW segments in chronological order until budget is hit. + 3. If even HIGH-priority content exceeds budget, progressively truncate + older iterations while preserving user instruction and final iteration. + """ + # Calculate essential (priority ≤ 3) size + essential = [s for s in segments if s["priority"] <= 3] + essential_chars = sum(len(s["text"]) for s in essential) + + remaining_budget = budget - essential_chars + + if remaining_budget < 0: + # Essential content alone exceeds budget — need to reduce + # Keep priority 0-1 (user instruction + final iteration) in full + # Truncate priority 2-3 (tool calls/errors + older assistant content) + return _assemble_essential_only(segments, budget) + + # Build output in chronological order + output_parts: List[str] = [] + used_chars = 0 + skipped_count = 0 + + for seg in segments: + text = seg["text"] + priority = seg["priority"] + + if priority <= 3: + # Essential — always include + output_parts.append(text) + used_chars += len(text) + 1 + elif used_chars + len(text) + 1 <= budget: + # Within budget — include + output_parts.append(text) + used_chars += len(text) + 1 + else: + # Over budget — try truncation + truncatable_to = seg.get("truncatable_to") + if truncatable_to and len(text) > truncatable_to: + truncated = text[:truncatable_to] + "... [budget-truncated]" + if used_chars + len(truncated) + 1 <= budget: + output_parts.append(truncated) + used_chars += len(truncated) + 1 + continue + skipped_count += 1 + + if skipped_count > 0: + output_parts.append( + f"\n[... {skipped_count} lower-priority segment(s) omitted due to length ...]" + ) + + return "\n\n".join(output_parts) + + +def _assemble_essential_only( + segments: List[Dict[str, Any]], + budget: int, +) -> str: + """Fallback: even essential content exceeds budget. + + Keep: + - User instruction (priority 0) — never truncated + - Final iteration (priority 1) — never truncated + - Tool calls + tool errors (priority 2) — budget-allocated, truncated if needed + - Non-final assistant reasoning (priority 3) — heavily summarized + """ + output_parts: List[str] = [] + used_chars = 0 + + # Pass 1: priority 0 and 1 (user instruction + final iteration) + for seg in segments: + if seg["priority"] <= 1: + output_parts.append(seg["text"]) + used_chars += len(seg["text"]) + 1 + + remaining = budget - used_chars + + # Pass 2: priority 2 (tool calls + tool errors) — budget-allocated + tool_segments = [s for s in segments if s["priority"] == 2] + if tool_segments: + per_segment_budget = max(400, remaining // (len(tool_segments) + 1)) + for seg in tool_segments: + text = seg["text"] + if len(text) > per_segment_budget: + text = text[:per_segment_budget] + "... [budget-truncated]" + if used_chars + len(text) + 1 <= budget: + output_parts.append(text) + used_chars += len(text) + 1 + + # Pass 3: priority 3 (non-final assistant reasoning) — one-line summaries + assistants = [s for s in segments if s["priority"] == 3] + if assistants and used_chars < budget: + output_parts.append("\n--- Older iteration summaries ---") + for seg in assistants: + first_line = seg["text"].split("\n", 1)[0][:200] + if used_chars + len(first_line) + 1 > budget: + output_parts.append("[... remaining iterations omitted ...]") + break + output_parts.append(first_line) + used_chars += len(first_line) + 1 + + return "\n\n".join(output_parts) + +def _is_error_result(content: str) -> bool: + """Detect if a tool result represents an error.""" + if not content: + return False + # Check common error patterns in the first 200 chars + head = content[:200].lower() + return ( + content.startswith("[ERROR]") + or content.startswith("ERROR") + or "error" in head[:50] + or "task failed" in head + or "connection refused" in head + or "timed out" in head + or "traceback" in head + ) + + +def _extract_embedded_summary(content: str) -> Optional[str]: + """Extract self-generated summary from tool result content. + + Shell agent results often contain an ``Execution Summary (N steps):`` + block that provides a compact view of what happened internally. + This is more informative than the raw output. + """ + # Look for "Execution Summary (N steps):" pattern + match = re.search( + r"(Execution Summary \(\d+ steps?\):.*?)(?:={10,}|$)", + content, + re.DOTALL, + ) + if match: + summary = match.group(1).strip() + # Also capture any "Summary:" line after the steps + summary_match = re.search(r"\nSummary:\s*(.+)", content) + if summary_match: + summary += f"\nConclusion: {summary_match.group(1).strip()}" + return summary[:TOOL_SUMMARY_MAX_CHARS] + + return None + diff --git a/openspace/skill_engine/evolver.py b/openspace/skill_engine/evolver.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9afdc3919c14c2a50c1acb33b54e173e8bdb8a --- /dev/null +++ b/openspace/skill_engine/evolver.py @@ -0,0 +1,1520 @@ +"""SkillEvolver — execute skill evolution actions. + +Three evolution types: + FIX — repair broken/outdated instructions (in-place, same name) + DERIVED — create enhanced version from existing skill (new directory) + CAPTURED — capture novel reusable pattern from execution (brand new skill) + +Three trigger sources: + 1. Post-analysis — analyzer found evolution suggestions for a specific task + 2. Tool degradation — ToolQualityManager detected problematic tools + 3. Metric monitor — periodic scan of skill health indicators + +All triggers produce an EvolutionContext → evolve() → LLM agent loop → +apply-retry cycle → validation → store persistence. +""" + +from __future__ import annotations + +import asyncio +import copy +import json +import re +import shutil +import uuid +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING + +from .types import ( + EvolutionSuggestion, + EvolutionType, + ExecutionAnalysis, + SkillCategory, + SkillLineage, + SkillOrigin, + SkillRecord, +) +from .patch import ( + PatchType, + SkillEditResult, + collect_skill_snapshot, + create_skill, + fix_skill, + derive_skill, + SKILL_FILENAME, +) +from .skill_utils import ( + extract_change_summary as _extract_change_summary, + get_frontmatter_field as _extract_frontmatter_field, + set_frontmatter_field as _set_frontmatter_field, + strip_markdown_fences as _strip_markdown_fences, + truncate as _truncate, + validate_skill_dir as _validate_skill_dir, +) +from .registry import write_skill_id +from .store import SkillStore +from openspace.prompts import SkillEnginePrompts +from openspace.utils.logging import Logger + +if TYPE_CHECKING: + from .registry import SkillRegistry + from openspace.llm import LLMClient + from openspace.grounding.core.tool import BaseTool + from openspace.grounding.core.quality.types import ToolQualityRecord + +logger = Logger.get_logger(__name__) + +EVOLUTION_COMPLETE = SkillEnginePrompts.EVOLUTION_COMPLETE +EVOLUTION_FAILED = SkillEnginePrompts.EVOLUTION_FAILED + +_SKILL_CONTENT_MAX_CHARS = 12_000 # Max chars of SKILL.md in evolution prompt +_MAX_SKILL_NAME_LENGTH = 50 # Max chars for a skill name (directory name) + + +def _sanitize_skill_name(name: str) -> str: + """Enforce naming rules for skill names (used as directory names). + + - Lowercase, hyphens only (no underscores or special chars) + - Truncate to ``_MAX_SKILL_NAME_LENGTH`` at a word boundary + - Remove trailing hyphens + """ + # Normalize: lowercase, replace underscores and spaces with hyphens + clean = re.sub(r"[^a-z0-9\-]", "-", name.lower().strip()) + # Collapse multiple hyphens + clean = re.sub(r"-{2,}", "-", clean).strip("-") + + if len(clean) <= _MAX_SKILL_NAME_LENGTH: + return clean + + # Truncate at a hyphen boundary to avoid cutting words + truncated = clean[:_MAX_SKILL_NAME_LENGTH] + last_hyphen = truncated.rfind("-") + if last_hyphen > _MAX_SKILL_NAME_LENGTH // 2: + truncated = truncated[:last_hyphen] + return truncated.strip("-") + +_ANALYSIS_CONTEXT_MAX = 5 # Max recent analyses to include in prompt +_ANALYSIS_NOTE_MAX_CHARS = 500 # Per-analysis note truncation + +# Agent loop / retry constants +_MAX_EVOLUTION_ITERATIONS = 5 # Max tool-calling rounds for evolution agent +_MAX_EVOLUTION_ATTEMPTS = 3 # Max apply-retry attempts per evolution + +# Rule-based thresholds for candidate screening (relaxed — LLM confirms) +_FALLBACK_THRESHOLD = 0.4 # Relaxed from 0.5 for wider screening +_LOW_COMPLETION_THRESHOLD = 0.35 # Relaxed from 0.3 +_HIGH_APPLIED_FOR_FIX = 0.4 # Relaxed from 0.5 +_MODERATE_EFFECTIVE_THRESHOLD = 0.55 # Relaxed from 0.5 +_MIN_APPLIED_FOR_DERIVED = 0.25 # Relaxed from 0.3 + + +class EvolutionTrigger(str, Enum): + """What initiated this evolution.""" + ANALYSIS = "analysis" # Post-execution analysis suggestion + TOOL_DEGRADATION = "tool_degradation" # Tool quality degradation detected + METRIC_MONITOR = "metric_monitor" # Periodic skill health check + + +@dataclass +class EvolutionContext: + """Unified context for all evolution triggers. + + For trigger 1 (ANALYSIS): source_task_id is set, recent_analyses may be + just the single triggering analysis. + For triggers 2/3: source_task_id is None, recent_analyses are loaded + from the skill's historical records. + """ + trigger: EvolutionTrigger + suggestion: EvolutionSuggestion + + # Parent skill context + skill_records: List[SkillRecord] = field(default_factory=list) + skill_contents: List[str] = field(default_factory=list) + skill_dirs: List[Path] = field(default_factory=list) + + # Task context + source_task_id: Optional[str] = None + recent_analyses: List[ExecutionAnalysis] = field(default_factory=list) + + # Trigger-specific context + tool_issue_summary: str = "" # For TOOL_DEGRADATION + metric_summary: str = "" # For METRIC_MONITOR + + # Available tools for agent loop (read_file, web_search, shell, MCP, etc.) + available_tools: List["BaseTool"] = field(default_factory=list) + + +class SkillEvolver: + """Execute skill evolution actions. + + Single entry point: ``evolve()`` takes an EvolutionContext, runs an + LLM agent loop (with optional tool use), applies the edit with retry, + validates the result, and persists the new SkillRecord via ``SkillStore``. + + Concurrency: + ``max_concurrent`` controls the semaphore that throttles parallel + evolutions across all trigger types. File I/O is synchronous and + naturally serialized by the event loop; only LLM calls run in + parallel. + + Anti-loop (Trigger 2 — tool degradation): + ``_addressed_degradations`` is a ``Dict[str, Set[str]]`` mapping + ``tool_key → {skill_id, …}`` for skills that have already been + evolved to handle a specific tool's degradation. At the start of + each ``process_tool_degradation`` call, tools that are no longer + in the problematic list are pruned — so if a tool **recovers and + then degrades again**, all its dependent skills are re-evaluated. + + Anti-loop (Trigger 3 — metric check): + Newly-evolved skills have ``total_selections=0``, requiring + ``min_selections`` (default 5) fresh data points before being + re-evaluated. This is data-driven and needs no time-based guard. + + Background: + Trigger 2 and 3 are always launched as ``asyncio.Task``s via + ``schedule_background()`` so they never block the main flow. + """ + + def __init__( + self, + store: SkillStore, + registry: "SkillRegistry", + llm_client: "LLMClient", + model: Optional[str] = None, + available_tools: Optional[List["BaseTool"]] = None, + *, + max_concurrent: int = 3, + ) -> None: + self._store = store + self._registry = registry + self._llm_client = llm_client + self._model = model + self._available_tools: List["BaseTool"] = available_tools or [] + + # Concurrency: semaphore limits parallel LLM sessions + self._max_concurrent = max(1, max_concurrent) + self._semaphore = asyncio.Semaphore(self._max_concurrent) + + # Anti-loop for Trigger 2: tracks which skills have already been + # evolved for each degraded tool. Keyed by tool_key. + # Pruned when a tool leaves the problematic list (= recovered). + self._addressed_degradations: Dict[str, Set[str]] = {} + + # Track background tasks so they can be awaited on shutdown. + self._background_tasks: Set[asyncio.Task] = set() + + def set_available_tools(self, tools: List["BaseTool"]) -> None: + """Update the tools available for evolution agent loops.""" + self._available_tools = list(tools) + + async def wait_background(self) -> None: + """Await all outstanding background evolution tasks. + + Call this during shutdown / cleanup to ensure nothing is lost. + """ + if self._background_tasks: + logger.info( + f"Waiting for {len(self._background_tasks)} background " + f"evolution task(s) to finish..." + ) + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def evolve(self, ctx: EvolutionContext) -> Optional[SkillRecord]: + """Execute one evolution action. Returns new SkillRecord or None. + + The global semaphore is NOT acquired here — it is managed at the + trigger-method level so the concurrency limit covers the whole batch. + """ + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("evolver") + except ImportError: + _src_tok = None + + evo_type = ctx.suggestion.evolution_type + try: + if evo_type == EvolutionType.FIX: + return await self._evolve_fix(ctx) + elif evo_type == EvolutionType.DERIVED: + return await self._evolve_derived(ctx) + elif evo_type == EvolutionType.CAPTURED: + return await self._evolve_captured(ctx) + else: + logger.warning(f"Unknown evolution type: {evo_type}") + return None + except Exception as e: + targets = "+".join(ctx.suggestion.target_skill_ids) or "(new)" + logger.error(f"Evolution failed [{evo_type.value}] target={targets}: {e}") + return None + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + # Trigger 1: post-analysis + async def process_analysis( + self, analysis: ExecutionAnalysis, + ) -> List[SkillRecord]: + """Process all evolution suggestions from a completed analysis. + + Called immediately after ``ExecutionAnalyzer.analyze_execution()``. + Each suggestion becomes one evolution action, executed in parallel + (throttled by semaphore). + """ + if not analysis.candidate_for_evolution: + return [] + + # Build contexts first (cheap, no LLM calls) + contexts: List[EvolutionContext] = [] + for suggestion in analysis.evolution_suggestions: + ctx = self._build_context_from_analysis(analysis, suggestion) + if ctx is not None: + contexts.append(ctx) + + if not contexts: + return [] + + results = await self._execute_contexts(contexts, "analysis") + + if results: + names = [r.name for r in results] + logger.info( + f"[Trigger:analysis] Evolved {len(results)} skill(s): {names} " + f"from task {analysis.task_id}" + ) + return results + + # Trigger 2: tool quality degradation + async def process_tool_degradation( + self, problematic_tools: List["ToolQualityRecord"], + ) -> List[SkillRecord]: + """Fix skills that depend on degraded tools. + + Two-phase: rule-based candidate screening → LLM confirmation. + + Anti-loop (state-driven): + ``_addressed_degradations[tool_key]`` records skill names that + have already been evolved for that tool's degradation. They are + skipped on subsequent calls as long as the tool stays degraded. + + At the start of each call, tools that **recovered** (no longer + in ``problematic_tools``) are pruned from the dict — so if the + tool degrades again later, all dependent skills are re-evaluated. + """ + if not problematic_tools: + return [] + + # Prune recovered tools: if a tool_key used to be tracked but is + # no longer in the current problematic list, it recovered — clear + # its addressed set so future re-degradation gets a fresh pass. + current_tool_keys = {t.tool_key for t in problematic_tools} + recovered = [k for k in self._addressed_degradations if k not in current_tool_keys] + for k in recovered: + logger.debug(f"[Trigger:tool_degradation] Tool '{k}' recovered, clearing addressed set") + del self._addressed_degradations[k] + + # Phase 1: screen & confirm candidates + confirmed_contexts: List[EvolutionContext] = [] + seen_skills: set = set() # de-dup by skill_id within this call + + for tool_rec in problematic_tools: + addressed = self._addressed_degradations.get(tool_rec.tool_key, set()) + + skill_ids = self._store.find_skills_by_tool(tool_rec.tool_key) + for skill_id in skill_ids: + skill_record = self._store.load_record(skill_id) + if not skill_record or not skill_record.is_active: + continue + + # De-duplicate by skill_id within this call + if skill_record.skill_id in seen_skills: + continue + seen_skills.add(skill_record.skill_id) + + # Anti-loop: already evolved for this tool's degradation + if skill_record.skill_id in addressed: + logger.debug( + f"[Trigger:tool_degradation] Skipping '{skill_record.skill_id}' " + f"(already addressed for tool '{tool_rec.tool_key}')" + ) + continue + + recent = self._store.load_analyses(skill_id=skill_record.skill_id, limit=_ANALYSIS_CONTEXT_MAX) + content = self._load_skill_content(skill_record) + if not content: + continue + + issue_summary = ( + f"Tool `{tool_rec.tool_key}` degraded — " + f"recent success rate: {tool_rec.recent_success_rate:.0%}, " + f"total calls: {tool_rec.total_calls}, " + f"LLM flagged: {tool_rec.llm_flagged_count} time(s)." + ) + + direction = ( + f"Tool `{tool_rec.tool_key}` has degraded " + f"(success_rate={tool_rec.recent_success_rate:.0%}). " + f"Update skill instructions to handle this tool's " + f"failures gracefully or suggest alternatives." + ) + + # LLM confirmation: ask whether this skill truly needs fixing + confirmed = await self._llm_confirm_evolution( + skill_record=skill_record, + skill_content=content, + proposed_type=EvolutionType.FIX, + proposed_direction=direction, + trigger_context=f"Tool degradation: {issue_summary}", + recent_analyses=recent, + ) + if not confirmed: + logger.debug( + f"[Trigger:tool_degradation] LLM rejected evolution " + f"for skill '{skill_record.skill_id}' (tool={tool_rec.tool_key})" + ) + # Even if LLM rejected, mark as addressed to avoid + # repeated LLM confirmation calls on every cycle. + self._addressed_degradations.setdefault( + tool_rec.tool_key, set() + ).add(skill_record.skill_id) + continue + + skill_dir = Path(skill_record.path).parent if skill_record.path else None + confirmed_contexts.append(EvolutionContext( + trigger=EvolutionTrigger.TOOL_DEGRADATION, + suggestion=EvolutionSuggestion( + evolution_type=EvolutionType.FIX, + target_skill_ids=[skill_record.skill_id], + direction=direction, + ), + skill_records=[skill_record], + skill_contents=[content], + skill_dirs=[skill_dir] if skill_dir else [], + recent_analyses=recent, + tool_issue_summary=issue_summary, + available_tools=self._available_tools, + )) + + # Mark as addressed regardless of whether evolution succeeds + # (if it fails, Trigger 1/3 can pick it up on new data) + self._addressed_degradations.setdefault( + tool_rec.tool_key, set() + ).add(skill_record.skill_id) + + if not confirmed_contexts: + return [] + + # Phase 2: execute confirmed evolutions in parallel + results = await self._execute_contexts(confirmed_contexts, "tool_degradation") + return results + + # Trigger 3: periodic metric check + async def process_metric_check( + self, min_selections: int = 5, + ) -> List[SkillRecord]: + """Scan active skills and evolve those with poor health metrics. + + Two-phase: rule-based candidate screening (relaxed thresholds) → + LLM confirmation. Called periodically (e.g., every N executions). + Only considers skills with enough data (``min_selections``). + + Anti-loop (data-driven): newly-evolved skills start with + ``total_selections=0``, so they naturally need ``min_selections`` + fresh executions before being re-evaluated. No time-based + cooldown is needed. + """ + # Phase 1: screen & confirm candidates + confirmed_contexts: List[EvolutionContext] = [] + all_active = self._store.load_active() + + for skill_id, record in all_active.items(): + if record.total_selections < min_selections: + continue + + evo_type, direction = self._diagnose_skill_health(record) + if evo_type is None: + continue + + content = self._load_skill_content(record) + if not content: + continue + + recent = self._store.load_analyses(skill_id=record.skill_id, limit=_ANALYSIS_CONTEXT_MAX) + metric_summary = ( + f"selections={record.total_selections}, " + f"applied_rate={record.applied_rate:.0%}, " + f"completion_rate={record.completion_rate:.0%}, " + f"effective_rate={record.effective_rate:.0%}, " + f"fallback_rate={record.fallback_rate:.0%}" + ) + + # LLM confirmation: ask whether this skill truly needs evolution + confirmed = await self._llm_confirm_evolution( + skill_record=record, + skill_content=content, + proposed_type=evo_type, + proposed_direction=direction, + trigger_context=f"Metric check: {metric_summary}", + recent_analyses=recent, + ) + if not confirmed: + logger.debug( + f"[Trigger:metric_monitor] LLM rejected evolution " + f"for skill '{record.name}' ({evo_type.value})" + ) + continue + + skill_dir = Path(record.path).parent if record.path else None + confirmed_contexts.append(EvolutionContext( + trigger=EvolutionTrigger.METRIC_MONITOR, + suggestion=EvolutionSuggestion( + evolution_type=evo_type, + target_skill_ids=[record.skill_id], + direction=direction, + ), + skill_records=[record], + skill_contents=[content], + skill_dirs=[skill_dir] if skill_dir else [], + recent_analyses=recent, + metric_summary=metric_summary, + available_tools=self._available_tools, + )) + + if not confirmed_contexts: + return [] + + # Phase 2: execute confirmed evolutions in parallel + results = await self._execute_contexts(confirmed_contexts, "metric_monitor") + return results + + async def _execute_contexts( + self, + contexts: List[EvolutionContext], + trigger_label: str, + ) -> List[SkillRecord]: + """Execute a list of evolution contexts in parallel (throttled). + + Used by all three triggers after building/confirming contexts. + """ + async def _throttled(c: EvolutionContext) -> Optional[SkillRecord]: + async with self._semaphore: + return await self.evolve(c) + + raw = await asyncio.gather( + *[_throttled(c) for c in contexts], + return_exceptions=True, + ) + results: List[SkillRecord] = [] + for r in raw: + if isinstance(r, BaseException): + logger.error(f"[Trigger:{trigger_label}] Evolution task raised: {r}") + elif r is not None: + results.append(r) + + if results: + names = [r.name for r in results] + logger.info( + f"[Trigger:{trigger_label}] Evolved {len(results)} skill(s): {names}" + ) + return results + + def schedule_background( + self, + coro, + *, + label: str = "background_evolution", + ) -> Optional[asyncio.Task]: + """Launch a coroutine as a background ``asyncio.Task``. + + Used by the caller (``OpenSpace._maybe_evolve_quality``) when + ``background_triggers`` is True. The task is tracked so it can + be awaited on shutdown via ``wait_background()``. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"No running event loop — cannot schedule {label}") + return None + + task = loop.create_task(coro, name=label) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + task.add_done_callback(self._log_background_result) + return task + + @staticmethod + def _log_background_result(task: asyncio.Task) -> None: + """Log the outcome of a background evolution task.""" + if task.cancelled(): + logger.debug(f"Background task '{task.get_name()}' was cancelled") + return + exc = task.exception() + if exc: + logger.error( + f"Background task '{task.get_name()}' failed: {exc}", + exc_info=exc, + ) + + # LLM confirmation for Trigger 2/3 + async def _llm_confirm_evolution( + self, + *, + skill_record: SkillRecord, + skill_content: str, + proposed_type: EvolutionType, + proposed_direction: str, + trigger_context: str, + recent_analyses: List[ExecutionAnalysis], + ) -> bool: + """Ask LLM to confirm whether a rule-based evolution candidate + truly needs evolution. + + Returns True if LLM agrees, False otherwise. + This prevents false positives from rigid threshold-based rules. + + The confirmation prompt and response are recorded to + ``conversations.jsonl`` under agent_name="SkillEvolver.confirm". + """ + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("evolver") + except ImportError: + _src_tok = None + + from openspace.recording import RecordingManager + + analysis_ctx = self._format_analysis_context(recent_analyses) + + prompt = SkillEnginePrompts.evolution_confirm( + skill_id=skill_record.skill_id, + skill_content=_truncate(skill_content, _SKILL_CONTENT_MAX_CHARS // 2), + proposed_type=proposed_type.value, + proposed_direction=proposed_direction, + trigger_context=trigger_context, + recent_analyses=analysis_ctx, + ) + + confirm_messages = [{"role": "user", "content": prompt}] + + # Record confirmation setup + await RecordingManager.record_conversation_setup( + setup_messages=copy.deepcopy(confirm_messages), + agent_name="SkillEvolver.confirm", + extra={ + "skill_id": skill_record.skill_id, + "proposed_type": proposed_type.value, + "trigger_context": trigger_context[:200], + }, + ) + + model = self._model or self._llm_client.model + try: + result = await self._llm_client.complete( + messages=confirm_messages, + model=model, + ) + content = result["message"].get("content", "").strip().lower() + confirmed = self._parse_confirmation(content) + + # Record confirmation response + await RecordingManager.record_iteration_context( + iteration=1, + delta_messages=[{"role": "assistant", "content": content}], + response_metadata={ + "has_tool_calls": False, + "confirmed": confirmed, + }, + agent_name="SkillEvolver.confirm", + ) + + return confirmed + except Exception as e: + logger.warning(f"LLM confirmation failed, defaulting to skip: {e}") + return False + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + @staticmethod + def _parse_confirmation(response: str) -> bool: + """Parse LLM confirmation response (expects JSON with 'proceed' field).""" + # Try JSON parse first + try: + # Strip markdown fences + cleaned = response.strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + data = json.loads(cleaned) + if isinstance(data, dict): + return bool(data.get("proceed", False)) + except (json.JSONDecodeError, ValueError): + pass + # Fallback: look for keywords. + # - yes/no use strict word boundaries to avoid false positives + # (e.g. "know" matching "no"). + # - confirm/reject/skip use stem-style matching so that common + # LLM variants like "confirmed", "rejected", "skipping" still + # parse correctly. + _wb = re.search # shorthand + if any(w in response for w in ("\"proceed\": true", "proceed: true")) \ + or _wb(r"\byes\b", response) \ + or _wb(r"\bconfirm\w*\b", response): + return True + if any(w in response for w in ("\"proceed\": false", "proceed: false")) \ + or _wb(r"\bno\b", response) \ + or _wb(r"\breject\w*\b", response) \ + or _wb(r"\bskip\w*\b", response): + return False + # Default: skip — ambiguous response should not trigger costly evolution + logger.debug("LLM confirmation response was ambiguous, defaulting to skip") + return False + + async def _evolve_fix(self, ctx: EvolutionContext) -> Optional[SkillRecord]: + """In-place fix: same name, same directory, new version record. + + Uses agent loop for information gathering + apply-retry cycle. + """ + if not ctx.skill_records or not ctx.skill_contents or not ctx.skill_dirs: + logger.warning("FIX requires exactly 1 parent (skill_records/contents/dirs)") + return None + + parent = ctx.skill_records[0] + parent_content = ctx.skill_contents[0] + parent_dir = ctx.skill_dirs[0] + + # Build prompt with full directory content for multi-file skills + dir_content = self._format_skill_dir_content(parent_dir) + prompt = SkillEnginePrompts.evolution_fix( + current_content=_truncate(dir_content or parent_content, _SKILL_CONTENT_MAX_CHARS), + direction=ctx.suggestion.direction, + failure_context=self._format_analysis_context(ctx.recent_analyses), + tool_issue_summary=ctx.tool_issue_summary, + metric_summary=ctx.metric_summary, + ) + + # Agent loop: LLM can gather information via tools before generating edits + new_content = await self._run_evolution_loop(prompt, ctx) + if not new_content: + return None + + # Extract change_summary from LLM output (first line if prefixed) + new_content, change_summary = _extract_change_summary(new_content) + + # Apply-retry cycle + edit_result = await self._apply_with_retry( + apply_fn=lambda content: fix_skill(parent_dir, content, PatchType.AUTO), + initial_content=new_content, + skill_dir=parent_dir, + ctx=ctx, + prompt=prompt, + ) + if edit_result is None or not edit_result.ok: + return None + + # Re-read name/description from the updated SKILL.md on disk — + # the LLM may have refined the description (or even name) during the fix. + updated_skill_md = edit_result.content_snapshot.get(SKILL_FILENAME, "") + fixed_name = _extract_frontmatter_field(updated_skill_md, "name") or parent.name + fixed_desc = _extract_frontmatter_field(updated_skill_md, "description") or parent.description + + new_id = f"{fixed_name}__v{parent.lineage.generation + 1}_{uuid.uuid4().hex[:8]}" + model = self._model or self._llm_client.model + + new_record = SkillRecord( + skill_id=new_id, + name=fixed_name, + description=fixed_desc, + path=parent.path, + category=parent.category, + tags=list(parent.tags), + visibility=parent.visibility, + creator_id=parent.creator_id, + lineage=SkillLineage( + origin=SkillOrigin.FIXED, + generation=parent.lineage.generation + 1, + parent_skill_ids=[parent.skill_id], + source_task_id=ctx.source_task_id, + change_summary=change_summary or ctx.suggestion.direction, + content_diff=edit_result.content_diff, + content_snapshot=edit_result.content_snapshot, + created_by=model, + ), + tool_dependencies=list(parent.tool_dependencies), + critical_tools=list(parent.critical_tools), + ) + + await self._store.evolve_skill(new_record, [parent.skill_id]) + + # Stamp the new skill_id into the sidecar file so next discover() + write_skill_id(parent_dir, new_id) + + from .registry import SkillMeta + new_meta = SkillMeta( + skill_id=new_id, + name=fixed_name, + description=fixed_desc, + path=Path(parent.path), + ) + self._registry.update_skill(parent.skill_id, new_meta) + + logger.info( + f"FIX: {parent.name} gen{parent.lineage.generation} → " + f"gen{new_record.lineage.generation} [{new_id}]" + ) + return new_record + + async def _evolve_derived(self, ctx: EvolutionContext) -> Optional[SkillRecord]: + """Create enhanced version in a new directory. + + Supports single-parent (enhance) and multi-parent (merge/fuse). + Uses agent loop for information gathering + apply-retry cycle. + """ + if not ctx.skill_records or not ctx.skill_contents or not ctx.skill_dirs: + logger.warning("DERIVED requires at least one parent skill_record + content + dir") + return None + + first_parent = ctx.skill_records[0] # For fallback defaults only + is_merge = len(ctx.skill_records) > 1 + + # Build prompt — include all parent contents for multi-parent merge + if is_merge: + parent_sections = [] + for i, (rec, sd) in enumerate(zip(ctx.skill_records, ctx.skill_dirs)): + dir_content = self._format_skill_dir_content(sd) + label = f"Parent {i + 1}: {rec.name}" + parent_sections.append( + f"## {label}\n{_truncate(dir_content or ctx.skill_contents[i], _SKILL_CONTENT_MAX_CHARS)}" + ) + combined_content = "\n\n---\n\n".join(parent_sections) + else: + dir_content = self._format_skill_dir_content(ctx.skill_dirs[0]) + combined_content = _truncate(dir_content or ctx.skill_contents[0], _SKILL_CONTENT_MAX_CHARS) + + prompt = SkillEnginePrompts.evolution_derived( + parent_content=combined_content, + direction=ctx.suggestion.direction, + execution_insights=self._format_analysis_context(ctx.recent_analyses), + metric_summary=ctx.metric_summary, + ) + + # Agent loop + new_content = await self._run_evolution_loop(prompt, ctx) + if not new_content: + return None + + new_content, change_summary = _extract_change_summary(new_content) + + # Determine new skill name from frontmatter, or generate one + new_name = _extract_frontmatter_field(new_content, "name") + if not new_name or new_name == first_parent.name: + suffix = "-merged" if is_merge else "-enhanced" + new_name = f"{first_parent.name}{suffix}" + new_content = _set_frontmatter_field(new_content, "name", new_name) + + # Cap name length to avoid ever-growing chains like + # "panel-component-enhanced-enhanced-merged_abc123" + new_name = _sanitize_skill_name(new_name) + new_content = _set_frontmatter_field(new_content, "name", new_name) + + # Directory name always matches the skill name + target_dir = ctx.skill_dirs[0].parent / new_name + if target_dir.exists(): + new_name = f"{new_name}-{uuid.uuid4().hex[:6]}" + new_name = _sanitize_skill_name(new_name) + target_dir = ctx.skill_dirs[0].parent / new_name + new_content = _set_frontmatter_field(new_content, "name", new_name) + + # Apply-retry cycle for derive_skill + edit_result = await self._apply_with_retry( + apply_fn=lambda content: derive_skill(ctx.skill_dirs, target_dir, content, PatchType.AUTO), + initial_content=new_content, + skill_dir=target_dir, + ctx=ctx, + prompt=prompt, + cleanup_on_retry=target_dir, # Remove failed target dir before retry + ) + if edit_result is None or not edit_result.ok: + return None + + # Extract description from new content + new_desc = _extract_frontmatter_field(new_content, "description") or first_parent.description + + # Collect parent info from ALL parents + parent_ids = [r.skill_id for r in ctx.skill_records] + max_gen = max(r.lineage.generation for r in ctx.skill_records) + all_tool_deps: set = set() + all_critical: set = set() + all_tags: set = set() + for rec in ctx.skill_records: + all_tool_deps.update(rec.tool_dependencies) + all_critical.update(rec.critical_tools) + all_tags.update(rec.tags) + + new_id = f"{new_name}__v0_{uuid.uuid4().hex[:8]}" + model = self._model or self._llm_client.model + + new_record = SkillRecord( + skill_id=new_id, + name=new_name, + description=new_desc, + path=str(target_dir / SKILL_FILENAME), + category=ctx.suggestion.category or first_parent.category, + tags=sorted(all_tags), + visibility=first_parent.visibility, + creator_id=first_parent.creator_id, + lineage=SkillLineage( + origin=SkillOrigin.DERIVED, + generation=max_gen + 1, + parent_skill_ids=parent_ids, + source_task_id=ctx.source_task_id, + change_summary=change_summary or ctx.suggestion.direction, + content_diff=edit_result.content_diff, + content_snapshot=edit_result.content_snapshot, + created_by=model, + ), + tool_dependencies=sorted(all_tool_deps), + critical_tools=sorted(all_critical), + ) + + await self._store.evolve_skill(new_record, parent_ids) + + # Stamp skill_id sidecar so discover() uses this ID on restart + write_skill_id(target_dir, new_id) + + # Register the new skill so it's immediately available for selection + from .registry import SkillMeta + new_meta = SkillMeta( + skill_id=new_id, + name=new_name, + description=new_desc, + path=target_dir / SKILL_FILENAME, + ) + self._registry.add_skill(new_meta) + + parent_names = " + ".join(r.name for r in ctx.skill_records) + logger.info(f"DERIVED: {parent_names} → {new_name} [{new_id}]") + return new_record + + async def _evolve_captured(self, ctx: EvolutionContext) -> Optional[SkillRecord]: + """Capture a novel pattern as a brand-new skill. + + Uses agent loop for information gathering + apply-retry cycle. + """ + # Build prompt and call LLM + # For CAPTURED, we use analyses as context (the tasks where the pattern was observed) + task_descriptions = [] + for a in ctx.recent_analyses[:_ANALYSIS_CONTEXT_MAX]: + if a.execution_note: + task_descriptions.append( + f"- task={a.task_id}: {a.execution_note[:200]}" + ) + + prompt = SkillEnginePrompts.evolution_captured( + direction=ctx.suggestion.direction, + category=(ctx.suggestion.category or SkillCategory.WORKFLOW).value, + execution_highlights="\n".join(task_descriptions) if task_descriptions else "(no task context available)", + ) + + # Agent loop + new_content = await self._run_evolution_loop(prompt, ctx) + if not new_content: + return None + + new_content, change_summary = _extract_change_summary(new_content) + + # Extract name/description from the generated content + new_name = _extract_frontmatter_field(new_content, "name") + new_desc = _extract_frontmatter_field(new_content, "description") + if not new_name: + logger.warning("CAPTURED: LLM did not produce a valid skill name") + return None + + # Sanitize name (enforce length limit + valid chars) + new_name = _sanitize_skill_name(new_name) + new_content = _set_frontmatter_field(new_content, "name", new_name) + + # Create new skill directory via create_skill (handles multi-file FULL) + skill_dirs = self._registry._skill_dirs + if not skill_dirs: + logger.warning("CAPTURED: no skill directories configured") + return None + + # Directory name always matches the skill name + base_dir = skill_dirs[0] # Primary user skill directory + target_dir = base_dir / new_name + if target_dir.exists(): + new_name = f"{new_name}-{uuid.uuid4().hex[:6]}" + new_name = _sanitize_skill_name(new_name) + target_dir = base_dir / new_name + new_content = _set_frontmatter_field(new_content, "name", new_name) + + # Apply-retry cycle for create_skill + edit_result = await self._apply_with_retry( + apply_fn=lambda content: create_skill(target_dir, content, PatchType.AUTO), + initial_content=new_content, + skill_dir=target_dir, + ctx=ctx, + prompt=prompt, + cleanup_on_retry=target_dir, + ) + if edit_result is None or not edit_result.ok: + return None + + snapshot = edit_result.content_snapshot + add_all_diff = edit_result.content_diff + + new_id = f"{new_name}__v0_{uuid.uuid4().hex[:8]}" + model = self._model or self._llm_client.model + + new_record = SkillRecord( + skill_id=new_id, + name=new_name, + description=new_desc or new_name, + path=str(target_dir / SKILL_FILENAME), + category=ctx.suggestion.category or SkillCategory.WORKFLOW, + lineage=SkillLineage( + origin=SkillOrigin.CAPTURED, + generation=0, + parent_skill_ids=[], + source_task_id=ctx.source_task_id, + change_summary=change_summary or ctx.suggestion.direction, + content_diff=add_all_diff, + content_snapshot=snapshot, + created_by=model, + ), + ) + + await self._store.save_record(new_record) + + # Stamp skill_id sidecar so discover() uses this ID on restart + write_skill_id(target_dir, new_id) + + # Register the new skill so it's immediately available + from .registry import SkillMeta + new_meta = SkillMeta( + skill_id=new_id, + name=new_name, + description=new_desc or new_name, + path=target_dir / SKILL_FILENAME, + ) + self._registry.add_skill(new_meta) + + logger.info(f"CAPTURED: {new_name} [{new_id}]") + return new_record + + async def _run_evolution_loop( + self, + prompt: str, + ctx: EvolutionContext, + ) -> Optional[str]: + """Run evolution as a token-driven agent loop. + + Modeled after ``GroundingAgent.process()`` — the loop continues + until the LLM outputs an explicit completion/failure token, NOT + based on whether tools were called. + + Termination signals (checked every iteration, regardless of tool use): + - ``EVOLUTION_COMPLETE`` in assistant content → success, return edit. + - ``EVOLUTION_FAILED`` in assistant content → failure, return None. + + Tool availability: + - Iterations 1 … N-1: tools enabled (LLM may gather information). + - Iteration N (final): tools disabled, LLM must output a decision. + + Each non-final iteration without a token gets a nudge message + telling the LLM which iteration it is on and how many remain. + + Conversations are recorded to ``conversations.jsonl`` via + ``RecordingManager`` (agent_name="SkillEvolver") so the full + evolution dialogue is preserved for debugging and replay. + """ + from openspace.recording import RecordingManager + + model = self._model or self._llm_client.model + + # Merge tools from context and instance-level + evolution_tools: List["BaseTool"] = list(ctx.available_tools or []) + if not evolution_tools: + evolution_tools = list(self._available_tools) + + messages: List[Dict[str, Any]] = [ + {"role": "user", "content": prompt}, + ] + + # Record initial conversation setup + await RecordingManager.record_conversation_setup( + setup_messages=copy.deepcopy(messages), + tools=evolution_tools if evolution_tools else None, + agent_name="SkillEvolver", + extra={ + "evolution_type": ctx.suggestion.evolution_type.value, + "trigger": ctx.trigger.value, + "target_skills": ctx.suggestion.target_skill_ids, + }, + ) + + for iteration in range(_MAX_EVOLUTION_ITERATIONS): + is_last = iteration == _MAX_EVOLUTION_ITERATIONS - 1 + + # Snapshot message count before any additions + LLM call + msg_count_before = len(messages) + + # Final round: disable tools and force a decision + if is_last: + messages.append({ + "role": "system", + "content": ( + f"This is your FINAL round (iteration " + f"{iteration + 1}/{_MAX_EVOLUTION_ITERATIONS}) — " + f"no more tool calls allowed. " + f"You MUST output the skill edit content now based on " + f"all information gathered so far. Follow the output " + f"format specified in the original instructions. " + f"End with {EVOLUTION_COMPLETE} if the edit is satisfactory, " + f"or {EVOLUTION_FAILED} with a reason if you cannot produce one." + ), + }) + + try: + result = await self._llm_client.complete( + messages=messages, + tools=evolution_tools if (evolution_tools and not is_last) else None, + execute_tools=True, + model=model, + ) + except Exception as e: + logger.error(f"Evolution LLM call failed (iter {iteration + 1}): {e}") + return None + + content = result["message"].get("content", "") + updated_messages = result["messages"] + has_tool_calls = result.get("has_tool_calls", False) + + # Record iteration delta + delta = updated_messages[msg_count_before:] + await RecordingManager.record_iteration_context( + iteration=iteration + 1, + delta_messages=copy.deepcopy(delta), + response_metadata={ + "has_tool_calls": has_tool_calls, + "tool_calls_count": len(result.get("tool_results", [])), + "has_completion_token": bool( + content and (EVOLUTION_COMPLETE in content or EVOLUTION_FAILED in content) + ), + }, + agent_name="SkillEvolver", + ) + + messages = updated_messages + + # ── Token check (every iteration, regardless of tool calls) ── + if content and (EVOLUTION_COMPLETE in content or EVOLUTION_FAILED in content): + edit_content, failure_reason = self._parse_evolution_output(content) + if failure_reason is not None: + targets = "+".join(ctx.suggestion.target_skill_ids) or "(new)" + logger.warning( + f"Evolution LLM signalled failure " + f"[{ctx.suggestion.evolution_type.value}] " + f"target={targets}: {failure_reason}" + ) + return None + return edit_content + + # No token found + if is_last: + # Final round exhausted without a decision + logger.warning( + f"Evolution agent finished {_MAX_EVOLUTION_ITERATIONS} iterations " + f"without signalling {EVOLUTION_COMPLETE} or {EVOLUTION_FAILED}" + ) + return None + + if has_tool_calls: + logger.debug( + f"Evolution agent used tools " + f"(iter {iteration + 1}/{_MAX_EVOLUTION_ITERATIONS})" + ) + else: + # No tools, no token — nudge the LLM + logger.debug( + f"Evolution agent produced content without token or tools " + f"(iter {iteration + 1}/{_MAX_EVOLUTION_ITERATIONS})" + ) + + # Iteration guidance + remaining = _MAX_EVOLUTION_ITERATIONS - iteration - 1 + messages.append({ + "role": "system", + "content": ( + f"Iteration {iteration + 1}/{_MAX_EVOLUTION_ITERATIONS} complete " + f"({remaining} remaining). " + f"If your edit is ready, output it and include {EVOLUTION_COMPLETE} " + f"at the end. " + f"If you cannot complete this evolution, output {EVOLUTION_FAILED} " + f"with a reason. " + f"Otherwise, continue gathering information with tools." + ), + }) + + # Should never reach here (is_last handles the final iteration) + return None + + @staticmethod + def _parse_evolution_output(content: str) -> tuple[Optional[str], Optional[str]]: + """Extract edit content or failure reason from LLM output. + + MUST only be called when ``EVOLUTION_COMPLETE`` or + ``EVOLUTION_FAILED`` is present in *content*. + + Returns ``(clean_content, failure_reason)``: + - ``(content, None)`` — ``EVOLUTION_COMPLETE`` found. + - ``(None, reason)`` — ``EVOLUTION_FAILED`` found. + """ + stripped = content.strip() + + # Failure takes priority (if both tokens appear, treat as failure) + if EVOLUTION_FAILED in stripped: + idx = stripped.index(EVOLUTION_FAILED) + reason_part = stripped[idx + len(EVOLUTION_FAILED):].strip() + if reason_part.lower().startswith("reason:"): + reason_part = reason_part[len("reason:"):].strip() + reason = reason_part[:500] if reason_part else "LLM declined to produce edit (no reason given)" + return None, reason + + if EVOLUTION_COMPLETE in stripped: + clean = stripped.replace(EVOLUTION_COMPLETE, "").strip() + clean = _strip_markdown_fences(clean) + return clean, None + + # Caller guarantees a token is present; defensive fallback + return None, "No completion token found (unexpected)" + + async def _apply_with_retry( + self, + *, + apply_fn, + initial_content: str, + skill_dir: Path, + ctx: EvolutionContext, + prompt: str, + cleanup_on_retry: Optional[Path] = None, + ) -> Optional[SkillEditResult]: + """Apply an edit with retry on failure. + + If the first attempt fails (patch parse error, path mismatch, etc.), + feeds the error back to the LLM and asks for a corrected version. + + After successful application, runs structural validation. + + Retry conversations are recorded to ``conversations.jsonl`` under + agent_name="SkillEvolver.retry" so failed apply attempts and LLM + corrections are preserved for debugging. + + Args: + apply_fn: Callable that takes content str and returns SkillEditResult. + initial_content: First LLM-generated content to try. + skill_dir: Skill directory for validation. + ctx: Evolution context (for retry LLM calls). + prompt: Original prompt (for retry context). + cleanup_on_retry: Directory to remove before retrying (for derive/create). + """ + from openspace.recording import RecordingManager + + current_content = initial_content + msg_history: List[Dict[str, Any]] = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": initial_content}, + ] + + # Track whether we've recorded the retry setup (only on first retry) + retry_setup_recorded = False + + for attempt in range(_MAX_EVOLUTION_ATTEMPTS): + # Clean up previous failed attempt (for derive/create) + if attempt > 0 and cleanup_on_retry and cleanup_on_retry.exists(): + shutil.rmtree(cleanup_on_retry, ignore_errors=True) + + # Apply the edit + edit_result = apply_fn(current_content) + + if edit_result.ok: + # Validate the result + validation_error = _validate_skill_dir(skill_dir) + if validation_error is None: + if attempt > 0: + logger.info( + f"Apply-retry succeeded on attempt {attempt + 1}/{_MAX_EVOLUTION_ATTEMPTS}" + ) + return edit_result + else: + # Validation failed — treat as error for retry + error_msg = f"Validation failed: {validation_error}" + logger.warning( + f"Apply succeeded but validation failed " + f"(attempt {attempt + 1}/{_MAX_EVOLUTION_ATTEMPTS}): " + f"{validation_error}" + ) + else: + error_msg = edit_result.error or "Unknown apply error" + logger.warning( + f"Apply failed (attempt {attempt + 1}/{_MAX_EVOLUTION_ATTEMPTS}): " + f"{error_msg}" + ) + + # Last attempt? Give up. + if attempt >= _MAX_EVOLUTION_ATTEMPTS - 1: + logger.error( + f"Apply-retry exhausted after {_MAX_EVOLUTION_ATTEMPTS} attempts. " + f"Last error: {error_msg}" + ) + # Clean up any partially created directory + if cleanup_on_retry and cleanup_on_retry.exists(): + shutil.rmtree(cleanup_on_retry, ignore_errors=True) + return None + + # Record retry setup on first retry attempt + if not retry_setup_recorded: + await RecordingManager.record_conversation_setup( + setup_messages=copy.deepcopy(msg_history), + agent_name="SkillEvolver.retry", + extra={ + "evolution_type": ctx.suggestion.evolution_type.value, + "target_skills": ctx.suggestion.target_skill_ids, + "first_error": error_msg[:300], + }, + ) + retry_setup_recorded = True + + # Feed error back to LLM for retry, including current file + # content so the LLM doesn't hallucinate what's on disk. + current_on_disk = self._format_skill_dir_content(skill_dir) if skill_dir.is_dir() else "" + retry_prompt = ( + f"The previous edit was not successful. " + f"This was the error:\n\n{error_msg}\n\n" + ) + if current_on_disk: + retry_prompt += ( + f"Here is the CURRENT content of the skill files on disk " + f"(use this as the ground truth for any SEARCH/REPLACE or " + f"context anchors):\n\n{_truncate(current_on_disk, _SKILL_CONTENT_MAX_CHARS)}\n\n" + ) + retry_prompt += ( + f"Please fix the issue and generate the edit again. " + f"Follow the same output format as before." + ) + msg_history.append({"role": "user", "content": retry_prompt}) + + # Call LLM for corrected version (no tools — just fix the edit) + model = self._model or self._llm_client.model + try: + result = await self._llm_client.complete( + messages=msg_history, + model=model, + ) + new_content = result["message"].get("content", "") + if not new_content: + logger.warning("Retry LLM returned empty content") + continue + + new_content = _strip_markdown_fences(new_content) + # Strip evolution tokens that the LLM may include in retry responses + new_content = new_content.replace(EVOLUTION_COMPLETE, "").replace(EVOLUTION_FAILED, "").strip() + new_content, _ = _extract_change_summary(new_content) + msg_history.append({"role": "assistant", "content": new_content}) + current_content = new_content + + # Record retry iteration + await RecordingManager.record_iteration_context( + iteration=attempt + 1, + delta_messages=[ + {"role": "user", "content": retry_prompt}, + {"role": "assistant", "content": new_content}, + ], + response_metadata={ + "has_tool_calls": False, + "attempt": attempt + 1, + "error": error_msg[:300], + }, + agent_name="SkillEvolver.retry", + ) + + except Exception as e: + logger.error(f"Retry LLM call failed: {e}") + continue + + return None + + def _build_context_from_analysis( + self, + analysis: ExecutionAnalysis, + suggestion: EvolutionSuggestion, + ) -> Optional[EvolutionContext]: + """Build EvolutionContext from a single analysis suggestion. + + Loads all target skills referenced by ``suggestion.target_skill_ids``. + For FIX: exactly 1 parent required. + For DERIVED: 1+ parents (multi-parent = merge). + For CAPTURED: parents list is empty. + """ + records: List[SkillRecord] = [] + contents: List[str] = [] + dirs: List[Path] = [] + + if suggestion.evolution_type in (EvolutionType.FIX, EvolutionType.DERIVED): + if not suggestion.target_skill_ids: + logger.warning("FIX/DERIVED suggestion missing target_skill_ids") + return None + + for target_id in suggestion.target_skill_ids: + rec = self._store.load_record(target_id) + if not rec: + logger.warning(f"Target skill not found: {target_id}") + return None + content = self._load_skill_content(rec) + if not content: + logger.warning(f"Cannot load content for skill: {target_id}") + return None + skill_dir = Path(rec.path).parent if rec.path else None + + records.append(rec) + contents.append(content) + if skill_dir: + dirs.append(skill_dir) + + # FIX must target exactly one skill + if suggestion.evolution_type == EvolutionType.FIX and len(records) != 1: + logger.warning( + f"FIX requires exactly 1 target, got {len(records)}: " + f"{suggestion.target_skill_ids}" + ) + return None + + return EvolutionContext( + trigger=EvolutionTrigger.ANALYSIS, + suggestion=suggestion, + skill_records=records, + skill_contents=contents, + skill_dirs=dirs, + source_task_id=analysis.task_id, + recent_analyses=[analysis], + available_tools=self._available_tools, + ) + + def _load_skill_content(self, record: SkillRecord) -> str: + """Load SKILL.md content from disk via registry or direct read.""" + # Try registry first (uses cache, keyed by skill_id) + content = self._registry.load_skill_content(record.skill_id) + if content: + return content + # Fallback: read directly from path + if record.path: + p = Path(record.path) + if p.exists(): + try: + return p.read_text(encoding="utf-8") + except Exception: + pass + return "" + + @staticmethod + def _format_skill_dir_content(skill_dir: Path) -> str: + """Format all text files in a skill directory for prompt inclusion. + + Returns a multi-file listing if there are auxiliary files beyond + SKILL.md, or just the SKILL.md content for single-file skills. + """ + files = collect_skill_snapshot(skill_dir) + if not files: + return "" + + # Single-file skill: return just the content + if len(files) == 1 and SKILL_FILENAME in files: + return files[SKILL_FILENAME] + + # Multi-file: format as directory listing + parts: list[str] = [] + # SKILL.md first + if SKILL_FILENAME in files: + parts.append(f"### File: {SKILL_FILENAME}\n```markdown\n{files[SKILL_FILENAME]}\n```") + for name, content in sorted(files.items()): + if name == SKILL_FILENAME: + continue + parts.append(f"### File: {name}\n```\n{content}\n```") + + return "\n\n".join(parts) + + @staticmethod + def _format_analysis_context(analyses: List[ExecutionAnalysis]) -> str: + """Format recent analyses into a concise context block for prompts.""" + if not analyses: + return "(no execution history available)" + + parts: List[str] = [] + for a in analyses[:_ANALYSIS_CONTEXT_MAX]: + completed = "completed" if a.task_completed else "failed" + + # Per-skill notes + skill_notes = [] + for j in a.skill_judgments: + applied = "applied" if j.skill_applied else "NOT applied" + note = f" - {j.skill_id}: {applied}" + if j.note: + note += f" — {j.note[:_ANALYSIS_NOTE_MAX_CHARS]}" + skill_notes.append(note) + + # Tool issues + tool_lines = [] + for issue in a.tool_issues[:3]: + tool_lines.append(f" - {issue[:200]}") + + block = f"### Task: {a.task_id} ({completed})\n" + if a.execution_note: + block += f"{a.execution_note[:_ANALYSIS_NOTE_MAX_CHARS]}\n" + if skill_notes: + block += "Skills:\n" + "\n".join(skill_notes) + "\n" + if tool_lines: + block += "Tool issues:\n" + "\n".join(tool_lines) + "\n" + parts.append(block) + + return "\n".join(parts) + + @staticmethod + def _diagnose_skill_health( + record: SkillRecord, + ) -> tuple[Optional[EvolutionType], str]: + """Diagnose what type of evolution a skill needs based on metrics. + + Returns (None, "") if the skill appears healthy. + Thresholds are intentionally relaxed — the LLM confirmation step + filters out false positives. + """ + # High fallback rate → skill frequently selected but not used → FIX candidate + if record.fallback_rate > _FALLBACK_THRESHOLD: + return EvolutionType.FIX, ( + f"High fallback rate ({record.fallback_rate:.0%}): " + f"skill is frequently selected but not applied, " + f"suggesting instructions are unclear or outdated." + ) + + # Applied often but rarely completes → instructions are wrong → FIX candidate + if (record.applied_rate > _HIGH_APPLIED_FOR_FIX + and record.completion_rate < _LOW_COMPLETION_THRESHOLD): + return EvolutionType.FIX, ( + f"Low completion rate ({record.completion_rate:.0%}) despite " + f"high applied rate ({record.applied_rate:.0%}): " + f"skill instructions may be incorrect or incomplete." + ) + + # Moderate effectiveness → could be better → DERIVED candidate + if (record.effective_rate < _MODERATE_EFFECTIVE_THRESHOLD + and record.applied_rate > _MIN_APPLIED_FOR_DERIVED): + return EvolutionType.DERIVED, ( + f"Moderate effectiveness ({record.effective_rate:.0%}): " + f"skill works sometimes but could be enhanced with " + f"better error handling or alternative approaches." + ) + + return None, "" \ No newline at end of file diff --git a/openspace/skill_engine/fuzzy_match.py b/openspace/skill_engine/fuzzy_match.py new file mode 100644 index 0000000000000000000000000000000000000000..f6320e3afe836181d400cc3dbbe967200cba4fde --- /dev/null +++ b/openspace/skill_engine/fuzzy_match.py @@ -0,0 +1,322 @@ +"""Fuzzy matching chain for SEARCH/REPLACE edits. + +The chain degrades gracefully: + Level 1 — exact match + Level 2 — line-trimmed match (per-line strip) + Level 3 — block-anchor match (first/last line + Levenshtein middle) + Level 4 — whitespace-normalized match (collapse whitespace) + Level 5 — indentation-flexible match (strip common indent) + Level 6 — trimmed-boundary match (strip entire block) +""" + +from __future__ import annotations + +import re +from typing import Generator, List, Optional, Tuple + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +__all__ = [ + "fuzzy_find_match", + "fuzzy_replace", + "REPLACER_CHAIN", +] + +# Type alias — each replacer yields candidate match strings. +Replacer = Generator[str, None, None] + +# Thresholds +SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0 +MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3 + +def levenshtein(a: str, b: str) -> int: + """Compute the Levenshtein edit distance between two strings.""" + if not a or not b: + return max(len(a), len(b)) + rows = len(a) + 1 + cols = len(b) + 1 + matrix = [[0] * cols for _ in range(rows)] + for i in range(rows): + matrix[i][0] = i + for j in range(cols): + matrix[0][j] = j + for i in range(1, rows): + for j in range(1, cols): + cost = 0 if a[i - 1] == b[j - 1] else 1 + matrix[i][j] = min( + matrix[i - 1][j] + 1, + matrix[i][j - 1] + 1, + matrix[i - 1][j - 1] + cost, + ) + return matrix[len(a)][len(b)] + +def simple_replacer(_content: str, find: str) -> Replacer: + """Yield *find* unconditionally; the caller verifies via ``str.find``.""" + yield find + +def line_trimmed_replacer(content: str, find: str) -> Replacer: + """Match by trimming each line, then yield the original substring.""" + original_lines = content.split("\n") + search_lines = find.split("\n") + + # Strip trailing empty line (common LLM artifact) + if search_lines and search_lines[-1] == "": + search_lines.pop() + + if not search_lines: + return + + n_search = len(search_lines) + for i in range(len(original_lines) - n_search + 1): + matches = True + for j in range(n_search): + if original_lines[i + j].strip() != search_lines[j].strip(): + matches = False + break + if matches: + start_idx = sum(len(original_lines[k]) + 1 for k in range(i)) + end_idx = start_idx + for k in range(n_search): + end_idx += len(original_lines[i + k]) + if k < n_search - 1: + end_idx += 1 + yield content[start_idx:end_idx] + +def block_anchor_replacer(content: str, find: str) -> Replacer: + """Anchor on first/last lines (trimmed) and use Levenshtein on middles.""" + original_lines = content.split("\n") + search_lines = find.split("\n") + + if len(search_lines) < 3: + return + if search_lines and search_lines[-1] == "": + search_lines.pop() + if len(search_lines) < 3: + return + + first_search = search_lines[0].strip() + last_search = search_lines[-1].strip() + search_block_size = len(search_lines) + + candidates: List[Tuple[int, int]] = [] + for i, line in enumerate(original_lines): + if line.strip() != first_search: + continue + for j in range(i + 2, len(original_lines)): + if original_lines[j].strip() == last_search: + candidates.append((i, j)) + break + + if not candidates: + return + + def _extract_block(start_line: int, end_line: int) -> str: + s = sum(len(original_lines[k]) + 1 for k in range(start_line)) + e = s + for k in range(start_line, end_line + 1): + e += len(original_lines[k]) + if k < end_line: + e += 1 + return content[s:e] + + if len(candidates) == 1: + start_line, end_line = candidates[0] + actual_size = end_line - start_line + 1 + lines_to_check = min(search_block_size - 2, actual_size - 2) + + if lines_to_check > 0: + similarity = 0.0 + for j in range(1, min(search_block_size - 1, actual_size - 1)): + orig_line = original_lines[start_line + j].strip() + srch_line = search_lines[j].strip() + max_len = max(len(orig_line), len(srch_line)) + if max_len == 0: + continue + dist = levenshtein(orig_line, srch_line) + similarity += (1 - dist / max_len) / lines_to_check + if similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD: + break + else: + similarity = 1.0 + + if similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD: + yield _extract_block(start_line, end_line) + return + + # Multiple candidates: pick the best + best_match: Optional[Tuple[int, int]] = None + max_similarity = -1.0 + + for start_line, end_line in candidates: + actual_size = end_line - start_line + 1 + lines_to_check = min(search_block_size - 2, actual_size - 2) + + if lines_to_check > 0: + raw_sim = 0.0 + for j in range(1, min(search_block_size - 1, actual_size - 1)): + orig_line = original_lines[start_line + j].strip() + srch_line = search_lines[j].strip() + max_len = max(len(orig_line), len(srch_line)) + if max_len == 0: + continue + dist = levenshtein(orig_line, srch_line) + raw_sim += 1 - dist / max_len + similarity = raw_sim / lines_to_check + else: + similarity = 1.0 + + if similarity > max_similarity: + max_similarity = similarity + best_match = (start_line, end_line) + + if max_similarity >= MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD and best_match: + yield _extract_block(best_match[0], best_match[1]) + +def whitespace_normalized_replacer(content: str, find: str) -> Replacer: + r"""Normalize whitespace (``\s+`` -> single space) before comparing.""" + + def _normalize(text: str) -> str: + return re.sub(r"\s+", " ", text).strip() + + normalized_find = _normalize(find) + lines = content.split("\n") + + # Single-line matching + for line in lines: + if _normalize(line) == normalized_find: + yield line + else: + normalized_line = _normalize(line) + if normalized_find in normalized_line: + words = find.strip().split() + if words: + pattern = r"\s+".join(re.escape(word) for word in words) + try: + match = re.search(pattern, line) + if match: + yield match.group(0) + except re.error: + pass + + # Multi-line matching + find_lines = find.split("\n") + if len(find_lines) > 1: + for i in range(len(lines) - len(find_lines) + 1): + block = lines[i: i + len(find_lines)] + if _normalize("\n".join(block)) == normalized_find: + yield "\n".join(block) + +def indentation_flexible_replacer(content: str, find: str) -> Replacer: + """Remove the common leading indentation and compare blocks.""" + + def _remove_indent(text: str) -> str: + lines = text.split("\n") + non_empty = [line for line in lines if line.strip()] + if not non_empty: + return text + min_indent = min(len(line) - len(line.lstrip()) for line in non_empty) + return "\n".join( + line[min_indent:] if line.strip() else line for line in lines + ) + + normalized_find = _remove_indent(find) + content_lines = content.split("\n") + find_lines = find.split("\n") + + for i in range(len(content_lines) - len(find_lines) + 1): + block = "\n".join(content_lines[i: i + len(find_lines)]) + if _remove_indent(block) == normalized_find: + yield block + +def trimmed_boundary_replacer(content: str, find: str) -> Replacer: + """Trim the entire find block, then search.""" + trimmed_find = find.strip() + if trimmed_find == find: + return + + if trimmed_find in content: + yield trimmed_find + + lines = content.split("\n") + find_lines = find.split("\n") + for i in range(len(lines) - len(find_lines) + 1): + block = "\n".join(lines[i: i + len(find_lines)]) + if block.strip() == trimmed_find: + yield block + +REPLACER_CHAIN: list = [ + ("simple", simple_replacer), + ("line_trimmed", line_trimmed_replacer), + ("block_anchor", block_anchor_replacer), + ("whitespace_normalized", whitespace_normalized_replacer), + ("indentation_flexible", indentation_flexible_replacer), + ("trimmed_boundary", trimmed_boundary_replacer), +] + +def fuzzy_find_match(content: str, find: str) -> Tuple[str, int]: + """Locate *find* in *content* using the replacer chain. + + Returns ``(matched_text, position)`` where *matched_text* is the + actual substring of *content*, and *position* is its character offset. + Returns ``("", -1)`` when no match is found. + """ + for name, replacer in REPLACER_CHAIN: + for candidate in replacer(content, find): + pos = content.find(candidate) + if pos == -1: + continue + if name != "simple": + logger.debug( + "fuzzy_find_match: matched via '%s' at position %d", + name, pos, + ) + return candidate, pos + + return "", -1 + +def fuzzy_replace( + content: str, + old_string: str, + new_string: str, + replace_all: bool = False, +) -> str: + """Replace *old_string* with *new_string* in *content*. + + Walks the chain until a unique match is found. + + Raises: + ValueError: When old_string not found or match is ambiguous. + """ + if old_string == new_string: + raise ValueError("old_string and new_string are identical") + + not_found = True + + for name, replacer in REPLACER_CHAIN: + for candidate in replacer(content, old_string): + idx = content.find(candidate) + if idx == -1: + continue + + not_found = False + + if replace_all: + return content.replace(candidate, new_string) + + last_idx = content.rfind(candidate) + if idx != last_idx: + continue # ambiguous + + return content[:idx] + new_string + content[idx + len(candidate):] + + if not_found: + raise ValueError( + "Could not find old_string in the file. " + "Must match exactly (including whitespace and indentation)." + ) + raise ValueError( + "Found multiple matches for old_string. " + "Provide more context to make the match unique." + ) diff --git a/openspace/skill_engine/patch.py b/openspace/skill_engine/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..72ce8018f86acfdbf53c0a849f7e8ee4366e4719 --- /dev/null +++ b/openspace/skill_engine/patch.py @@ -0,0 +1,1001 @@ +"""patch — Multi-file patch application and diff generation for Skills. + +A skill is a directory containing ``SKILL.md`` and optional auxiliary files +(scripts, configs, examples). Three LLM output formats are supported: + + FULL — Complete file content (single-file or multi-file via ``*** Begin Files``) + DIFF — SEARCH/REPLACE blocks (single-file, applied to SKILL.md) + PATCH — ``*** Begin Patch`` multi-file format (supports Add / Update / Delete across multiple files) + +Auto-detection (``PatchType.AUTO``) inspects the LLM output and dispatches +to the correct parser. + +Three skill operations: + fix_skill — in-place repair of an existing skill directory + derive_skill — copy directory → apply changes in the copy + create_skill — create a brand-new skill directory (for CAPTURED) +""" + +from __future__ import annotations + +import difflib +import re +import shutil +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union + +from .fuzzy_match import fuzzy_find_match +from .skill_utils import normalize_frontmatter +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +SKILL_FILENAME = "SKILL.md" + +# Sidecar file that stores the persistent skill_id — excluded from diffs/snapshots +_SKILL_ID_FILENAME = ".skill_id" + + +def _normalize_skill_frontmatter(skill_dir: Path) -> None: + """Re-quote SKILL.md frontmatter in-place so values are valid YAML.""" + skill_file = skill_dir / SKILL_FILENAME + if not skill_file.exists(): + return + try: + raw = skill_file.read_text(encoding="utf-8") + normalized = normalize_frontmatter(raw) + if normalized != raw: + skill_file.write_text(normalized, encoding="utf-8") + except Exception as e: + logger.debug(f"frontmatter normalize skipped for {skill_dir.name}: {e}") + + +class PatchType(str, Enum): + """LLM output format for skill edits.""" + AUTO = "auto" # Auto-detect from content + FULL = "full" # Complete content (single or multi-file) + DIFF = "diff" # SEARCH/REPLACE blocks (single-file) + PATCH = "patch" # *** Begin Patch multi-file format + + +@dataclass +class UpdateChunk: + """A single change block inside an Update File hunk.""" + old_lines: List[str] + new_lines: List[str] + change_context: Optional[str] = None + is_end_of_file: bool = False + + +@dataclass +class PatchHunk: + """One file-level operation inside a patch.""" + type: str # "add" | "update" | "delete" + path: str + contents: str = "" # type="add": new file body + move_path: Optional[str] = None # type="update": optional rename target + chunks: List[UpdateChunk] = field(default_factory=list) + + +@dataclass +class PatchResult: + """Parsed representation of a ``*** Begin Patch`` block.""" + hunks: List[PatchHunk] + + +@dataclass +class SkillEditResult: + """Result of a skill edit operation. + + Attributes: + skill_dir: Final skill directory location. + content_diff: Combined unified diff (git diff format) covering all + modified files, for lineage recording. + content_snapshot: Full directory snapshot ``{relative_path: content}``, + for lineage recording. + error: Non-None means the operation failed. + """ + skill_dir: Path = field(default_factory=lambda: Path(".")) + content_diff: str = "" + content_snapshot: Dict[str, str] = field(default_factory=dict) + error: Optional[str] = None + + @property + def ok(self) -> bool: + return self.error is None + + +class PatchError(RuntimeError): + """Raised when a patch cannot be applied.""" + pass + + +class PatchParseError(PatchError): + """Raised when the patch text cannot be parsed.""" + pass + + +PATCH_PATTERN = re.compile( + r"<{7}\s*SEARCH\s*\n(.*?)\n\s*={7}\s*\n(.*?)\n\s*>{7}\s*REPLACE\s*", + re.DOTALL, +) + + +def fix_skill( + skill_dir: Path, + content: str, + patch_type: PatchType = PatchType.AUTO, +) -> SkillEditResult: + """In-place repair of an existing skill directory. + + Applies the LLM output to the skill directory, overwrites files on disk, + returns a combined diff and snapshot for lineage recording. + + Args: + skill_dir: Existing skill directory. + content: LLM output (FULL, DIFF, or PATCH format). + patch_type: Output format (AUTO to auto-detect). + """ + if not skill_dir.is_dir(): + return SkillEditResult(error=f"Skill directory not found: {skill_dir}") + skill_file = skill_dir / SKILL_FILENAME + if not skill_file.exists(): + return SkillEditResult(error=f"SKILL.md not found: {skill_file}") + + # Snapshot before edit + old_files = _collect_files(skill_dir) + + # Resolve patch type + if patch_type == PatchType.AUTO: + patch_type = detect_patch_type(content) + + try: + if patch_type == PatchType.PATCH: + _apply_multi_file_patch(content, skill_dir) + elif patch_type == PatchType.FULL: + _apply_multi_file_full(content, skill_dir) + elif patch_type == PatchType.DIFF: + _apply_search_replace_to_file(content, skill_file) + else: + return SkillEditResult(error=f"Unknown patch type: {patch_type}") + except PatchError as e: + return SkillEditResult(error=str(e)) + except Exception as e: + return SkillEditResult(error=f"Unexpected error: {e}") + + _normalize_skill_frontmatter(skill_dir) + + # Snapshot after edit + new_files = _collect_files(skill_dir) + diff = _compute_files_diff(old_files, new_files) + + logger.info(f"fix_skill: {skill_dir.name} ({patch_type.value})") + return SkillEditResult( + skill_dir=skill_dir, + content_diff=diff, + content_snapshot=new_files, + ) + +def derive_skill( + source_dirs: Union[Path, List[Path]], + target_dir: Path, + content: str, + patch_type: PatchType = PatchType.AUTO, +) -> SkillEditResult: + """Derive a new skill from one or more existing skills. + + **Single parent** (``source_dirs`` is one Path): + Copies parent directory → applies LLM output. Supports PATCH / DIFF / FULL. + + **Multiple parents** (``source_dirs`` is a list of Paths): + Creates a brand-new directory and applies LLM output as FULL or PATCH + (DIFF not supported for multi-parent — no single base to search/replace). + ``content_diff`` is empty for multi-parent (no meaningful single-parent + diff); ``content_snapshot`` captures the full result for lineage tracking. + + Source directories stay unchanged in both cases. + + Args: + source_dirs: Parent skill directory (or list for multi-parent merge). + target_dir: New skill directory (must not exist). + content: LLM output. + patch_type: Output format (AUTO to auto-detect). + """ + # Normalise to list + if isinstance(source_dirs, Path): + sources = [source_dirs] + else: + sources = list(source_dirs) + + if not sources: + return SkillEditResult(error="derive_skill requires at least one source directory") + if target_dir.exists(): + return SkillEditResult(error=f"Target already exists: {target_dir}") + + # Validate all sources + for sd in sources: + if not sd.is_dir(): + return SkillEditResult(error=f"Source does not exist: {sd}") + if not (sd / SKILL_FILENAME).exists(): + return SkillEditResult(error=f"Source SKILL.md not found: {sd / SKILL_FILENAME}") + + first_source = sources[0] + is_multi_parent = len(sources) > 1 + + if is_multi_parent: + # Multi-parent merge: create new directory, apply content (FULL or PATCH) + if patch_type == PatchType.AUTO: + patch_type = detect_patch_type(content) + if patch_type == PatchType.DIFF: + # DIFF (SEARCH/REPLACE) is not meaningful for merge — no single base + patch_type = PatchType.FULL + + try: + target_dir.mkdir(parents=True, exist_ok=True) + if patch_type == PatchType.PATCH: + _apply_multi_file_patch(content, target_dir) + else: + _apply_multi_file_full(content, target_dir) + except (PatchError, Exception) as e: + shutil.rmtree(target_dir, ignore_errors=True) + return SkillEditResult(error=str(e)) + else: + # Single parent: copy → apply (original behaviour) + shutil.copytree(first_source, target_dir) + + if patch_type == PatchType.AUTO: + patch_type = detect_patch_type(content) + + try: + if patch_type == PatchType.PATCH: + _apply_multi_file_patch(content, target_dir) + elif patch_type == PatchType.FULL: + _apply_multi_file_full(content, target_dir) + elif patch_type == PatchType.DIFF: + _apply_search_replace_to_file(content, target_dir / SKILL_FILENAME) + else: + shutil.rmtree(target_dir, ignore_errors=True) + return SkillEditResult(error=f"Unknown patch type: {patch_type}") + except (PatchError, Exception) as e: + shutil.rmtree(target_dir, ignore_errors=True) + return SkillEditResult(error=str(e)) + + _normalize_skill_frontmatter(target_dir) + + new_files = _collect_files(target_dir) + + # Single parent: diff against the parent. Multi-parent: no meaningful diff → empty. + # (content_snapshot already captures the full result for lineage tracking.) + diff = compute_skill_diff(first_source, target_dir) if not is_multi_parent else "" + + src_names = " + ".join(sd.name for sd in sources) + logger.info(f"derive_skill: {src_names} → {target_dir.name} ({patch_type.value})") + return SkillEditResult( + skill_dir=target_dir, + content_diff=diff, + content_snapshot=new_files, + ) + +def create_skill( + target_dir: Path, + content: str, + patch_type: PatchType = PatchType.AUTO, +) -> SkillEditResult: + """Create a brand-new skill directory (for CAPTURED). + + Args: + target_dir: New skill directory (must not exist). + content: LLM output — complete skill content. + patch_type: Output format (AUTO to auto-detect, usually FULL). + """ + if target_dir.exists(): + return SkillEditResult(error=f"Target already exists: {target_dir}") + + if patch_type == PatchType.AUTO: + patch_type = detect_patch_type(content) + + try: + target_dir.mkdir(parents=True, exist_ok=True) + + if patch_type == PatchType.PATCH: + # PATCH with only Add File hunks + _apply_multi_file_patch(content, target_dir) + elif patch_type == PatchType.FULL: + _apply_multi_file_full(content, target_dir) + elif patch_type == PatchType.DIFF: + # For CAPTURED, DIFF doesn't make sense — treat as single-file FULL + (target_dir / SKILL_FILENAME).write_text(content, encoding="utf-8") + else: + shutil.rmtree(target_dir, ignore_errors=True) + return SkillEditResult(error=f"Unknown patch type: {patch_type}") + except (PatchError, Exception) as e: + shutil.rmtree(target_dir, ignore_errors=True) + return SkillEditResult(error=str(e)) + + _normalize_skill_frontmatter(target_dir) + + new_files = _collect_files(target_dir) + # Add-all diff (everything is new) + add_all = "\n".join( + compute_unified_diff("", text, filename=name) + for name, text in sorted(new_files.items()) + if compute_unified_diff("", text, filename=name) + ) + + logger.info(f"create_skill: {target_dir.name} ({patch_type.value})") + return SkillEditResult( + skill_dir=target_dir, + content_diff=add_all, + content_snapshot=new_files, + ) + +def detect_patch_type(content: str) -> PatchType: + """Auto-detect the patch format from LLM output. + + Detection uses **structural markers** that are unlikely to appear in + normal skill prose, checked in specificity order: + + 1. ``*** Begin Patch`` → PATCH (multi-file diff) + 2. ``*** Begin Files`` → FULL (multi-file envelope) + 3. ``*** File:`` as a standalone marker line → FULL (multi-file, no envelope) + 4. ``<<<<<<< SEARCH`` → DIFF (single-file SEARCH/REPLACE) + 5. Default → FULL (single-file complete content) + + For (3), we require ``*** File:`` to appear at the start of a line + (with optional leading whitespace) AND the content to contain at least + two such markers or one marker followed by meaningful content — this + avoids false positives when a skill happens to mention the marker in + prose (e.g. inside a code example). + """ + if "*** Begin Patch" in content: + return PatchType.PATCH + if "*** Begin Files" in content: + return PatchType.FULL + + # Detect bare *** File: markers (no *** Begin Files envelope). + # Must appear at line start; require at least one to look structural. + file_header_hits = _FILE_HEADER_RE.findall(content) + if file_header_hits: + return PatchType.FULL + + if "<<<<<<< SEARCH" in content: + return PatchType.DIFF + return PatchType.FULL + + +# MULTI-FILE FULL FORMAT +# Format: +# *** Begin Files +# *** File: SKILL.md +# (complete file content, no prefix) +# *** File: examples/helper.sh +# (complete file content) +# *** End Files +# +# Or just raw content (single-file fallback to SKILL.md). + +_FILE_HEADER_RE = re.compile(r"^\*\*\*\s*File:\s*(.+)$", re.MULTILINE) + + +def parse_multi_file_full(content: str) -> Dict[str, str]: + """Parse ``*** Begin Files`` format into ``{relative_path: content}``. + + Falls back to ``{SKILL.md: content}`` if no markers are found (single-file). + """ + # Strip optional envelope + stripped = content.strip() + if stripped.startswith("*** Begin Files"): + stripped = stripped[len("*** Begin Files"):].strip() + # Strip *** End Files and anything after it (e.g. stray tokens) + end_files_idx = stripped.rfind("*** End Files") + if end_files_idx != -1: + stripped = stripped[:end_files_idx].strip() + + # Find all *** File: headers + headers = list(_FILE_HEADER_RE.finditer(stripped)) + if not headers: + # No multi-file markers — treat entire content as SKILL.md + return {SKILL_FILENAME: content} + + files: Dict[str, str] = {} + for i, match in enumerate(headers): + file_path = match.group(1).strip() + start = match.end() + # Content extends to the next header or end of string + if i + 1 < len(headers): + end = headers[i + 1].start() + else: + end = len(stripped) + file_content = stripped[start:end].strip("\n") + # Preserve one trailing newline for non-empty files + if file_content and not file_content.endswith("\n"): + file_content += "\n" + files[file_path] = file_content + + return files + + +def _apply_multi_file_full(content: str, skill_dir: Path) -> None: + """Apply multi-file FULL format to a skill directory. + + For each file in the parsed output: + - If the path exists → overwrite + - If the path is new → create (with parent dirs) + """ + files = parse_multi_file_full(content) + + for rel_path, file_content in files.items(): + # Security: ensure path stays within skill_dir + target = (skill_dir / rel_path).resolve() + if not str(target).startswith(str(skill_dir.resolve())): + raise PatchError(f"Path escapes skill directory: {rel_path}") + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(file_content, encoding="utf-8") + logger.debug(f"FULL write: {rel_path}") + + +# MULTI-FILE PATCH FORMAT +# Format: +# *** Begin Patch +# *** Add File: +# +line1 +# +line2 +# *** Update File: +# @@ context line +# -old line +# +new line +# *** Delete File: +# *** End Patch + +Comparator = Callable[[str, str], bool] + + +def _try_match( + lines: List[str], + pattern: List[str], + start_index: int, + compare: Comparator, + eof: bool, +) -> int: + """Attempt to find *pattern* inside *lines* using *compare*.""" + n = len(lines) + p = len(pattern) + if p == 0: + return -1 + + if eof: + from_end = n - p + if from_end >= start_index: + if all(compare(lines[from_end + j], pattern[j]) for j in range(p)): + return from_end + + for i in range(start_index, n - p + 1): + if all(compare(lines[i + j], pattern[j]) for j in range(p)): + return i + + return -1 + + +# Unicode normalisation (from ShinkaEvolve) +_UNICODE_REPLACEMENTS: Dict[str, str] = { + "\u2018": "'", "\u2019": "'", "\u201A": "'", "\u201B": "'", + "\u201C": '"', "\u201D": '"', "\u201E": '"', "\u201F": '"', + "\u2010": "-", "\u2011": "-", "\u2012": "-", "\u2013": "-", + "\u2014": "-", "\u2015": "-", + "\u2026": "...", + "\u00A0": " ", +} +_UNICODE_RE = re.compile("|".join(re.escape(k) for k in _UNICODE_REPLACEMENTS)) + + +def _normalize_unicode(s: str) -> str: + return _UNICODE_RE.sub(lambda m: _UNICODE_REPLACEMENTS[m.group()], s) + + +def seek_sequence( + lines: List[str], + pattern: List[str], + start_index: int, + eof: bool = False, +) -> int: + """4-level degrading search for a line pattern inside *lines*. + + Returns the 0-based index of the first matching position, or -1. + """ + if not pattern: + return -1 + + # Pass 1: exact + idx = _try_match(lines, pattern, start_index, lambda a, b: a == b, eof) + if idx != -1: + return idx + + # Pass 2: rstrip + idx = _try_match( + lines, pattern, start_index, + lambda a, b: a.rstrip() == b.rstrip(), eof, + ) + if idx != -1: + return idx + + # Pass 3: strip + idx = _try_match( + lines, pattern, start_index, + lambda a, b: a.strip() == b.strip(), eof, + ) + if idx != -1: + return idx + + # Pass 4: Unicode-normalised + strip + idx = _try_match( + lines, pattern, start_index, + lambda a, b: _normalize_unicode(a.strip()) == _normalize_unicode(b.strip()), + eof, + ) + return idx + +def _parse_patch_header( + lines: List[str], idx: int, +) -> Optional[Tuple[str, Optional[str], int]]: + """Parse a ``*** Add/Delete/Update File:`` header line.""" + line = lines[idx] + + if line.startswith("*** Add File:"): + file_path = line.split(":", 1)[1].strip() + return (file_path, None, idx + 1) if file_path else None + + if line.startswith("*** Delete File:"): + file_path = line.split(":", 1)[1].strip() + return (file_path, None, idx + 1) if file_path else None + + if line.startswith("*** Update File:"): + file_path = line.split(":", 1)[1].strip() + if not file_path: + return None + move_path: Optional[str] = None + next_idx = idx + 1 + if next_idx < len(lines) and lines[next_idx].startswith("*** Move to:"): + move_path = lines[next_idx].split(":", 1)[1].strip() + next_idx += 1 + return (file_path, move_path, next_idx) + + return None + +def _parse_add_file_content( + lines: List[str], start_idx: int, +) -> Tuple[str, int]: + """Collect ``+``-prefixed lines for an Add File hunk.""" + content_lines: List[str] = [] + i = start_idx + while i < len(lines) and not lines[i].startswith("***"): + if lines[i].startswith("+"): + content_lines.append(lines[i][1:]) + i += 1 + content = "\n".join(content_lines) + if content.endswith("\n"): + content = content[:-1] + return content, i + +def _parse_update_chunks( + lines: List[str], start_idx: int, +) -> Tuple[List[UpdateChunk], int]: + """Parse ``@@``-delimited change chunks for an Update File hunk.""" + chunks: List[UpdateChunk] = [] + i = start_idx + while i < len(lines) and not lines[i].startswith("***"): + if lines[i].startswith("@@"): + context_line = lines[i][2:].strip() + i += 1 + old_lines: List[str] = [] + new_lines: List[str] = [] + is_end_of_file = False + + while i < len(lines): + cl = lines[i] + if cl.startswith("@@"): + break + if cl.startswith("***") and cl != "*** End of File": + break + if cl == "*** End of File": + is_end_of_file = True + i += 1 + break + if cl.startswith(" "): + old_lines.append(cl[1:]) + new_lines.append(cl[1:]) + elif cl.startswith("-"): + old_lines.append(cl[1:]) + elif cl.startswith("+"): + new_lines.append(cl[1:]) + i += 1 + + chunks.append(UpdateChunk( + old_lines=old_lines, + new_lines=new_lines, + change_context=context_line or None, + is_end_of_file=is_end_of_file, + )) + else: + i += 1 + + return chunks, i + +def parse_patch(patch_text: str) -> PatchResult: + """Parse a ``*** Begin Patch`` / ``*** End Patch`` block. + + Ported from ShinkaEvolve ``shinka/edit/apply_patch.py``. + """ + lines = patch_text.strip().split("\n") + + begin_idx = -1 + end_idx = -1 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped == "*** Begin Patch": + begin_idx = i + elif stripped == "*** End Patch": + end_idx = i + + if begin_idx == -1 or end_idx == -1 or begin_idx >= end_idx: + raise PatchParseError( + "Invalid patch format: missing or mis-ordered " + "*** Begin Patch / *** End Patch markers" + ) + + hunks: List[PatchHunk] = [] + i = begin_idx + 1 + + while i < end_idx: + header = _parse_patch_header(lines, i) + if header is None: + i += 1 + continue + + file_path, move_path, next_idx = header + + if lines[i].startswith("*** Add File:"): + content, next_idx = _parse_add_file_content(lines, next_idx) + hunks.append(PatchHunk(type="add", path=file_path, contents=content)) + i = next_idx + + elif lines[i].startswith("*** Delete File:"): + hunks.append(PatchHunk(type="delete", path=file_path)) + i = next_idx + + elif lines[i].startswith("*** Update File:"): + chunks, next_idx = _parse_update_chunks(lines, next_idx) + hunks.append(PatchHunk( + type="update", + path=file_path, + move_path=move_path, + chunks=chunks, + )) + i = next_idx + else: + i += 1 + + return PatchResult(hunks=hunks) + +def _compute_replacements( + original_lines: List[str], + file_path: str, + chunks: List[UpdateChunk], +) -> List[Tuple[int, int, List[str]]]: + """Compute (start, old_count, new_lines) replacement tuples.""" + replacements: List[Tuple[int, int, List[str]]] = [] + line_index = 0 + + for chunk in chunks: + # Context-based seeking + if chunk.change_context: + ctx_idx = seek_sequence( + original_lines, [chunk.change_context], line_index, + ) + if ctx_idx == -1: + raise PatchError( + f"Cannot locate context anchor " + f"'{chunk.change_context}' in {file_path}" + ) + line_index = ctx_idx + + # Pure addition (no old lines to match) + if not chunk.old_lines: + if original_lines and original_lines[-1] == "": + insert_idx = len(original_lines) - 1 + else: + insert_idx = len(original_lines) + replacements.append((insert_idx, 0, chunk.new_lines)) + continue + + pattern = list(chunk.old_lines) + new_slice = list(chunk.new_lines) + found = seek_sequence( + original_lines, pattern, line_index, chunk.is_end_of_file, + ) + + # Retry without trailing empty line + if found == -1 and pattern and pattern[-1] == "": + pattern = pattern[:-1] + if new_slice and new_slice[-1] == "": + new_slice = new_slice[:-1] + found = seek_sequence( + original_lines, pattern, line_index, chunk.is_end_of_file, + ) + + if found != -1: + replacements.append((found, len(pattern), new_slice)) + line_index = found + len(pattern) + else: + raise PatchError( + f"Cannot find expected lines in {file_path}:\n" + + "\n".join(chunk.old_lines) + ) + + replacements.sort(key=lambda x: x[0]) + return replacements + +def _apply_replacements( + lines: List[str], + replacements: List[Tuple[int, int, List[str]]], +) -> List[str]: + """Apply pre-sorted replacements in reverse order to avoid index shift.""" + result = list(lines) + for start_idx, old_len, new_segment in reversed(replacements): + del result[start_idx: start_idx + old_len] + for j, line in enumerate(new_segment): + result.insert(start_idx + j, line) + return result + +def apply_update_chunks( + file_path: str, + original_content: str, + chunks: List[UpdateChunk], +) -> str: + """Apply *chunks* to *original_content* and return the new content.""" + original_lines = original_content.split("\n") + + # Drop trailing empty element for consistent line counting + if original_lines and original_lines[-1] == "": + original_lines.pop() + + replacements = _compute_replacements(original_lines, file_path, chunks) + new_lines = _apply_replacements(original_lines, replacements) + + # Ensure trailing newline + if not new_lines or new_lines[-1] != "": + new_lines.append("") + + return "\n".join(new_lines) + +def _apply_multi_file_patch(patch_text: str, skill_dir: Path) -> None: + """Parse and apply a ``*** Begin Patch`` block to a skill directory. + + Two-phase: validate all hunks first, then write to disk. + """ + parsed = parse_patch(patch_text) + if not parsed.hunks: + raise PatchParseError("Patch contains no file operations") + + resolved_dir = skill_dir.resolve() + + # Phase 1: validate and compute new contents + changes: List[Tuple[str, Path, str, str]] = [] # (type, abs_path, old, new) + + for hunk in parsed.hunks: + abs_path = (skill_dir / hunk.path).resolve() + + # Security check + if not str(abs_path).startswith(str(resolved_dir)): + raise PatchError(f"Path escapes skill directory: {hunk.path}") + + if hunk.type == "add": + new_content = hunk.contents + if new_content and not new_content.endswith("\n"): + new_content += "\n" + changes.append(("add", abs_path, "", new_content)) + + elif hunk.type == "delete": + if not abs_path.exists(): + raise PatchError(f"Cannot delete non-existent file: {hunk.path}") + changes.append(("delete", abs_path, "", "")) + + elif hunk.type == "update": + if not abs_path.exists(): + raise PatchError(f"Cannot update non-existent file: {hunk.path}") + old_content = abs_path.read_text(encoding="utf-8") + new_content = apply_update_chunks(str(hunk.path), old_content, hunk.chunks) + changes.append(("update", abs_path, old_content, new_content)) + + # Phase 2: write all changes + for change_type, abs_path, _, new_content in changes: + if change_type == "add": + abs_path.parent.mkdir(parents=True, exist_ok=True) + abs_path.write_text(new_content, encoding="utf-8") + logger.debug(f"PATCH add: {abs_path.relative_to(resolved_dir)}") + + elif change_type == "delete": + if abs_path.exists(): + abs_path.unlink() + logger.debug(f"PATCH delete: {abs_path.relative_to(resolved_dir)}") + + elif change_type == "update": + abs_path.write_text(new_content, encoding="utf-8") + logger.debug(f"PATCH update: {abs_path.relative_to(resolved_dir)}") + + +# SEARCH/REPLACE (single-file DIFF) +def apply_search_replace( + patch_text: str, + original: str, + *, + strict: bool = True, +) -> tuple[str, int, Optional[str]]: + """Apply SEARCH/REPLACE blocks to a single file's content. + + Uses the 6-level fuzzy matching chain from ``fuzzy_match`` for + robust matching. + """ + new_text = original + num_applied = 0 + + blocks = list(PATCH_PATTERN.finditer(patch_text)) + if not blocks: + return new_text, 0, None + + for block in blocks: + search = _strip_trailing_ws(block.group(1)) + replace = _strip_trailing_ws(block.group(2)) + + # Empty SEARCH → append at end + if not search.strip(): + new_text = new_text.rstrip("\n") + "\n" + replace + "\n" + num_applied += 1 + continue + + # Use fuzzy matching chain + matched_search, pos = fuzzy_find_match(new_text, search) + + if pos != -1: + new_text = new_text[:pos] + replace + new_text[pos + len(matched_search):] + num_applied += 1 + continue + + # Not found + if strict: + first_line = search.splitlines()[0].strip() if search.splitlines() else "" + similar = _find_similar_lines(first_line, new_text) + msg_parts = [ + f"SEARCH text not found in {SKILL_FILENAME}", + "", + f"Looking for: {first_line!r}", + ] + if similar: + msg_parts.append("") + msg_parts.append("Similar lines found:") + for line, line_num in similar: + msg_parts.append(f" Line {line_num}: {line.strip()}") + msg_parts.extend([ + "", + "Ensure the SEARCH block matches the file content exactly.", + ]) + return new_text, num_applied, "\n".join(msg_parts) + + return new_text, num_applied, None + + +def _apply_search_replace_to_file( + patch_text: str, skill_file: Path, +) -> None: + """Apply SEARCH/REPLACE blocks to a file on disk.""" + original = skill_file.read_text(encoding="utf-8") + updated, num_applied, error = apply_search_replace(patch_text, original) + if error: + raise PatchError(error) + if num_applied == 0: + raise PatchError("No SEARCH/REPLACE blocks found in LLM output") + skill_file.write_text(updated, encoding="utf-8") + + +# DIFF COMPUTATION & SNAPSHOT +def compute_unified_diff( + original: str, + updated: str, + *, + filename: str = SKILL_FILENAME, + context: int = 3, +) -> str: + """Unified diff (git diff format) between two strings.""" + diff_lines = difflib.unified_diff( + original.splitlines(keepends=True), + updated.splitlines(keepends=True), + fromfile=f"a/{filename}", + tofile=f"b/{filename}", + n=context, + ) + return "".join(diff_lines) + +def compute_skill_diff(old_dir: Path, new_dir: Path) -> str: + """Compare all files in two skill directories, return combined diff.""" + old_files = _collect_files(old_dir) if old_dir.is_dir() else {} + new_files = _collect_files(new_dir) if new_dir.is_dir() else {} + + all_names = sorted(set(old_files) | set(new_files)) + parts: list[str] = [] + for name in all_names: + d = compute_unified_diff( + old_files.get(name, ""), + new_files.get(name, ""), + filename=name, + ) + if d: + parts.append(d) + return "\n".join(parts) + +def collect_skill_snapshot(skill_dir: Path) -> Dict[str, str]: + """Collect all text files in a skill directory. + + Returns ``{relative_path: content}``. Binary files are silently skipped. + """ + return _collect_files(skill_dir) + +def _compute_files_diff( + old_files: Dict[str, str], + new_files: Dict[str, str], +) -> str: + """Compute combined unified diff from two snapshot dicts.""" + all_names = sorted(set(old_files) | set(new_files)) + parts: list[str] = [] + for name in all_names: + d = compute_unified_diff( + old_files.get(name, ""), + new_files.get(name, ""), + filename=name, + ) + if d: + parts.append(d) + return "\n".join(parts) + +def _collect_files(directory: Path) -> Dict[str, str]: + """Collect all text files in a directory (recursive). + + Excludes the ``.skill_id`` sidecar (internal metadata, not skill content). + """ + files: Dict[str, str] = {} + for p in sorted(directory.rglob("*")): + if p.is_file() and p.name != _SKILL_ID_FILENAME: + rel = str(p.relative_to(directory)) + try: + files[rel] = p.read_text(encoding="utf-8") + except (UnicodeDecodeError, OSError): + pass + return files + +def _strip_trailing_ws(text: str) -> str: + return "\n".join(line.rstrip() for line in text.splitlines()) + +def _find_similar_lines( + search_line: str, + text: str, + max_suggestions: int = 3, +) -> List[Tuple[str, int]]: + """Fuzzy match for error reporting.""" + import difflib as _dl + + search_clean = search_line.strip() + if not search_clean: + return [] + + results = [] + for i, line in enumerate(text.splitlines()): + line_clean = line.strip() + if not line_clean: + continue + ratio = _dl.SequenceMatcher(None, search_clean, line_clean).ratio() + if ratio > 0.6: + results.append((line, i + 1, ratio)) + + results.sort(key=lambda x: x[2], reverse=True) + return [(line, num) for line, num, _ in results[:max_suggestions]] diff --git a/openspace/skill_engine/registry.py b/openspace/skill_engine/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bb35dd8e3f18c824d10280fec89864a5df6a78eb --- /dev/null +++ b/openspace/skill_engine/registry.py @@ -0,0 +1,736 @@ +"""SkillRegistry — discover, load, match, and inject skills. + +Skills follow the official SKILL.md format: + - YAML frontmatter with only ``name`` and ``description`` + - Markdown body with instructions (loaded only after selection) + +Skills are discovered from user-configured directories and matched to +tasks via LLM-based selection (with keyword fallback). + +Skill identity: + Every skill directory may contain a ``.skill_id`` sidecar file that + stores the persistent unique identifier. On **first discovery** + (no ``.skill_id`` file present), an ID is generated and written to + the file. On subsequent runs the ID is **read** from the file — + this makes the ID portable (survives directory moves, machine changes) + and deterministic (never regenerated). + + Imported skills: ``{name}__imp_{uuid_hex[:8]}`` + Evolved skills: ``{name}__v{gen}_{uuid_hex[:8]}`` (written by evolver) +""" + +from __future__ import annotations + +import json +import re +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from openspace.utils.logging import Logger +from .skill_utils import parse_frontmatter, strip_frontmatter, check_skill_safety, is_skill_safe +from .skill_ranker import SkillRanker, SkillCandidate, PREFILTER_THRESHOLD + +if TYPE_CHECKING: + from openspace.llm import LLMClient + +logger = Logger.get_logger(__name__) + +# Sidecar filename that stores the persistent skill_id +SKILL_ID_FILENAME = ".skill_id" + + +def _read_or_create_skill_id(name: str, skill_dir: Path) -> str: + """Read ``skill_id`` from ``.skill_id`` sidecar, or create one. + + The sidecar file is a single-line plain-text file containing only + the ``skill_id`` string. It lives alongside ``SKILL.md`` inside + the skill directory. + + First call (no file): generates ``{name}__imp_{uuid8}`` and writes it. + Subsequent calls: reads and returns the existing ID. + """ + id_file = skill_dir / SKILL_ID_FILENAME + if id_file.exists(): + try: + existing = id_file.read_text(encoding="utf-8").strip() + if existing: + return existing + except OSError: + pass # fall through to generate + + # Generate a new ID and persist + new_id = f"{name}__imp_{uuid.uuid4().hex[:8]}" + try: + id_file.write_text(new_id + "\n", encoding="utf-8") + logger.debug(f"Created .skill_id for '{name}': {new_id}") + except OSError as e: + logger.warning(f"Cannot write {id_file}: {e} — ID will not persist across restarts") + return new_id + + +def write_skill_id(skill_dir: Path, skill_id: str) -> None: + """Write (or overwrite) the ``.skill_id`` sidecar in *skill_dir*. + + Called by ``SkillEvolver`` after FIX / DERIVED / CAPTURED to stamp + the new ``skill_id`` into the skill directory so that the next + ``discover()`` picks it up correctly. + """ + id_file = skill_dir / SKILL_ID_FILENAME + try: + id_file.write_text(skill_id + "\n", encoding="utf-8") + except OSError as e: + logger.warning(f"Cannot write {id_file}: {e}") + + +@dataclass +class SkillMeta: + """Metadata for a discovered skill. + + ``skill_id`` is the globally unique identifier used throughout the + system — LLM prompts, database, evolution, and selection all + reference this field. + """ + + skill_id: str # Unique — persisted in .skill_id sidecar + name: str # Human-readable name (from frontmatter or dirname) + description: str + path: Path # Absolute path to SKILL.md + + +class SkillRegistry: + """Discover, load, select, and inject skills into agent context. + + Args: + skill_dirs: Ordered list of directories to scan. Earlier entries have higher + priority — a skill in the first dir shadows one with the same name + in later dirs. + + All internal maps are keyed by ``skill_id``, not ``name``. + """ + + def __init__(self, skill_dirs: Optional[List[Path]] = None) -> None: + self._skill_dirs: List[Path] = skill_dirs or [] + self._skills: Dict[str, SkillMeta] = {} # skill_id -> SkillMeta + self._content_cache: Dict[str, str] = {} # skill_id -> raw SKILL.md content + self._discovered = False + self._ranker: Optional[SkillRanker] = None # lazy-init on first use + + def discover(self) -> List[SkillMeta]: + """Scan all skill_dirs and populate the registry. + + Each skill is a sub-directory containing a ``SKILL.md`` file. + The ``skill_id`` is read from the ``.skill_id`` sidecar (created + automatically on first discovery). Two skills with the same + ``name`` in different directories get different IDs and can + coexist in the registry and database. + """ + self._skills.clear() + self._content_cache.clear() + + for skill_dir in self._skill_dirs: + if not skill_dir.exists(): + logger.debug(f"Skill dir does not exist, skipping: {skill_dir}") + continue + + for entry in sorted(skill_dir.iterdir()): + if not entry.is_dir(): + continue + skill_file = entry / "SKILL.md" + if not skill_file.exists(): + continue + + try: + content = skill_file.read_text(encoding="utf-8") + + # Safety check on skill content + safety_flags = check_skill_safety(content) + if not is_skill_safe(safety_flags): + logger.warning( + f"BLOCKED skill {entry.name}: " + f"safety flags {safety_flags}" + ) + continue + + meta = self._parse_skill(entry.name, entry, skill_file, content) + sid = meta.skill_id + + if sid in self._skills: + logger.debug(f"Skill '{sid}' already discovered, skipping {skill_file}") + continue + + self._skills[sid] = meta + self._content_cache[sid] = content + if safety_flags: + logger.debug(f"Discovered skill: {sid} (safety: {safety_flags})") + else: + logger.debug(f"Discovered skill: {sid} — {meta.description[:60]}") + except Exception as e: + logger.warning(f"Failed to parse skill {skill_file}: {e}") + + self._discovered = True + logger.info( + f"Skill discovery complete: {len(self._skills)} skill(s) " + f"from {len(self._skill_dirs)} dir(s)" + ) + return list(self._skills.values()) + + def list_skills(self) -> List[SkillMeta]: + """List all discovered skills.""" + self._ensure_discovered() + return list(self._skills.values()) + + def get_skill(self, skill_id: str) -> Optional[SkillMeta]: + """Get a skill by ``skill_id``.""" + self._ensure_discovered() + return self._skills.get(skill_id) + + def get_skill_by_name(self, name: str) -> Optional[SkillMeta]: + """Get a skill by ``name`` (first match). Use ``get_skill`` when possible.""" + self._ensure_discovered() + for meta in self._skills.values(): + if meta.name == name: + return meta + return None + + def update_skill(self, old_skill_id: str, new_meta: SkillMeta) -> None: + """Replace a skill entry after FIX evolution. + + Removes *old_skill_id* from the registry and inserts *new_meta* + under its (new) ``skill_id``. Content cache is refreshed from + the filesystem. + """ + self._skills.pop(old_skill_id, None) + self._content_cache.pop(old_skill_id, None) + + self._skills[new_meta.skill_id] = new_meta + if new_meta.path.exists(): + try: + self._content_cache[new_meta.skill_id] = ( + new_meta.path.read_text(encoding="utf-8") + ) + except Exception: + pass + logger.debug( + f"Registry.update_skill: {old_skill_id} → {new_meta.skill_id}" + ) + + def add_skill(self, meta: SkillMeta) -> None: + """Register a newly-created skill (DERIVED / CAPTURED). + + Does NOT overwrite an existing entry with the same ``skill_id``. + """ + if meta.skill_id in self._skills: + logger.debug( + f"Registry.add_skill: {meta.skill_id} already exists, skipping" + ) + return + self._skills[meta.skill_id] = meta + if meta.path.exists(): + try: + self._content_cache[meta.skill_id] = ( + meta.path.read_text(encoding="utf-8") + ) + except Exception: + pass + logger.debug(f"Registry.add_skill: {meta.skill_id}") + + # Hot-reload API (add external skills at runtime) + def discover_from_dirs(self, extra_dirs: List[Path]) -> List[SkillMeta]: + """Discover skills from additional directories and add to the registry. + + Unlike :meth:`discover`, this does **NOT** clear existing skills — it + only adds new ones from the given directories. Useful for hot-loading + external skills (e.g. host-agent skills, newly downloaded cloud skills). + + Safety: applies the same ``check_skill_safety`` / ``is_skill_safe`` + filtering as :meth:`discover` to prevent malicious external skills. + + Args: + extra_dirs: Additional directories to scan. + """ + added: List[SkillMeta] = [] + for skill_dir in extra_dirs: + if not skill_dir.exists() or not skill_dir.is_dir(): + logger.debug(f"discover_from_dirs: skipping {skill_dir}") + continue + for entry in sorted(skill_dir.iterdir()): + if not entry.is_dir(): + continue + skill_file = entry / "SKILL.md" + if not skill_file.exists(): + continue + try: + content = skill_file.read_text(encoding="utf-8") + + # Safety check (same as discover()) + safety_flags = check_skill_safety(content) + if not is_skill_safe(safety_flags): + logger.warning( + f"BLOCKED external skill {entry.name}: " + f"safety flags {safety_flags}" + ) + continue + + meta = self._parse_skill(entry.name, entry, skill_file, content) + if meta.skill_id in self._skills: + continue + self._skills[meta.skill_id] = meta + self._content_cache[meta.skill_id] = content + added.append(meta) + logger.debug(f"Hot-registered: {meta.skill_id} — {meta.description[:60]}") + except Exception as e: + logger.warning(f"Failed to parse skill {skill_file}: {e}") + + if added: + logger.info( + f"discover_from_dirs: {len(added)} new skill(s) from " + f"{len(extra_dirs)} dir(s)" + ) + return added + + def register_skill_dir(self, skill_dir: Path) -> Optional[SkillMeta]: + """Register a single skill directory (hot-reload). + + Safety: applies ``check_skill_safety`` / ``is_skill_safe`` filtering. + + Args: + skill_dir: Path to a directory containing ``SKILL.md``. + + Returns: + :class:`SkillMeta` if newly registered, ``None`` if already + present, the directory is invalid, or the skill fails safety checks. + """ + skill_file = skill_dir / "SKILL.md" + if not skill_file.exists(): + logger.debug(f"register_skill_dir: no SKILL.md in {skill_dir}") + return None + try: + content = skill_file.read_text(encoding="utf-8") + + # Safety check (same as discover()) + safety_flags = check_skill_safety(content) + if not is_skill_safe(safety_flags): + logger.warning( + f"BLOCKED skill {skill_dir.name}: " + f"safety flags {safety_flags}" + ) + return None + + meta = self._parse_skill(skill_dir.name, skill_dir, skill_file, content) + if meta.skill_id in self._skills: + logger.debug(f"register_skill_dir: {meta.skill_id} already exists") + return None + self._skills[meta.skill_id] = meta + self._content_cache[meta.skill_id] = content + logger.info(f"Hot-registered skill: {meta.skill_id}") + return meta + except Exception as e: + logger.warning(f"Failed to register skill {skill_dir}: {e}") + return None + + @property + def ranker(self) -> SkillRanker: + """Lazy-initialised :class:`SkillRanker` for hybrid pre-filtering.""" + if self._ranker is None: + self._ranker = SkillRanker() + return self._ranker + + async def select_skills_with_llm( + self, + task_description: str, + llm_client: "LLMClient", + max_skills: int = 2, + model: Optional[str] = None, + skill_quality: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> tuple[List[SkillMeta], Optional[Dict[str, Any]]]: + """Use an LLM to select the most relevant skills. + + When the local registry has more than ``PREFILTER_THRESHOLD`` skills, + a **BM25 → embedding** pre-filter narrows the candidate set before + sending to the LLM. This avoids stuffing an overly long catalog + into the prompt. + + Progressive disclosure: the LLM only sees skill *headers* + (skill_id + description + quality stats), not the full SKILL.md + content. Full content is loaded only after selection. + + Args: + task_description: The user's task instruction. + llm_client: An initialised LLMClient used for the selection call. + max_skills: Maximum number of skills to inject. + model: Override model for this selection call. + If None, falls back to ``llm_client``'s default model. + skill_quality: Optional mapping ``{skill_id: {total_applied, total_completions, total_fallbacks}}`` + from :class:`SkillStore`. When provided, skills with high + fallback rates are filtered out and quality signals are + included in the LLM selection prompt. + + Returns: + tuple[list[SkillMeta], dict | None]: (selected_skills, selection_record). + selection_record contains the LLM conversation for logging. + """ + self._ensure_discovered() + if not task_description: + return [], None + + available = list(self._skills.values()) + if not available: + return [], None + + # Quality-based filtering: remove skills that consistently fail + filtered_out: List[str] = [] + if skill_quality: + kept: List[SkillMeta] = [] + for s in available: + q = skill_quality.get(s.skill_id) + if q: + selections = q.get("total_selections", 0) + applied = q.get("total_applied", 0) + completions = q.get("total_completions", 0) + fallbacks = q.get("total_fallbacks", 0) + # Filter 1: selected multiple times but never completed + if selections >= 2 and completions == 0: + filtered_out.append(s.skill_id) + continue + # Filter 2: high fallback rate when applied + if applied >= 2 and fallbacks / applied > 0.5: + filtered_out.append(s.skill_id) + continue + kept.append(s) + if filtered_out: + logger.info( + f"Skill quality filter: removed {len(filtered_out)} " + f"high-fallback skill(s): {filtered_out}" + ) + available = kept + + if not available: + return [], None + + # Pre-filter when skill count exceeds threshold + prefilter_used = False + if len(available) > PREFILTER_THRESHOLD: + available = self._prefilter_skills(task_description, available, max_skills) + prefilter_used = True + + # Build a concise skills catalogue for the LLM (skill_id + description + quality) + catalog_lines: List[str] = [] + for s in available: + q = skill_quality.get(s.skill_id) if skill_quality else None + if q: + selections = q.get("total_selections", 0) + applied = q.get("total_applied", 0) + completions = q.get("total_completions", 0) + if applied > 0: + rate = completions / applied + catalog_lines.append( + f"- **{s.skill_id}**: {s.description} " + f"(success {completions}/{applied} = {rate:.0%})" + ) + elif selections > 0: + catalog_lines.append( + f"- **{s.skill_id}**: {s.description} " + f"(selected {selections}x, never succeeded)" + ) + else: + catalog_lines.append(f"- **{s.skill_id}**: {s.description} (new)") + else: + catalog_lines.append(f"- **{s.skill_id}**: {s.description}") + skills_catalog = "\n".join(catalog_lines) + + prompt = self._build_skill_selection_prompt( + task_description, skills_catalog, max_skills + ) + + selection_record: Dict[str, Any] = { + "method": "llm", + "task": task_description[:500], + "available_skills": [s.skill_id for s in available], + "filtered_out": filtered_out, + "prefilter_used": prefilter_used, + "prompt": prompt, + } + + try: + from gdpval_bench.token_tracker import set_call_source, reset_call_source + _src_tok = set_call_source("skill_select") + except ImportError: + _src_tok = None + + try: + llm_kwargs = {} + if model: + llm_kwargs["model"] = model + resp = await llm_client.complete(prompt, **llm_kwargs) + content = resp["message"]["content"].strip() + selected_ids, brief_plan = self._parse_skill_selection_response(content) + + selection_record["llm_response"] = content + selection_record["parsed_ids"] = selected_ids + selection_record["brief_plan"] = brief_plan + + # Validate ids against registry & cap + result: List[SkillMeta] = [] + for sid in selected_ids: + if len(result) >= max_skills: + break + meta = self._skills.get(sid) + if meta: + result.append(meta) + else: + logger.debug(f"LLM selected unknown skill_id: {sid}") + + selection_record["selected"] = [s.skill_id for s in result] + + if result: + ids = ", ".join(s.skill_id for s in result) + logger.info(f"LLM skill selection: [{ids}]") + else: + logger.info("LLM decided no skills are relevant for this task") + + return result, selection_record + + except Exception as e: + logger.warning(f"LLM skill selection failed: {e} — proceeding without skills") + selection_record["error"] = str(e) + selection_record["method"] = "llm_failed" + selection_record["selected"] = [] + return [], selection_record + finally: + if _src_tok is not None: + reset_call_source(_src_tok) + + def _prefilter_skills( + self, + task: str, + available: List[SkillMeta], + max_skills: int, + ) -> List[SkillMeta]: + """Narrow the candidate set using BM25 + embedding hybrid ranking. + + Keeps at most ``max(15, max_skills * 5)`` candidates for the LLM + selection prompt. + """ + prefilter_top_k = max(15, max_skills * 5) + + # Build SkillCandidate list + candidates: List[SkillCandidate] = [] + for s in available: + body = "" + raw = self._content_cache.get(s.skill_id, "") + if raw: + body = strip_frontmatter(raw) + + candidates.append(SkillCandidate( + skill_id=s.skill_id, + name=s.name, + description=s.description, + body=body, + )) + + ranked = self.ranker.hybrid_rank(task, candidates, top_k=prefilter_top_k) + + # Map back to SkillMeta + ranked_ids = {c.skill_id for c in ranked} + result = [s for s in available if s.skill_id in ranked_ids] + + if len(result) < len(available): + logger.info( + f"Skill pre-filter: {len(available)} → {len(result)} candidates " + f"(BM25+embedding, threshold={PREFILTER_THRESHOLD})" + ) + return result + + def load_skill_content(self, skill_id: str) -> Optional[str]: + """Return the SKILL.md content (with frontmatter stripped) for *skill_id*.""" + self._ensure_discovered() + raw = self._content_cache.get(skill_id) + if raw is None: + return None + return self._strip_frontmatter(raw) + + def build_context_injection( + self, + skills: List[SkillMeta], + backends: Optional[List[str]] = None, + ) -> str: + """Build a prompt fragment with the full content of *skills*. + + Injected as a system message into the agent's messages before the + user instruction so the LLM reads skill guidance first. + + Args: + skills: Skills to inject. + backends: Active backend names (e.g. ``["shell", "mcp"]``). Used to + tailor the guidance so only actually available backends are + mentioned. ``None`` falls back to mentioning all backends. + + Key features: + - Includes the skill directory path so the agent can resolve + relative references to ``scripts/``, ``references/``, ``assets/``. + - Replaces ``{baseDir}`` placeholders with the actual skill + directory path (a convention used in some SKILL.md files). + """ + parts: List[str] = [] + for skill in skills: + content = self.load_skill_content(skill.skill_id) + if content: + # Resolve {baseDir} placeholder to the skill directory + skill_dir = str(skill.path.parent) + content = content.replace("{baseDir}", skill_dir) + + part = ( + f"### Skill: {skill.skill_id}\n" + f"**Skill directory**: `{skill_dir}`\n\n" + f"{content}" + ) + parts.append(part) + + if not parts: + return "" + + # Build a backend hint that only mentions registered backends + scope = set(backends) if backends else {"gui", "shell", "mcp", "web", "system"} + backend_names: List[str] = [] + if "mcp" in scope: + backend_names.append("MCP") + if "shell" in scope: + backend_names.append("shell") + if "gui" in scope: + backend_names.append("GUI") + tool_hint = ", ".join(backend_names) if backend_names else "available" + + # Resource access tips — mention shell_agent only when shell is available + has_shell = "shell" in scope + resource_tip = ( + "Use `read_file` / `list_dir` / `write_file` for file operations" + + (" and `shell_agent` for running scripts" if has_shell else "") + + ". Paths in skill instructions are relative to the skill " + "directory listed under each skill heading.\n\n" + ) + + header = ( + "# Active Skills\n\n" + "The following skills provide **domain knowledge and tested procedures** " + "relevant to this task.\n\n" + "**How to use skills:**\n" + "- If a skill contains **step-by-step procedures or commands**, follow them — " + "they are verified workflows.\n" + "- If a skill provides **reference information, best practices, or tool guides**, " + "use it as context to inform your decisions.\n" + f"- Skills supplement your available tools — you may use **any** tool " + f"({tool_hint}) alongside skill guidance. " + "Choose the best tool for each sub-step.\n\n" + "**Resource access**: Each skill may include bundled resources " + "(scripts, references, assets) in its skill directory. " + + resource_tip + ) + return header + "\n\n---\n\n".join(parts) + + def _ensure_discovered(self) -> None: + if not self._discovered: + self.discover() + + @staticmethod + def _parse_skill( + dir_name: str, + skill_dir: Path, + skill_file: Path, + content: str, + ) -> SkillMeta: + """Parse a SKILL.md file into a SkillMeta. + + Only ``name`` and ``description`` are read from frontmatter + (per the official skill format). ``skill_id`` is read from + the ``.skill_id`` sidecar (created if absent). + """ + frontmatter = parse_frontmatter(content) + name = frontmatter.get("name", dir_name) + description = frontmatter.get("description", name) + skill_id = _read_or_create_skill_id(name, skill_dir) + + return SkillMeta( + skill_id=skill_id, + name=name, + description=description, + path=skill_file, + ) + + # Frontmatter parsing is delegated to skill_utils (single source of truth). + _extract_frontmatter = staticmethod(parse_frontmatter) + _strip_frontmatter = staticmethod(strip_frontmatter) + + @staticmethod + def _build_skill_selection_prompt( + task: str, + skills_catalog: str, + max_skills: int, + ) -> str: + """Build the prompt for LLM skill selection. + + Uses a plan-then-select pattern: the LLM first writes a brief + execution plan, then selects skills that match the plan. + """ + return f"""You are a skill selector for an autonomous agent. + +# Task + +{task} + +# Available Skills + +{skills_catalog} + +# Instructions + +Follow these steps: + +**Step 1 — Plan**: Think about how you would accomplish this task. What are the key deliverables? What file formats are needed (PDF, DOCX, XLSX, etc.)? What tools or libraries would you use? + +**Step 2 — Match**: Check which skills directly teach workflows for the deliverables or file formats identified in your plan. A skill is relevant ONLY if it provides a tested procedure for a core part of your plan. Skills that only share vague topical overlap (e.g. a "PDF checklist" skill for a task that just happens to involve PDFs) add noise and should be excluded. + +**Step 3 — Quality check**: Among matching skills, prefer ones with higher success rates. Avoid skills marked as "never succeeded" or with very low success rates — they waste iterations and actively hurt performance. + +**Step 4 — Decide**: Select at most {max_skills} skill(s). If no skill closely matches your plan, you MUST return an empty list. Selecting an irrelevant or low-quality skill is **worse than selecting none** — it forces the agent down an unproductive path and wastes the entire iteration budget. When in doubt, leave it out. + +Return a JSON object: +{{"brief_plan": "1-2 sentence plan for this task", "skills": ["skill_id_1", "skill_id_2"]}} + +If no skill applies: +{{"brief_plan": "1-2 sentence plan", "skills": []}} + +IMPORTANT: Use the **exact skill_id** from the list above.""" + + @staticmethod + def _parse_skill_selection_response(content: str) -> tuple[List[str], str]: + """Parse the LLM response and extract selected skill IDs + plan. + + Returns: + (skill_ids, brief_plan) + """ + # Handle markdown code blocks + code_block = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", content, re.DOTALL) + if code_block: + content = code_block.group(1).strip() + else: + # Try to find a raw JSON object + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + content = json_match.group() + + try: + data = json.loads(content) + except json.JSONDecodeError: + logger.warning(f"Failed to parse LLM skill selection JSON: {content[:200]}") + return [], "" + + brief_plan = data.get("brief_plan", "") + if brief_plan: + logger.info(f"Skill selection plan: {brief_plan}") + + ids = data.get("skills", []) + if not isinstance(ids, list): + return [], brief_plan + return [str(n).strip() for n in ids if n], brief_plan diff --git a/openspace/skill_engine/retrieve_tool.py b/openspace/skill_engine/retrieve_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa874870d27232fb95042843f2c7bc183a75a55 --- /dev/null +++ b/openspace/skill_engine/retrieve_tool.py @@ -0,0 +1,111 @@ +"""RetrieveSkillTool — mid-iteration skill retrieval for GroundingAgent. + +Registered as an internal tool so the LLM can pull in skill guidance +during execution when the initial skill set is insufficient. + +Reuses the same pipeline as initial skill selection: + quality filter → BM25+embedding pre-filter → LLM plan-then-select. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from openspace.grounding.core.tool.local_tool import LocalTool +from openspace.grounding.core.types import BackendType +from openspace.utils.logging import Logger + +if TYPE_CHECKING: + from openspace.llm import LLMClient + from openspace.skill_engine import SkillRegistry + from openspace.skill_engine.store import SkillStore + +logger = Logger.get_logger(__name__) + + +class RetrieveSkillTool(LocalTool): + """Internal tool: mid-iteration skill retrieval. + + Reuses ``SkillRegistry.select_skills_with_llm()`` so the same + quality filter, BM25+embedding pre-filter, and plan-then-select + LLM prompt are applied consistently. + """ + + _name = "retrieve_skill" + _description = ( + "Search for specialized skill guidance when the current approach " + "isn't working or the task requires domain-specific knowledge. " + "Returns step-by-step instructions if a relevant skill is found." + ) + backend_type = BackendType.SYSTEM + + def __init__( + self, + skill_registry: "SkillRegistry", + backends: Optional[List[str]] = None, + llm_client: Optional["LLMClient"] = None, + skill_store: Optional["SkillStore"] = None, + ): + super().__init__() + self._skill_registry = skill_registry + self._backends = backends + self._llm_client = llm_client + self._skill_store = skill_store + + def _load_skill_quality(self) -> Optional[Dict[str, Dict[str, Any]]]: + if not self._skill_store: + return None + try: + rows = self._skill_store.get_summary(active_only=True) + return { + r["skill_id"]: { + "total_selections": r.get("total_selections", 0), + "total_applied": r.get("total_applied", 0), + "total_completions": r.get("total_completions", 0), + "total_fallbacks": r.get("total_fallbacks", 0), + } + for r in rows + } + except Exception: + return None + + async def _arun(self, query: str) -> str: + if self._llm_client: + # Full pipeline: quality filter → BM25+embedding → LLM plan-then-select + quality = self._load_skill_quality() + selected, record = await self._skill_registry.select_skills_with_llm( + query, + llm_client=self._llm_client, + max_skills=1, + skill_quality=quality, + ) + if record: + plan = record.get("brief_plan", "") + if plan: + logger.info(f"retrieve_skill plan: {plan}") + else: + # Fallback: BM25+embedding only (no LLM available) + from openspace.cloud.search import hybrid_search_skills + + results = await hybrid_search_skills( + query=query, + local_skills=self._skill_registry.list_skills(), + source="local", + limit=1, + ) + if not results: + return "No relevant skills found for this query." + + hit_ids = {r["skill_id"] for r in results} + selected = [ + s for s in self._skill_registry.list_skills() + if s.skill_id in hit_ids + ] + + if not selected: + return "No relevant skills found for this query." + + logger.info(f"retrieve_skill matched: {[s.skill_id for s in selected]}") + return self._skill_registry.build_context_injection( + selected, backends=self._backends, + ) diff --git a/openspace/skill_engine/skill_ranker.py b/openspace/skill_engine/skill_ranker.py new file mode 100644 index 0000000000000000000000000000000000000000..503eda5e11726bfc6f8cbad7c39f8c63bea685a2 --- /dev/null +++ b/openspace/skill_engine/skill_ranker.py @@ -0,0 +1,415 @@ +"""SkillRanker — BM25 + embedding hybrid ranking for skills. + +Provides a two-stage retrieval pipeline for skill selection: + Stage 1 (BM25): Fast lexical rough-rank over all skills + Stage 2 (Embedding): Semantic re-rank on BM25 candidates + +Embedding strategy: + - Text = ``name + description + SKILL.md body`` (consistent with MCP + ``search_skills`` and the clawhub cloud platform) + - Model: ``qwen/qwen3-embedding-8b`` via OpenRouter API + - Embeddings are cached in-memory keyed by ``skill_id`` and optionally + persisted to a pickle file for cross-session reuse + +Reused by: + - ``SkillRegistry.select_skills_with_llm`` — pre-filter before LLM selection + - ``mcp_server.search_skills`` — BM25 stage of the MCP search tool +""" + +from __future__ import annotations + +import json +import math +import os +import pickle +import re +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +# Embedding model — must match clawhub platform for vector-space compatibility +SKILL_EMBEDDING_MODEL = "openai/text-embedding-3-small" +SKILL_EMBEDDING_MAX_CHARS = 12_000 + +# Pre-filter threshold: when local skills exceed this count, BM25 pre-filter +# is activated before LLM selection. Below this, all skills go directly to LLM. +PREFILTER_THRESHOLD = 10 + +# How many candidates to keep after BM25 rough-rank (before embedding re-rank) +BM25_CANDIDATES_MULTIPLIER = 3 # top_k * 3 + +# Cache version — increment when format changes +_CACHE_VERSION = 1 + + +@dataclass +class SkillCandidate: + """Lightweight skill representation for ranking.""" + skill_id: str + name: str + description: str + body: str = "" # SKILL.md body (frontmatter stripped) + source: str = "local" # "local" | "cloud" + # Internal ranking fields + embedding: Optional[List[float]] = None + embedding_text: str = "" # text used to compute embedding + score: float = 0.0 + bm25_score: float = 0.0 + vector_score: float = 0.0 + # Pass-through metadata (for MCP search results) + metadata: Dict[str, Any] = field(default_factory=dict) + + +class SkillRanker: + """Hybrid BM25 + embedding ranker for skills. + + Usage:: + + ranker = SkillRanker() + candidates = [SkillCandidate(skill_id=..., name=..., description=..., body=...)] + ranked = ranker.hybrid_rank(query, candidates, top_k=10) + """ + + def __init__( + self, + *, + cache_dir: Optional[Path] = None, + enable_cache: bool = True, + ) -> None: + # Embedding cache: skill_id → List[float] + self._embedding_cache: Dict[str, List[float]] = {} + self._enable_cache = enable_cache + + if cache_dir is None: + try: + from openspace.config.constants import PROJECT_ROOT + cache_dir = PROJECT_ROOT / ".openspace" / "skill_embedding_cache" + except Exception: + cache_dir = Path(".openspace") / "skill_embedding_cache" + self._cache_dir = Path(cache_dir) + + if self._enable_cache: + self._load_cache() + + def hybrid_rank( + self, + query: str, + candidates: List[SkillCandidate], + top_k: int = 10, + ) -> List[SkillCandidate]: + """BM25 rough-rank → embedding re-rank → return top_k. + + Falls back gracefully: + - No BM25 lib → simple token overlap + - No embedding API key → BM25-only + - Both fail → return first top_k candidates + """ + if not candidates or not query.strip(): + return candidates[:top_k] + + # Stage 1: BM25 rough-rank + bm25_top = self._bm25_rank(query, candidates, top_k * BM25_CANDIDATES_MULTIPLIER) + if not bm25_top: + # BM25 found nothing — try embedding on all candidates + emb_results = self._embedding_rank(query, candidates, top_k) + return emb_results if emb_results else candidates[:top_k] + + # Stage 2: Embedding re-rank on BM25 candidates + emb_results = self._embedding_rank(query, bm25_top, top_k) + if emb_results: + return emb_results + + # Embedding unavailable — return BM25 results + logger.debug("Embedding unavailable, using BM25-only results") + return bm25_top[:top_k] + + def bm25_only( + self, + query: str, + candidates: List[SkillCandidate], + top_k: int = 30, + ) -> List[SkillCandidate]: + """BM25-only ranking (for MCP search Phase 1).""" + return self._bm25_rank(query, candidates, top_k) + + def embedding_only( + self, + query: str, + candidates: List[SkillCandidate], + top_k: int = 10, + ) -> List[SkillCandidate]: + """Embedding-only ranking.""" + return self._embedding_rank(query, candidates, top_k) + + def get_or_compute_embedding( + self, candidate: SkillCandidate, + ) -> Optional[List[float]]: + """Get embedding from cache or compute it. + + Returns None if embedding cannot be generated. + """ + # Already has embedding (e.g. cloud pre-computed) + if candidate.embedding: + return candidate.embedding + + # Check cache + cached = self._embedding_cache.get(candidate.skill_id) + if cached: + candidate.embedding = cached + return cached + + # Compute + text = self._build_embedding_text(candidate) + emb = self._generate_embedding(text) + if emb: + candidate.embedding = emb + self._embedding_cache[candidate.skill_id] = emb + self._save_cache() + return emb + + def invalidate_cache(self, skill_id: str) -> None: + """Remove a skill's cached embedding (e.g. after evolution).""" + self._embedding_cache.pop(skill_id, None) + self._save_cache() + + def clear_cache(self) -> None: + """Clear all cached embeddings.""" + self._embedding_cache.clear() + self._save_cache() + + @staticmethod + def _tokenize(text: str) -> List[str]: + """Tokenize text for BM25.""" + tokens = re.split(r"[^\w]+", text.lower()) + return [t for t in tokens if t] + + def _bm25_rank( + self, + query: str, + candidates: List[SkillCandidate], + top_k: int, + ) -> List[SkillCandidate]: + """Rank candidates using BM25.""" + if not candidates: + return [] + + try: + from rank_bm25 import BM25Okapi # type: ignore + except ImportError: + BM25Okapi = None + + # Build corpus: name + description + truncated body for richer matching + corpus_tokens = [] + for c in candidates: + text = f"{c.name} {c.description}" + if c.body: + text += f" {c.body[:2000]}" # include body for BM25 but cap length + corpus_tokens.append(self._tokenize(text)) + + query_tokens = self._tokenize(query) + + if BM25Okapi and corpus_tokens: + bm25 = BM25Okapi(corpus_tokens) + scores = bm25.get_scores(query_tokens) + for c, s in zip(candidates, scores): + c.bm25_score = float(s) + else: + # Fallback: simple token overlap + q_set = set(query_tokens) + for c, toks in zip(candidates, corpus_tokens): + if not toks or not q_set: + c.bm25_score = 0.0 + else: + overlap = q_set.intersection(toks) + c.bm25_score = len(overlap) / len(q_set) + + # Sort and filter + ranked = sorted(candidates, key=lambda c: c.bm25_score, reverse=True) + + # If all scores are 0 (no match), return all candidates (let embedding decide) + if all(c.bm25_score == 0.0 for c in ranked): + logger.debug("BM25 found no matches, passing all candidates to embedding stage") + return candidates[:top_k] + + return ranked[:top_k] + + @staticmethod + def _get_openai_api_key() -> Optional[str]: + """Resolve OpenAI-compatible API key for embedding requests.""" + from openspace.cloud.embedding import resolve_embedding_api + api_key, _ = resolve_embedding_api() + return api_key + + @staticmethod + def _build_embedding_text(candidate: SkillCandidate) -> str: + """Build text for embedding, consistent with MCP search_skills.""" + if candidate.embedding_text: + return candidate.embedding_text + header = "\n".join(filter(None, [candidate.name, candidate.description])) + raw = "\n\n".join(filter(None, [header, candidate.body])) + if len(raw) > SKILL_EMBEDDING_MAX_CHARS: + raw = raw[:SKILL_EMBEDDING_MAX_CHARS] + candidate.embedding_text = raw + return raw + + def _embedding_rank( + self, + query: str, + candidates: List[SkillCandidate], + top_k: int, + ) -> List[SkillCandidate]: + """Rank candidates using embedding cosine similarity.""" + api_key = self._get_openai_api_key() + if not api_key: + return [] + + # Generate query embedding + query_emb = self._generate_embedding(query, api_key=api_key) + if not query_emb: + return [] + + # Ensure all candidates have embeddings + for c in candidates: + if not c.embedding: + cached = self._embedding_cache.get(c.skill_id) + if cached: + c.embedding = cached + else: + text = self._build_embedding_text(c) + emb = self._generate_embedding(text, api_key=api_key) + if emb: + c.embedding = emb + self._embedding_cache[c.skill_id] = emb + + # Save newly computed embeddings + self._save_cache() + + # Score + for c in candidates: + if c.embedding: + c.vector_score = _cosine_similarity(query_emb, c.embedding) + else: + c.vector_score = 0.0 + c.score = c.vector_score + + ranked = sorted(candidates, key=lambda c: c.score, reverse=True) + return ranked[:top_k] + + @staticmethod + def _generate_embedding( + text: str, + api_key: Optional[str] = None, + ) -> Optional[List[float]]: + """Generate embedding via OpenAI-compatible API (text-embedding-3-small). + + Delegates credential / base-URL resolution to + :func:`openspace.cloud.embedding.resolve_embedding_api`. + """ + from openspace.cloud.embedding import resolve_embedding_api + + resolved_key, base_url = resolve_embedding_api() + if not api_key: + api_key = resolved_key + if not api_key: + return None + + import urllib.request + + body = json.dumps({ + "model": SKILL_EMBEDDING_MODEL, + "input": text, + }).encode("utf-8") + + req = urllib.request.Request( + f"{base_url}/embeddings", + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + method="POST", + ) + import time + last_err = None + for attempt in range(3): + try: + with urllib.request.urlopen(req, timeout=15) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data.get("data", [{}])[0].get("embedding") + except Exception as e: + last_err = e + if attempt < 2: + delay = 2 * (attempt + 1) + logger.debug("Embedding request failed (attempt %d/3), retrying in %ds: %s", attempt + 1, delay, e) + time.sleep(delay) + logger.warning("Skill embedding generation failed after 3 attempts: %s", last_err) + return None + + def _cache_file(self) -> Path: + return self._cache_dir / f"skill_embeddings_v{_CACHE_VERSION}.pkl" + + def _load_cache(self) -> None: + """Load embedding cache from disk.""" + path = self._cache_file() + if not path.exists(): + return + try: + with open(path, "rb") as f: + data = pickle.load(f) + if isinstance(data, dict) and data.get("version") == _CACHE_VERSION: + self._embedding_cache = data.get("embeddings", {}) + logger.debug(f"Loaded {len(self._embedding_cache)} skill embeddings from cache") + except Exception as e: + logger.warning(f"Failed to load skill embedding cache: {e}") + self._embedding_cache = {} + + def _save_cache(self) -> None: + """Persist embedding cache to disk.""" + if not self._enable_cache or not self._embedding_cache: + return + try: + self._cache_dir.mkdir(parents=True, exist_ok=True) + data = { + "version": _CACHE_VERSION, + "model": SKILL_EMBEDDING_MODEL, + "last_updated": datetime.now().isoformat(), + "embeddings": self._embedding_cache, + } + with open(self._cache_file(), "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as e: + logger.warning(f"Failed to save skill embedding cache: {e}") + +def _cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors.""" + if len(a) != len(b) or not a: + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def build_skill_embedding_text( + name: str, + description: str, + readme_body: str, + max_chars: int = SKILL_EMBEDDING_MAX_CHARS, +) -> str: + """Build text for skill embedding: ``name + description + SKILL.md body``. + + Unified strategy matching MCP search_skills and clawhub platform. + """ + header = "\n".join(filter(None, [name, description])) + raw = "\n\n".join(filter(None, [header, readme_body])) + if len(raw) <= max_chars: + return raw + return raw[:max_chars] + diff --git a/openspace/skill_engine/skill_utils.py b/openspace/skill_engine/skill_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6fc7caa252c399845a9a4bc2039b526931637c --- /dev/null +++ b/openspace/skill_engine/skill_utils.py @@ -0,0 +1,309 @@ +"""Shared utility functions for the skill engine. + +Provides: + - YAML frontmatter parsing/manipulation (unified across registry, evolver, etc.) + - LLM output cleaning (markdown fence stripping, change summary extraction) + - Skill content safety checking (regex-based moderation) + - Skill directory validation + - Text truncation +""" + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + +SKILL_FILENAME = "SKILL.md" + +_SAFETY_RULES = [ + ("blocked.malware", re.compile(r"(ClawdAuthenticatorTool)", re.IGNORECASE)), + ("suspicious.keyword", re.compile(r"(malware|stealer|phish|phishing|keylogger)", re.IGNORECASE)), + ("suspicious.secrets", re.compile(r"(api[-_ ]?key|token|password|private key|secret)", re.IGNORECASE)), + ("suspicious.crypto", re.compile(r"(wallet|seed phrase|mnemonic|crypto)", re.IGNORECASE)), + ("suspicious.webhook", re.compile(r"(discord\.gg|webhook|hooks\.slack)", re.IGNORECASE)), + ("suspicious.script", re.compile(r"(curl[^\n]+\|\s*(sh|bash))", re.IGNORECASE)), + ("suspicious.url_shortener", re.compile(r"(bit\.ly|tinyurl\.com|t\.co|goo\.gl|is\.gd)", re.IGNORECASE)), +] + +_BLOCKING_FLAGS = frozenset({"blocked.malware"}) + + +def check_skill_safety(text: str) -> List[str]: + """Check *text* against safety rules, return list of triggered flag names. + + Returns an empty list if no rules match (= safe). + """ + return [flag for flag, pat in _SAFETY_RULES if pat.search(text)] + + +def is_skill_safe(flags: List[str]) -> bool: + """Return True if *flags* contain no blocking flag. + + ``suspicious.*`` flags are informational (logged / attached to search + results) but do NOT block. Only ``blocked.*`` flags cause rejection. + """ + return not any(f in _BLOCKING_FLAGS for f in flags) + +_FRONTMATTER_RE = re.compile(r"^---\n(.*?)\n---", re.DOTALL) + +# Characters that require YAML value quoting (colon-space, hash-space, +# or values starting with special YAML indicators). +_YAML_NEEDS_QUOTE_RE = re.compile(r"[:\#\[\]{}&*!|>'\"%@`]") + + +def _yaml_quote(value: str) -> str: + """Quote a YAML scalar value if it contains special characters.""" + if not value or not _YAML_NEEDS_QUOTE_RE.search(value): + return value + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + +def _yaml_unquote(value: str) -> str: + """Strip surrounding quotes and unescape a YAML scalar value.""" + if len(value) >= 2: + if (value[0] == '"' and value[-1] == '"') or \ + (value[0] == "'" and value[-1] == "'"): + inner = value[1:-1] + if value[0] == '"': + inner = inner.replace('\\"', '"').replace("\\\\", "\\") + return inner + return value + + +def parse_frontmatter(content: str) -> Dict[str, Any]: + """Parse YAML frontmatter into a flat dict. + + Simple line-by-line parser (no PyYAML dependency). + Handles both quoted and unquoted values. + Returns ``{}`` if no valid frontmatter is found. + """ + if not content.startswith("---"): + return {} + match = _FRONTMATTER_RE.match(content) + if not match: + return {} + fm: Dict[str, Any] = {} + for line in match.group(1).split("\n"): + if ":" in line: + key, value = line.split(":", 1) + key = key.strip() + if key: + fm[key] = _yaml_unquote(value.strip()) + return fm + + +def get_frontmatter_field(content: str, field_name: str) -> Optional[str]: + """Extract a single field value from YAML frontmatter. + + Returns ``None`` if the field is absent or content has no frontmatter. + """ + if not content.startswith("---"): + return None + match = _FRONTMATTER_RE.match(content) + if not match: + return None + for line in match.group(1).split("\n"): + if ":" in line: + key, value = line.split(":", 1) + if key.strip() == field_name: + return _yaml_unquote(value.strip()) + return None + + +def set_frontmatter_field(content: str, field_name: str, value: str) -> str: + """Set (or insert) a field in YAML frontmatter. + + Values containing YAML special characters (``:``, ``#``, etc.) are + automatically double-quoted to produce valid YAML. + + If *content* has no frontmatter, a new one is prepended. + """ + quoted = _yaml_quote(value) + if not content.startswith("---"): + return f"---\n{field_name}: {quoted}\n---\n{content}" + + match = _FRONTMATTER_RE.match(content) + if not match: + return content + + fm_text = match.group(1) + new_line = f"{field_name}: {quoted}" + found = False + new_lines = [] + for line in fm_text.split("\n"): + if ":" in line and line.split(":", 1)[0].strip() == field_name: + new_lines.append(new_line) + found = True + else: + new_lines.append(line) + if not found: + new_lines.append(new_line) + + new_fm = "\n".join(new_lines) + return f"---\n{new_fm}\n---{content[match.end():]}" + + +def normalize_frontmatter(content: str) -> str: + """Re-serialize frontmatter with proper YAML quoting. + + Parses the existing frontmatter, then re-writes each value through + :func:`_yaml_quote` so that colons, hashes, and other special + characters are safely double-quoted. The body after ``---`` is + preserved verbatim. + + Returns *content* unchanged if no frontmatter is found. + """ + if not content.startswith("---"): + return content + match = _FRONTMATTER_RE.match(content) + if not match: + return content + + fm = parse_frontmatter(content) + if not fm: + return content + + safe_lines = [f"{k}: {_yaml_quote(v)}" for k, v in fm.items()] + new_fm = "\n".join(safe_lines) + return f"---\n{new_fm}\n---{content[match.end():]}" + + +def strip_frontmatter(content: str) -> str: + """Remove YAML frontmatter from markdown content.""" + if content.startswith("---"): + match = re.match(r"^---\n.*?\n---\n?", content, re.DOTALL) + if match: + return content[match.end():].strip() + return content + +def strip_markdown_fences(text: str) -> str: + """Remove surrounding markdown code fences if present. + + Handles common LLM wrapping patterns: + - ````` ```markdown ```, ````` ```md ```, ````` ``` ```, ````` ```text ````` + - Nested triple-backtick pairs (outermost only) + - Leading/trailing whitespace around fences + """ + text = text.strip() + + # Pattern: opening ``` with optional language tag, content, closing ``` + m = re.match( + r"^```(?:markdown|md|text|yaml|diff|patch)?\s*\n(.*?)\n```\s*$", + text, + re.DOTALL, + ) + if m: + return m.group(1).strip() + + # Some LLMs emit ``````` (4+ backticks) as outer fence + m = re.match( + r"^`{3,}(?:\w+)?\s*\n(.*?)\n`{3,}\s*$", + text, + re.DOTALL, + ) + if m: + return m.group(1).strip() + + return text + + +_CHANGE_SUMMARY_RE = re.compile( + r"^[\s*_]*(?:CHANGE[\s_-]?SUMMARY)\s*[::]\s*(.+)", + re.IGNORECASE, +) + + +def extract_change_summary(content: str) -> tuple[str, str]: + """Extract ``CHANGE_SUMMARY`` from LLM output. + + Returns ``(clean_content, change_summary)``. + """ + lines = content.split("\n") + + # Find the first non-blank line + first_nonblank = -1 + for i, line in enumerate(lines): + if line.strip(): + first_nonblank = i + break + + if first_nonblank == -1: + return content, "" + + m = _CHANGE_SUMMARY_RE.match(lines[first_nonblank]) + if not m: + return content, "" + + # Strip markdown bold/italic markers (** or __) from both ends + summary = m.group(1).strip().strip("*_").strip() + + # Skip blank lines after the summary line to find content start + content_start = first_nonblank + 1 + while content_start < len(lines) and not lines[content_start].strip(): + content_start += 1 + + rest = "\n".join(lines[content_start:]) + return rest.strip(), summary + +def validate_skill_dir(skill_dir: Path) -> Optional[str]: + """Validate a skill directory after edit application. + + Returns None if valid, or an error message string. + Checks: + 1. Directory exists + 2. SKILL.md exists and is non-empty + 3. SKILL.md has valid YAML frontmatter with ``name`` field + 4. No empty files (warning-level, not blocking) + """ + if not skill_dir.exists(): + return f"Skill directory does not exist: {skill_dir}" + + skill_file = skill_dir / SKILL_FILENAME + if not skill_file.exists(): + return f"SKILL.md not found in {skill_dir}" + + try: + content = skill_file.read_text(encoding="utf-8") + except Exception as e: + return f"Cannot read SKILL.md: {e}" + + if not content.strip(): + return "SKILL.md is empty" + + # Check frontmatter + if not content.startswith("---"): + return "SKILL.md missing YAML frontmatter (should start with '---')" + + m = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) + if not m: + return "SKILL.md has malformed YAML frontmatter (missing closing '---')" + + # Check for required 'name' field in frontmatter + name = get_frontmatter_field(content, "name") + if not name: + return "SKILL.md frontmatter missing 'name' field" + + # Non-blocking checks: log warnings for empty auxiliary files + for p in skill_dir.rglob("*"): + if p.is_file() and p != skill_file: + try: + if p.stat().st_size == 0: + logger.warning(f"Validation: empty auxiliary file: {p.relative_to(skill_dir)}") + except OSError: + pass + + return None + + +def truncate(text: str, max_chars: int) -> str: + """Truncate *text* to *max_chars* with an ellipsis marker.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + f"\n\n... [truncated at {max_chars} chars]" + diff --git a/openspace/skill_engine/store.py b/openspace/skill_engine/store.py new file mode 100644 index 0000000000000000000000000000000000000000..4a90b71cb62d66832651c9a9c8486299ce914b83 --- /dev/null +++ b/openspace/skill_engine/store.py @@ -0,0 +1,1464 @@ +""" +Storage location: /.openspace/openspace.db +Tables: + skill_records — SkillRecord main table + skill_lineage_parents — Lineage parent-child relationships (many-to-many) + execution_analyses — ExecutionAnalysis records (one per task) + skill_judgments — Per-skill judgments within an analysis + skill_tool_deps — Tool dependencies + skill_tags — Auxiliary tags +""" + +from __future__ import annotations + +import asyncio +import json +import os +import libsql_experimental as sqlite3 +import threading +import time +from contextlib import contextmanager +from datetime import datetime +from functools import wraps +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional + +from .patch import collect_skill_snapshot, compute_unified_diff +from .types import ( + EvolutionSuggestion, + ExecutionAnalysis, + SkillCategory, + SkillJudgment, + SkillLineage, + SkillOrigin, + SkillRecord, + SkillVisibility, +) +from openspace.utils.logging import Logger +from openspace.config.constants import PROJECT_ROOT + +logger = Logger.get_logger(__name__) + + +def _db_retry( + max_retries: int = 5, + initial_delay: float = 0.1, + backoff: float = 2.0, +): + """Retry on transient SQLite errors with exponential backoff. + + Catches ``OperationalError`` (e.g. "database is locked") and + ``DatabaseError`` but NOT programming errors like ``InterfaceError``. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + delay = initial_delay + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except (sqlite3.OperationalError, sqlite3.DatabaseError) as exc: + if attempt == max_retries - 1: + logger.error( + f"DB {func.__name__} failed after " + f"{max_retries} retries: {exc}" + ) + raise + logger.warning( + f"DB {func.__name__} retry {attempt + 1}" + f"/{max_retries}: {exc}" + ) + time.sleep(delay) + delay *= backoff + + return wrapper + + return decorator + + +_DDL = """ +CREATE TABLE IF NOT EXISTS skill_records ( + skill_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + path TEXT NOT NULL DEFAULT '', + is_active INTEGER NOT NULL DEFAULT 1, + category TEXT NOT NULL DEFAULT 'workflow', + visibility TEXT NOT NULL DEFAULT 'private', + creator_id TEXT NOT NULL DEFAULT '', + lineage_origin TEXT NOT NULL DEFAULT 'imported', + lineage_generation INTEGER NOT NULL DEFAULT 0, + lineage_source_task_id TEXT, + lineage_change_summary TEXT NOT NULL DEFAULT '', + lineage_content_diff TEXT NOT NULL DEFAULT '', + lineage_content_snapshot TEXT NOT NULL DEFAULT '{}', + lineage_created_at TEXT NOT NULL, + lineage_created_by TEXT NOT NULL DEFAULT '', + total_selections INTEGER NOT NULL DEFAULT 0, + total_applied INTEGER NOT NULL DEFAULT 0, + total_completions INTEGER NOT NULL DEFAULT 0, + total_fallbacks INTEGER NOT NULL DEFAULT 0, + first_seen TEXT NOT NULL, + last_updated TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_sr_category ON skill_records(category); +CREATE INDEX IF NOT EXISTS idx_sr_updated ON skill_records(last_updated); +CREATE INDEX IF NOT EXISTS idx_sr_active ON skill_records(is_active); +CREATE INDEX IF NOT EXISTS idx_sr_name ON skill_records(name); + +CREATE TABLE IF NOT EXISTS skill_lineage_parents ( + skill_id TEXT NOT NULL + REFERENCES skill_records(skill_id) ON DELETE CASCADE, + parent_skill_id TEXT NOT NULL, + PRIMARY KEY (skill_id, parent_skill_id) +); +CREATE INDEX IF NOT EXISTS idx_lp_parent + ON skill_lineage_parents(parent_skill_id); + +-- One row per task. task_id is UNIQUE (at most one analysis per task). +CREATE TABLE IF NOT EXISTS execution_analyses ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT NOT NULL UNIQUE, + timestamp TEXT NOT NULL, + task_completed INTEGER NOT NULL DEFAULT 0, + execution_note TEXT NOT NULL DEFAULT '', + tool_issues TEXT NOT NULL DEFAULT '[]', + candidate_for_evolution INTEGER NOT NULL DEFAULT 0, + evolution_suggestions TEXT NOT NULL DEFAULT '[]', + analyzed_by TEXT NOT NULL DEFAULT '', + analyzed_at TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_ea_task ON execution_analyses(task_id); +CREATE INDEX IF NOT EXISTS idx_ea_ts ON execution_analyses(timestamp); + +-- Per-skill judgments within an analysis. +-- FK to execution_analyses.id (CASCADE delete). +-- skill_id is a plain TEXT — no FK to skill_records so that +-- historical judgments survive skill deletion. +CREATE TABLE IF NOT EXISTS skill_judgments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + analysis_id INTEGER NOT NULL + REFERENCES execution_analyses(id) ON DELETE CASCADE, + skill_id TEXT NOT NULL, + skill_applied INTEGER NOT NULL DEFAULT 0, + note TEXT NOT NULL DEFAULT '', + UNIQUE(analysis_id, skill_id) +); +CREATE INDEX IF NOT EXISTS idx_sj_skill ON skill_judgments(skill_id); +CREATE INDEX IF NOT EXISTS idx_sj_analysis ON skill_judgments(analysis_id); + +CREATE TABLE IF NOT EXISTS skill_tool_deps ( + skill_id TEXT NOT NULL + REFERENCES skill_records(skill_id) ON DELETE CASCADE, + tool_key TEXT NOT NULL, + critical INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (skill_id, tool_key) +); +CREATE INDEX IF NOT EXISTS idx_td_tool ON skill_tool_deps(tool_key); + +CREATE TABLE IF NOT EXISTS skill_tags ( + skill_id TEXT NOT NULL + REFERENCES skill_records(skill_id) ON DELETE CASCADE, + tag TEXT NOT NULL, + PRIMARY KEY (skill_id, tag) +); +""" + + +class SkillStore: + """SQLite persistence engine — Skill quality tracking and evolution ledger. + + Architecture: + Write path: async method → asyncio.to_thread → _xxx_sync → self._mu lock → self._conn + Read path: sync method → self._reader() → independent short connection (WAL parallel read) + + Lifecycle: ``__init__()`` → use → ``close()`` + Also supports async context manager: + async with SkillStore() as store: + await store.save_record(record) + rec = store.load_record(skill_id) + """ + + def __init__(self, db_path: Optional[Path] = None) -> None: + if db_path is None: + db_dir = PROJECT_ROOT / ".openspace" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / "openspace.db" + + self._db_path = Path(db_path) + self._mu = threading.Lock() + self._closed = False + + # Crash recovery: clean up stale WAL/SHM from unclean shutdown + self._cleanup_wal_on_startup() + + # Persistent write connection + self._conn = self._make_connection(read_only=False) + self._init_db() + logger.debug(f"SkillStore ready at {self._db_path}") + + def _make_connection(self, *, read_only: bool) -> sqlite3.Connection: + """Create a tuned SQLite connection. + + Write connection: ``check_same_thread=False`` for cross-thread + usage via ``asyncio.to_thread()``. + + Read connection: ``query_only=ON`` pragma for safety. + """ + db_url = os.environ.get("TURSO_DATABASE_URL") + auth_token = os.environ.get("TURSO_AUTH_TOKEN") + + if db_url and auth_token: + conn = sqlite3.connect(db_url, auth_token=auth_token) + else: + conn = sqlite3.connect( + str(self._db_path), + timeout=30.0, + check_same_thread=False, + ) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=30000") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA cache_size=-16000") # 16 MB + conn.execute("PRAGMA temp_store=MEMORY") + conn.execute("PRAGMA foreign_keys=ON") + if read_only: + conn.execute("PRAGMA query_only=ON") + + conn.row_factory = sqlite3.Row + return conn + + @contextmanager + def _reader(self) -> Generator[sqlite3.Connection, None, None]: + """Open a temporary read-only connection. + + WAL mode allows concurrent readers and one writer. + Each read operation gets its own connection so reads never + block the event loop and never contend with the write lock. + """ + self._ensure_open() + conn = self._make_connection(read_only=True) + try: + yield conn + finally: + conn.close() + + def _cleanup_wal_on_startup(self) -> None: + """Remove stale WAL/SHM left by unclean shutdown. + + If the main DB file is empty (0 bytes) but WAL/SHM companions + exist, the database is unrecoverable — delete the companions + so SQLite can start fresh. + """ + if not self._db_path.exists(): + return + wal = Path(f"{self._db_path}-wal") + shm = Path(f"{self._db_path}-shm") + if self._db_path.stat().st_size == 0 and ( + wal.exists() or shm.exists() + ): + logger.warning( + "Empty DB with WAL/SHM — removing for crash recovery" + ) + for f in (wal, shm): + if f.exists(): + f.unlink() + + @_db_retry() + def _init_db(self) -> None: + """Create tables if they don't exist (idempotent via IF NOT EXISTS).""" + with self._mu: + self._conn.executescript(_DDL) + self._conn.commit() + + # Lifecycle + def close(self) -> None: + """Close the persistent connection. Subsequent ops will raise. + + Performs a WAL checkpoint before closing so that all committed + data is flushed from the WAL file into the main ``.db`` file. + This ensures external tools (DB browsers, backup scripts) see + complete data without needing to understand SQLite WAL mode. + """ + if self._closed: + return + self._closed = True + try: + # Flush WAL → main DB so external readers see all data + db_url = os.environ.get("TURSO_DATABASE_URL") + if not db_url: + self._conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") + self._conn.close() + except Exception: + pass + logger.debug("SkillStore closed (WAL checkpointed)") + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + self.close() + + @property + def db_path(self) -> Path: + return self._db_path + + def _ensure_open(self) -> None: + if self._closed: + raise RuntimeError("SkillStore is closed") + + # Write API (async, offloaded via asyncio.to_thread) + async def save_record(self, record: SkillRecord) -> None: + """Upsert a single :class:`SkillRecord`.""" + await asyncio.to_thread(self._save_record_sync, record) + + async def save_records(self, records: List[SkillRecord]) -> None: + """Batch upsert in a single transaction.""" + await asyncio.to_thread(self._save_records_sync, records) + + async def sync_from_registry( + self, + discovered_skills: List[Any], + ) -> int: + """Ensure every discovered skill has an initial DB record. + + For each skill in *discovered_skills* (``SkillMeta`` objects + from :meth:`SkillRegistry.discover`), if no record with the + same ``skill_id`` already exists, a new :class:`SkillRecord` is + created (``origin=IMPORTED``, ``generation=0``). + + Existing records (including evolved ones) are left untouched. + + Args: + discovered_skills: List of ``SkillMeta`` objects. + """ + return await asyncio.to_thread( + self._sync_from_registry_sync, discovered_skills, + ) + + @_db_retry() + def _sync_from_registry_sync( + self, discovered_skills: List[Any], + ) -> int: + self._ensure_open() + created = 0 + refreshed = 0 + with self._mu: + self._conn.execute("BEGIN") + try: + # Fetch all existing records keyed by skill_id + rows = self._conn.execute( + "SELECT skill_id, name, description, " + "lineage_content_snapshot " + "FROM skill_records" + ).fetchall() + existing: Dict[str, Any] = {r[0]: r for r in rows} + + # Also fetch all paths with an active record. + # After FIX evolution the DB skill_id changes but the + # filesystem path stays the same. Matching by path + # prevents creating a duplicate imported record on restart. + path_rows = self._conn.execute( + "SELECT path FROM skill_records WHERE is_active=1" + ).fetchall() + existing_active_paths: set = {r[0] for r in path_rows} + + for meta in discovered_skills: + path_str = str(meta.path) + skill_dir = meta.path.parent + + if meta.skill_id in existing: + # Refresh name/description if frontmatter changed, + # and backfill empty content_snapshot + row = existing[meta.skill_id] + updates: List[str] = [] + params: list = [] + + if row["name"] != meta.name: + updates.append("name=?") + params.append(meta.name) + if row["description"] != meta.description: + updates.append("description=?") + params.append(meta.description) + + raw_snap = row["lineage_content_snapshot"] or "" + if raw_snap in ("", "{}"): + try: + snap = collect_skill_snapshot(skill_dir) + if snap: + updates.append("lineage_content_snapshot=?") + params.append(json.dumps(snap, ensure_ascii=False)) + diff = "\n".join( + compute_unified_diff("", text, filename=name) + for name, text in sorted(snap.items()) + if compute_unified_diff("", text, filename=name) + ) + if diff: + updates.append("lineage_content_diff=?") + params.append(diff) + except Exception as e: + logger.warning( + f"sync_from_registry: snapshot backfill failed " + f"for {meta.skill_id}: {e}" + ) + + if updates: + params.append(meta.skill_id) + self._conn.execute( + f"UPDATE skill_records SET {', '.join(updates)} " + f"WHERE skill_id=?", + params, + ) + refreshed += 1 + continue + + # Path already covered by an evolved record + if path_str in existing_active_paths: + continue + + # Snapshot the directory so this version can be restored later + snapshot: Dict[str, str] = {} + content_diff = "" + try: + snapshot = collect_skill_snapshot(skill_dir) + content_diff = "\n".join( + compute_unified_diff("", text, filename=name) + for name, text in sorted(snapshot.items()) + if compute_unified_diff("", text, filename=name) + ) + except Exception as e: + logger.warning( + f"sync_from_registry: failed to snapshot {skill_dir}: {e}" + ) + + record = SkillRecord( + skill_id=meta.skill_id, + name=meta.name, + description=meta.description, + path=path_str, + is_active=True, + lineage=SkillLineage( + origin=SkillOrigin.IMPORTED, + generation=0, + content_snapshot=snapshot, + content_diff=content_diff, + ), + ) + self._upsert(record) + created += 1 + logger.debug( + f"sync_from_registry: created {meta.name} [{meta.skill_id}]" + ) + + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + if created or refreshed: + logger.info( + f"sync_from_registry: {created} new record(s) created, " + f"{refreshed} refreshed, " + f"{len(discovered_skills) - created - refreshed} unchanged" + ) + return created + + async def record_analysis(self, analysis: ExecutionAnalysis) -> None: + """Atomic observation: insert analysis + judgments + increment counters. + + 1. INSERT a row in ``execution_analyses`` (one per task). + 2. INSERT rows in ``skill_judgments`` for each skill assessed. + 3. For each judgment, atomically increment the matching + ``skill_records`` counters: + - total_selections += 1 (always) + - total_applied += 1 (if skill_applied) + - total_completions += 1 (if applied and completed) + - total_fallbacks += 1 (if not applied and not completed) + - last_updated = now + """ + await asyncio.to_thread(self._record_analysis_sync, analysis) + + async def evolve_skill( + self, + new_record: SkillRecord, + parent_skill_ids: List[str], + ) -> None: + """Atomic evolution: insert new version + deactivate old version. + + **FIXED** — Same-name skill fix: + - ``new_record.name`` is the same as parent + - ``new_record.path`` is the same as parent + - parent is set to ``is_active=False`` + - ``new_record.is_active=True`` + + **DERIVED** — New skill derived: + - ``new_record.name`` is a new name + - parent is kept ``is_active=True`` (it is still the latest version of its line) + - ``new_record.is_active=True`` + + In the same SQL transaction, guaranteed by ``self._mu``. + + Args: + new_record : SkillRecord + New version record, ``lineage.parent_skill_ids`` must be non-empty. + parent_skill_ids : list[str] + Parent skill_id list (FIXED exactly 1, DERIVED ≥ 1). + For FIXED, parent is automatically deactivated. + """ + await asyncio.to_thread( + self._evolve_skill_sync, new_record, parent_skill_ids + ) + + async def deactivate_record(self, skill_id: str) -> bool: + """Set a specific record's ``is_active`` to False.""" + return await asyncio.to_thread(self._deactivate_record_sync, skill_id) + + async def reactivate_record(self, skill_id: str) -> bool: + """Set a specific record's ``is_active`` to True (revert / rollback).""" + return await asyncio.to_thread(self._reactivate_record_sync, skill_id) + + async def delete_record(self, skill_id: str) -> bool: + """Delete a skill and all related data (CASCADE).""" + return await asyncio.to_thread(self._delete_record_sync, skill_id) + + # Sync write implementations (thread-safe via self._mu) + @_db_retry() + def _save_record_sync(self, record: SkillRecord) -> None: + self._ensure_open() + with self._mu: + self._conn.execute("BEGIN") + try: + self._upsert(record) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + @_db_retry() + def _save_records_sync(self, records: List[SkillRecord]) -> None: + self._ensure_open() + with self._mu: + self._conn.execute("BEGIN") + try: + for r in records: + self._upsert(r) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + @_db_retry() + def _record_analysis_sync(self, analysis: ExecutionAnalysis) -> None: + """Persist an analysis and update skill quality counters. + + ``SkillJudgment.skill_id`` is the **true skill_id** (e.g. + ``weather__imp_a1b2c3d4``), the same identifier used as the DB + primary key. The analysis LLM receives skill_ids in its prompt + and outputs them verbatim. + + We update counters via ``WHERE skill_id = ?`` — exact match, no + ambiguity. + """ + self._ensure_open() + with self._mu: + self._conn.execute("BEGIN") + try: + analysis_id = self._insert_analysis(analysis) + + now_iso = datetime.now().isoformat() + for j in analysis.skill_judgments: + applied = 1 if j.skill_applied else 0 + completed = ( + 1 + if (j.skill_applied and analysis.task_completed) + else 0 + ) + fallback = ( + 1 + if (not j.skill_applied and not analysis.task_completed) + else 0 + ) + self._conn.execute( + """ + UPDATE skill_records SET + total_selections = total_selections + 1, + total_applied = total_applied + ?, + total_completions = total_completions + ?, + total_fallbacks = total_fallbacks + ?, + last_updated = ? + WHERE skill_id = ? + """, + (applied, completed, fallback, now_iso, j.skill_id), + ) + + self._conn.commit() + except Exception: + self._conn.rollback() + raise + + @_db_retry() + def _evolve_skill_sync( + self, + new_record: SkillRecord, + parent_skill_ids: List[str], + ) -> None: + """Atomic: insert new version + deactivate parents (for FIXED).""" + self._ensure_open() + with self._mu: + self._conn.execute("BEGIN") + try: + # For FIXED: deactivate same-name parents + if new_record.lineage.origin == SkillOrigin.FIXED: + for pid in parent_skill_ids: + self._conn.execute( + "UPDATE skill_records SET is_active=0, " + "last_updated=? WHERE skill_id=?", + (datetime.now().isoformat(), pid), + ) + + # Ensure new record has parent refs set + new_record.lineage.parent_skill_ids = list(parent_skill_ids) + new_record.is_active = True + + self._upsert(new_record) + self._conn.commit() + + origin = new_record.lineage.origin.value + logger.info( + f"evolve_skill ({origin}): " + f"{new_record.name}@gen{new_record.lineage.generation} " + f"[{new_record.skill_id}] ← parents={parent_skill_ids}" + ) + except Exception: + self._conn.rollback() + raise + + @_db_retry() + def _deactivate_record_sync(self, skill_id: str) -> bool: + self._ensure_open() + with self._mu: + cur = self._conn.execute( + "UPDATE skill_records SET is_active=0, last_updated=? " + "WHERE skill_id=?", + (datetime.now().isoformat(), skill_id), + ) + self._conn.commit() + return cur.rowcount > 0 + + @_db_retry() + def _reactivate_record_sync(self, skill_id: str) -> bool: + self._ensure_open() + with self._mu: + cur = self._conn.execute( + "UPDATE skill_records SET is_active=1, last_updated=? " + "WHERE skill_id=?", + (datetime.now().isoformat(), skill_id), + ) + self._conn.commit() + return cur.rowcount > 0 + + @_db_retry() + def _delete_record_sync(self, skill_id: str) -> bool: + self._ensure_open() + with self._mu: + # ON DELETE CASCADE automatically cleans up lineage_parents / deps / tags + # skill_judgments are NOT cascade-deleted (no FK to skill_records) + cur = self._conn.execute( + "DELETE FROM skill_records WHERE skill_id=?", (skill_id,) + ) + self._conn.commit() + return cur.rowcount > 0 + + # Read API (sync, each call opens its own read-only conn) + @_db_retry() + def load_record(self, skill_id: str) -> Optional[SkillRecord]: + """Load a single :class:`SkillRecord` by id.""" + with self._reader() as conn: + row = conn.execute( + "SELECT * FROM skill_records WHERE skill_id=?", + (skill_id,), + ).fetchone() + return self._to_record(conn, row) if row else None + + @_db_retry() + def load_all( + self, *, active_only: bool = False + ) -> Dict[str, SkillRecord]: + """Load skill records, keyed by ``skill_id``. + + Args: + active_only: If True, only return records with ``is_active=True``. + """ + with self._reader() as conn: + if active_only: + rows = conn.execute( + "SELECT * FROM skill_records WHERE is_active=1" + ).fetchall() + else: + rows = conn.execute("SELECT * FROM skill_records").fetchall() + result: Dict[str, SkillRecord] = {} + for row in rows: + rec = self._to_record(conn, row) + result[rec.skill_id] = rec + logger.info(f"Loaded {len(result)} skill records (active_only={active_only})") + return result + + @_db_retry() + def load_active(self) -> Dict[str, SkillRecord]: + """Load only active skill records, keyed by ``skill_id``. + + Convenience wrapper for ``load_all(active_only=True)``. + """ + return self.load_all(active_only=True) + + @_db_retry() + def load_record_by_path(self, skill_dir: str) -> Optional[SkillRecord]: + """Load the most recent active SkillRecord whose ``path`` is inside *skill_dir*. + + Used by ``upload_skill`` to retrieve pre-computed upload metadata + (origin, parents, change_summary, etc.) from the DB when + ``.upload_meta.json`` is missing. + + The match uses ``path LIKE '{skill_dir}%'`` so both + ``/a/b/SKILL.md`` and ``/a/b/scenarios/x.md`` match ``/a/b``. + Returns the newest active record (by ``last_updated DESC``). + """ + normalized = skill_dir.rstrip("/") + with self._reader() as conn: + row = conn.execute( + "SELECT * FROM skill_records " + "WHERE path LIKE ? AND is_active=1 " + "ORDER BY last_updated DESC LIMIT 1", + (f"{normalized}%",), + ).fetchone() + return self._to_record(conn, row) if row else None + + @_db_retry() + def get_versions(self, name: str) -> List[SkillRecord]: + """Load all versions of a named skill (active + inactive), sorted by generation.""" + with self._reader() as conn: + rows = conn.execute( + "SELECT * FROM skill_records WHERE name=? " + "ORDER BY lineage_generation ASC", + (name,), + ).fetchall() + return [self._to_record(conn, r) for r in rows] + + @_db_retry() + def load_by_category( + self, category: SkillCategory, *, active_only: bool = True + ) -> List[SkillRecord]: + """Load skill records filtered by category. + + Args: + active_only: If True (default), only return active records. + """ + with self._reader() as conn: + if active_only: + rows = conn.execute( + "SELECT * FROM skill_records " + "WHERE category=? AND is_active=1", + (category.value,), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM skill_records WHERE category=?", + (category.value,), + ).fetchall() + return [self._to_record(conn, r) for r in rows] + + @_db_retry() + def load_analyses( + self, + skill_id: Optional[str] = None, + limit: int = 50, + ) -> List[ExecutionAnalysis]: + """Load recent analyses. + + Args: + skill_id: True ``skill_id`` (e.g. ``weather__imp_a1b2c3d4``). + ``skill_judgments.skill_id`` now stores the true skill_id, + so filtering uses exact match. + If None, return pure-execution analyses (no judgments). + """ + with self._reader() as conn: + if skill_id is not None: + rows = conn.execute( + "SELECT ea.* FROM execution_analyses ea " + "JOIN skill_judgments sj ON ea.id = sj.analysis_id " + "WHERE sj.skill_id = ? " + "ORDER BY ea.timestamp DESC LIMIT ?", + (skill_id, limit), + ).fetchall() + else: + rows = conn.execute( + "SELECT ea.* FROM execution_analyses ea " + "LEFT JOIN skill_judgments sj ON ea.id = sj.analysis_id " + "WHERE sj.id IS NULL " + "ORDER BY ea.timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + return [self._to_analysis(conn, r) for r in reversed(rows)] + + @_db_retry() + def load_analyses_for_task( + self, task_id: str + ) -> Optional[ExecutionAnalysis]: + """Load the analysis for a specific task, or None.""" + with self._reader() as conn: + row = conn.execute( + "SELECT * FROM execution_analyses WHERE task_id=?", + (task_id,), + ).fetchone() + return self._to_analysis(conn, row) if row else None + + @_db_retry() + def load_all_analyses(self, limit: int = 200) -> List[ExecutionAnalysis]: + """Load recent analyses across all tasks.""" + with self._reader() as conn: + rows = conn.execute( + "SELECT * FROM execution_analyses " + "ORDER BY timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + return [self._to_analysis(conn, r) for r in reversed(rows)] + + @_db_retry() + def load_evolution_candidates( + self, limit: int = 50 + ) -> List[ExecutionAnalysis]: + """Load analyses marked as evolution candidates.""" + with self._reader() as conn: + rows = conn.execute( + "SELECT * FROM execution_analyses " + "WHERE candidate_for_evolution=1 " + "ORDER BY timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + return [self._to_analysis(conn, r) for r in reversed(rows)] + + @_db_retry() + def find_skills_by_tool(self, tool_key: str) -> List[str]: + """ + Only returns active records — deactivated (superseded) versions + are excluded so that Trigger 2 never re-processes old versions. + """ + with self._reader() as conn: + rows = conn.execute( + "SELECT sd.skill_id " + "FROM skill_tool_deps sd " + "JOIN skill_records sr ON sd.skill_id = sr.skill_id " + "WHERE sd.tool_key=? AND sr.is_active=1", + (tool_key,), + ).fetchall() + return [r["skill_id"] for r in rows] + + @_db_retry() + def find_children(self, parent_skill_id: str) -> List[str]: + """Find skill_ids derived from the given parent.""" + with self._reader() as conn: + rows = conn.execute( + "SELECT skill_id FROM skill_lineage_parents " + "WHERE parent_skill_id=?", + (parent_skill_id,), + ).fetchall() + return [r["skill_id"] for r in rows] + + @_db_retry() + def count(self, *, active_only: bool = False) -> int: + """Total number of skill records.""" + with self._reader() as conn: + if active_only: + return conn.execute( + "SELECT COUNT(*) FROM skill_records WHERE is_active=1" + ).fetchone()[0] + return conn.execute( + "SELECT COUNT(*) FROM skill_records" + ).fetchone()[0] + + # Analytics / Summary + @_db_retry() + def get_summary(self, *, active_only: bool = True) -> List[Dict[str, Any]]: + """Lightweight summary of skills (no analyses/deps loaded). + + Default filters to active skills only. + """ + with self._reader() as conn: + where = "WHERE is_active=1 " if active_only else "" + rows = conn.execute( + f""" + SELECT skill_id, name, description, category, is_active, + visibility, creator_id, + lineage_origin, lineage_generation, + total_selections, total_applied, + total_completions, total_fallbacks, + first_seen, last_updated + FROM skill_records + {where} + ORDER BY last_updated DESC + """ + ).fetchall() + return [dict(r) for r in rows] + + @_db_retry() + def get_stats(self, *, active_only: bool = True) -> Dict[str, Any]: + """Aggregate statistics across skills.""" + with self._reader() as conn: + where = " WHERE is_active=1" if active_only else "" + total = conn.execute( + f"SELECT COUNT(*) FROM skill_records{where}" + ).fetchone()[0] + + by_category = { + r["category"]: r["cnt"] + for r in conn.execute( + f"SELECT category, COUNT(*) AS cnt " + f"FROM skill_records{where} GROUP BY category" + ).fetchall() + } + by_origin = { + r["lineage_origin"]: r["cnt"] + for r in conn.execute( + f"SELECT lineage_origin, COUNT(*) AS cnt " + f"FROM skill_records{where} GROUP BY lineage_origin" + ).fetchall() + } + n_analyses = conn.execute( + "SELECT COUNT(*) FROM execution_analyses" + ).fetchone()[0] + n_candidates = conn.execute( + "SELECT COUNT(*) FROM execution_analyses " + "WHERE candidate_for_evolution=1" + ).fetchone()[0] + agg = conn.execute( + f""" + SELECT SUM(total_selections) AS sel, + SUM(total_applied) AS app, + SUM(total_completions) AS comp, + SUM(total_fallbacks) AS fb + FROM skill_records{where} + """ + ).fetchone() + + # Also report total (including inactive) for context + total_all = conn.execute( + "SELECT COUNT(*) FROM skill_records" + ).fetchone()[0] + + return { + "total_skills": total, + "total_skills_all": total_all, + "by_category": by_category, + "by_origin": by_origin, + "total_analyses": n_analyses, + "evolution_candidates": n_candidates, + "total_selections": agg["sel"] or 0, + "total_applied": agg["app"] or 0, + "total_completions": agg["comp"] or 0, + "total_fallbacks": agg["fb"] or 0, + } + + @_db_retry() + def get_task_skill_summary(self, task_id: str) -> Dict[str, Any]: + """Per-task summary: task-level fields + per-skill judgments. + + Useful for understanding how multiple skills contributed to a + single task execution. + + Returns: + dict: ``{"task_id", "task_completed", "execution_note", + "tool_issues", "judgments": [{skill_id, skill_applied, note}], + ...}`` or empty dict if the task has no analysis. + """ + with self._reader() as conn: + row = conn.execute( + "SELECT * FROM execution_analyses WHERE task_id=?", + (task_id,), + ).fetchone() + if not row: + return {} + + judgment_rows = conn.execute( + "SELECT skill_id, skill_applied, note " + "FROM skill_judgments WHERE analysis_id=?", + (row["id"],), + ).fetchall() + + try: + evo_suggestions = json.loads(row["evolution_suggestions"] or "[]") + except json.JSONDecodeError: + evo_suggestions = [] + + return { + "task_id": row["task_id"], + "timestamp": row["timestamp"], + "task_completed": bool(row["task_completed"]), + "execution_note": row["execution_note"], + "tool_issues": json.loads(row["tool_issues"]), + "candidate_for_evolution": bool(row["candidate_for_evolution"]), + "evolution_suggestions": evo_suggestions, + "analyzed_by": row["analyzed_by"], + "judgments": [ + { + "skill_id": jr["skill_id"], + "skill_applied": bool(jr["skill_applied"]), + "note": jr["note"], + } + for jr in judgment_rows + ], + } + + @_db_retry() + def get_top_skills( + self, + n: int = 10, + metric: str = "effective_rate", + min_selections: int = 1, + *, + active_only: bool = True, + ) -> List[Dict[str, Any]]: + """Top-N skills ranked by the chosen metric. + + Metrics: + ``effective_rate`` — completions / selections + ``applied_rate`` — applied / selections + ``completion_rate`` — completions / applied + ``total_selections``— raw count + """ + rate_exprs = { + "effective_rate": ( + "CAST(total_completions AS REAL) / total_selections" + ), + "applied_rate": ( + "CAST(total_applied AS REAL) / total_selections" + ), + "completion_rate": ( + "CASE WHEN total_applied > 0 " + "THEN CAST(total_completions AS REAL) / total_applied " + "ELSE 0.0 END" + ), + "total_selections": "total_selections", + } + expr = rate_exprs.get(metric, rate_exprs["effective_rate"]) + active_clause = " AND is_active=1" if active_only else "" + + with self._reader() as conn: + rows = conn.execute( + f"SELECT *, ({expr}) AS _rank " + f"FROM skill_records " + f"WHERE total_selections >= ?{active_clause} " + f"ORDER BY _rank DESC LIMIT ?", + (min_selections, n), + ).fetchall() + results = [] + for r in rows: + d = dict(r) + d.pop("_rank", None) + results.append(d) + return results + + @_db_retry() + def get_count_and_timestamp( + self, *, active_only: bool = True + ) -> Dict[str, Any]: + """Skill count + newest ``last_updated`` for cheap change detection.""" + with self._reader() as conn: + where = " WHERE is_active=1" if active_only else "" + row = conn.execute( + f"SELECT COUNT(*) AS cnt, MAX(last_updated) AS max_ts " + f"FROM skill_records{where}" + ).fetchone() + return { + "count": row["cnt"] if row else 0, + "max_last_updated": row["max_ts"] if row else None, + } + + # Lineage / Ancestry + @_db_retry() + def get_ancestry( + self, skill_id: str, max_depth: int = 10 + ) -> List[SkillRecord]: + """Walk up the lineage tree; returns ancestors oldest-first.""" + with self._reader() as conn: + visited: set[str] = set() + ancestors: List[SkillRecord] = [] + frontier = [skill_id] + + for _ in range(max_depth): + next_frontier: List[str] = [] + for sid in frontier: + for pr in conn.execute( + "SELECT parent_skill_id " + "FROM skill_lineage_parents WHERE skill_id=?", + (sid,), + ).fetchall(): + pid = pr["parent_skill_id"] + if pid in visited: + continue + visited.add(pid) + row = conn.execute( + "SELECT * FROM skill_records WHERE skill_id=?", + (pid,), + ).fetchone() + if row: + ancestors.append(self._to_record(conn, row)) + next_frontier.append(pid) + frontier = next_frontier + if not frontier: + break + + ancestors.sort(key=lambda r: r.lineage.generation) + return ancestors + + @_db_retry() + def get_lineage_tree( + self, skill_id: str, max_depth: int = 5 + ) -> Dict[str, Any]: + """Build a JSON-friendly tree rooted at *skill_id* (downward).""" + with self._reader() as conn: + return self._subtree(conn, skill_id, max_depth, set()) + + def _subtree( + self, + conn: sqlite3.Connection, + sid: str, + depth: int, + visited: set, + ) -> Dict[str, Any]: + visited.add(sid) + row = conn.execute( + "SELECT skill_id, name, lineage_generation, lineage_origin, is_active " + "FROM skill_records WHERE skill_id=?", + (sid,), + ).fetchone() + node: Dict[str, Any] = { + "skill_id": sid, + "name": row["name"] if row else "?", + "generation": row["lineage_generation"] if row else -1, + "origin": row["lineage_origin"] if row else "unknown", + "is_active": bool(row["is_active"]) if row else False, + "children": [], + } + if depth <= 0: + return node + for cr in conn.execute( + "SELECT skill_id FROM skill_lineage_parents " + "WHERE parent_skill_id=?", + (sid,), + ).fetchall(): + cid = cr["skill_id"] + if cid not in visited: + node["children"].append( + self._subtree(conn, cid, depth - 1, visited) + ) + return node + + # Maintenance + def clear(self) -> None: + """Delete all data (keeps schema).""" + self._ensure_open() + with self._mu: + self._conn.execute("BEGIN") + try: + # CASCADE on skill_records cleans up: lineage_parents, tool_deps, tags + self._conn.execute("DELETE FROM skill_records") + # execution_analyses CASCADE cleans up skill_judgments + self._conn.execute("DELETE FROM execution_analyses") + self._conn.commit() + logger.info("SkillStore cleared") + except Exception: + self._conn.rollback() + raise + + def vacuum(self) -> None: + """Compact the database file.""" + self._ensure_open() + with self._mu: + self._conn.execute("VACUUM") + + # Internal: Upsert / Insert / Deserialize + def _upsert(self, record: SkillRecord) -> None: + """Insert or update skill_records + sync related rows. + + Called within a transaction holding ``self._mu``. + """ + lin = record.lineage + # content_snapshot is Dict[str, str]; store as JSON text + snapshot_json = json.dumps( + lin.content_snapshot, ensure_ascii=False + ) + self._conn.execute( + """ + INSERT INTO skill_records ( + skill_id, name, description, path, is_active, category, + visibility, creator_id, + lineage_origin, lineage_generation, + lineage_source_task_id, lineage_change_summary, + lineage_content_diff, lineage_content_snapshot, + lineage_created_at, lineage_created_by, + total_selections, total_applied, + total_completions, total_fallbacks, + first_seen, last_updated + ) VALUES (?,?,?,?,?,?, ?,?, ?,?, ?,?, ?,?, ?,?, ?,?,?,?, ?,?) + ON CONFLICT(skill_id) DO UPDATE SET + name=excluded.name, + description=excluded.description, + path=excluded.path, + is_active=excluded.is_active, + category=excluded.category, + visibility=excluded.visibility, + creator_id=excluded.creator_id, + lineage_origin=excluded.lineage_origin, + lineage_generation=excluded.lineage_generation, + lineage_source_task_id=excluded.lineage_source_task_id, + lineage_change_summary=excluded.lineage_change_summary, + lineage_content_diff=excluded.lineage_content_diff, + lineage_content_snapshot=excluded.lineage_content_snapshot, + lineage_created_at=excluded.lineage_created_at, + lineage_created_by=excluded.lineage_created_by, + total_selections=excluded.total_selections, + total_applied=excluded.total_applied, + total_completions=excluded.total_completions, + total_fallbacks=excluded.total_fallbacks, + last_updated=excluded.last_updated + """, + ( + record.skill_id, + record.name, + record.description, + record.path, + int(record.is_active), + record.category.value, + record.visibility.value, + record.creator_id, + lin.origin.value, + lin.generation, + lin.source_task_id, + lin.change_summary, + lin.content_diff, + snapshot_json, + lin.created_at.isoformat(), + lin.created_by, + record.total_selections, + record.total_applied, + record.total_completions, + record.total_fallbacks, + record.first_seen.isoformat(), + record.last_updated.isoformat(), + ), + ) + + # Sync lineage parents + self._conn.execute( + "DELETE FROM skill_lineage_parents WHERE skill_id=?", + (record.skill_id,), + ) + for pid in lin.parent_skill_ids: + self._conn.execute( + "INSERT INTO skill_lineage_parents" + "(skill_id, parent_skill_id) VALUES(?,?)", + (record.skill_id, pid), + ) + + # Sync tool dependencies + self._conn.execute( + "DELETE FROM skill_tool_deps WHERE skill_id=?", + (record.skill_id,), + ) + critical_set = set(record.critical_tools) + for tk in record.tool_dependencies: + self._conn.execute( + "INSERT INTO skill_tool_deps" + "(skill_id, tool_key, critical) VALUES(?,?,?)", + (record.skill_id, tk, 1 if tk in critical_set else 0), + ) + + # Sync tags + self._conn.execute( + "DELETE FROM skill_tags WHERE skill_id=?", + (record.skill_id,), + ) + for tag in record.tags: + self._conn.execute( + "INSERT INTO skill_tags(skill_id, tag) VALUES(?,?)", + (record.skill_id, tag), + ) + + # Sync analyses (insert only NEW ones, dedup by task_id) + for a in record.recent_analyses: + existing = self._conn.execute( + "SELECT id FROM execution_analyses WHERE task_id=?", + (a.task_id,), + ).fetchone() + if existing is None: + self._insert_analysis(a) + + def _insert_analysis(self, a: ExecutionAnalysis) -> int: + """Insert an execution_analyses row + its skill_judgments. + + Called within a transaction holding ``self._mu``. + + Returns: + int: The ``execution_analyses.id`` of the newly inserted row. + """ + cur = self._conn.execute( + """ + INSERT INTO execution_analyses ( + task_id, timestamp, + task_completed, execution_note, + tool_issues, candidate_for_evolution, + evolution_suggestions, analyzed_by, analyzed_at + ) VALUES (?,?, ?,?, ?,?, ?,?,?) + """, + ( + a.task_id, + a.timestamp.isoformat(), + int(a.task_completed), + a.execution_note, + json.dumps(a.tool_issues, ensure_ascii=False), + int(a.candidate_for_evolution), + json.dumps( + [s.to_dict() for s in a.evolution_suggestions], + ensure_ascii=False, + ), + a.analyzed_by, + a.analyzed_at.isoformat(), + ), + ) + analysis_id = cur.lastrowid + + for j in a.skill_judgments: + self._conn.execute( + "INSERT INTO skill_judgments " + "(analysis_id, skill_id, skill_applied, note) " + "VALUES (?,?,?,?)", + (analysis_id, j.skill_id, int(j.skill_applied), j.note), + ) + + return analysis_id + + # Deserialization + def _to_record( + self, conn: sqlite3.Connection, row: sqlite3.Row + ) -> SkillRecord: + """Deserialize a skill_records row + related rows → SkillRecord.""" + sid = row["skill_id"] + + parents = [ + r["parent_skill_id"] + for r in conn.execute( + "SELECT parent_skill_id " + "FROM skill_lineage_parents WHERE skill_id=?", + (sid,), + ).fetchall() + ] + + # Deserialize content_snapshot: stored as JSON dict + # mapping relative file paths to their text content + raw_snapshot = row["lineage_content_snapshot"] or "{}" + snapshot: Dict[str, str] = json.loads(raw_snapshot) + + lineage = SkillLineage( + origin=SkillOrigin(row["lineage_origin"]), + generation=row["lineage_generation"], + parent_skill_ids=parents, + source_task_id=row["lineage_source_task_id"], + change_summary=row["lineage_change_summary"], + content_diff=row["lineage_content_diff"], + content_snapshot=snapshot, + created_at=datetime.fromisoformat(row["lineage_created_at"]), + created_by=row["lineage_created_by"], + ) + + dep_rows = conn.execute( + "SELECT tool_key, critical " + "FROM skill_tool_deps WHERE skill_id=?", + (sid,), + ).fetchall() + + tag_rows = conn.execute( + "SELECT tag FROM skill_tags WHERE skill_id=?", (sid,) + ).fetchall() + + # Load recent analyses involving this skill (via skill_judgments). + # skill_judgments.skill_id stores the true skill_id (same as DB PK). + analysis_rows = conn.execute( + "SELECT ea.* FROM execution_analyses ea " + "JOIN skill_judgments sj ON ea.id = sj.analysis_id " + "WHERE sj.skill_id = ? " + "ORDER BY ea.timestamp DESC LIMIT ?", + (sid, SkillRecord.MAX_RECENT), + ).fetchall() + + return SkillRecord( + skill_id=sid, + name=row["name"], + description=row["description"], + path=row["path"], + is_active=bool(row["is_active"]), + category=SkillCategory(row["category"]), + tags=[r["tag"] for r in tag_rows], + visibility=( + SkillVisibility(row["visibility"]) + if row["visibility"] else SkillVisibility.PRIVATE + ), + creator_id=row["creator_id"] or "", + lineage=lineage, + tool_dependencies=[r["tool_key"] for r in dep_rows], + critical_tools=[ + r["tool_key"] for r in dep_rows if r["critical"] + ], + total_selections=row["total_selections"], + total_applied=row["total_applied"], + total_completions=row["total_completions"], + total_fallbacks=row["total_fallbacks"], + recent_analyses=[ + self._to_analysis(conn, r) for r in reversed(analysis_rows) + ], + first_seen=datetime.fromisoformat(row["first_seen"]), + last_updated=datetime.fromisoformat(row["last_updated"]), + ) + + @staticmethod + def _to_analysis( + conn: sqlite3.Connection, row: sqlite3.Row + ) -> ExecutionAnalysis: + """Deserialize an execution_analyses row + judgments → ExecutionAnalysis.""" + analysis_id = row["id"] + + judgment_rows = conn.execute( + "SELECT skill_id, skill_applied, note " + "FROM skill_judgments WHERE analysis_id=?", + (analysis_id,), + ).fetchall() + + suggestions: list[EvolutionSuggestion] = [] + raw_suggestions = row["evolution_suggestions"] + if raw_suggestions: + try: + suggestions = [ + EvolutionSuggestion.from_dict(s) + for s in json.loads(raw_suggestions) + ] + except (json.JSONDecodeError, KeyError, ValueError): + pass + + return ExecutionAnalysis( + task_id=row["task_id"], + timestamp=datetime.fromisoformat(row["timestamp"]), + task_completed=bool(row["task_completed"]), + execution_note=row["execution_note"], + tool_issues=json.loads(row["tool_issues"]), + skill_judgments=[ + SkillJudgment( + skill_id=jr["skill_id"], + skill_applied=bool(jr["skill_applied"]), + note=jr["note"], + ) + for jr in judgment_rows + ], + evolution_suggestions=suggestions, + analyzed_by=row["analyzed_by"], + analyzed_at=datetime.fromisoformat(row["analyzed_at"]), + ) diff --git a/openspace/skill_engine/types.py b/openspace/skill_engine/types.py new file mode 100644 index 0000000000000000000000000000000000000000..0a21c9c35c8a658d61f5a9200aadfface9cdbd32 --- /dev/null +++ b/openspace/skill_engine/types.py @@ -0,0 +1,464 @@ +"""Data types for skill quality tracking and evolution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, ClassVar, Dict, List, Optional + + +class SkillCategory(str, Enum): + """Skill primary category.""" + + TOOL_GUIDE = "tool_guide" # Tool guide + WORKFLOW = "workflow" # End-to-end workflow + REFERENCE = "reference" # Reference knowledge + + +class SkillVisibility(str, Enum): + """Cloud visibility of a skill. (`Group` is managed by the cloud platform)""" + + PRIVATE = "private" # Only visible to the creator + PUBLIC = "public" # Visible to all users on the cloud + + +class EvolutionType(str, Enum): + FIX = "fix" # Repair broken / outdated skill instructions + DERIVED = "derived" # Enhance / specialize an existing skill + CAPTURED = "captured" # Capture a novel reusable pattern + + def to_origin(self) -> "SkillOrigin": + """Convert this evolution action to the corresponding SkillOrigin.""" + return _EVOLUTION_TO_ORIGIN[self] + + +class SkillOrigin(str, Enum): + """How this skill was created / entered the system. + + Version DAG model — every change creates a new SkillRecord node:: + + Lineage rules: + IMPORTED / CAPTURED → root node, no parent (parent_skill_ids = []) + DERIVED → 1+ parents, new skill name, new directory + FIXED → exactly 1 parent (previous version of same skill), + same ``name`` & ``path``, new ``skill_id``. + Files on disk are updated in-place; old directory + content (all files) is preserved via + ``content_snapshot`` dict in the DB. + Only the latest version is ``is_active=True``. + """ + + IMPORTED = "imported" # Initial import, no parent + CAPTURED = "captured" # Captured from a successful execution with no parent skill involved + DERIVED = "derived" # Derived from existing skill(s) (upgrade, wrap, compose, etc.) + FIXED = "fixed" # Fix of existing skill — new record, parent = previous version + + +_EVOLUTION_TO_ORIGIN: Dict["EvolutionType", "SkillOrigin"] = { + EvolutionType.FIX: SkillOrigin.FIXED, + EvolutionType.DERIVED: SkillOrigin.DERIVED, + EvolutionType.CAPTURED: SkillOrigin.CAPTURED, +} + +_ORIGIN_TO_EVOLUTION: Dict["SkillOrigin", "EvolutionType"] = { + v: k for k, v in _EVOLUTION_TO_ORIGIN.items() +} + + +@dataclass +class SkillLineage: + """Tracks the evolutionary lineage of a skill. + + ``parent_skill_ids`` may contain multiple parents for DERIVED. + FIXED always has exactly one parent (the previous version). + IMPORTED / CAPTURED have no parents. + + ─── generation ───────────────────────────────────────────────── + + Distance from root in the version DAG. Set by the evolution logic + when creating a new skill record: + + - IMPORTED / CAPTURED → ``generation = 0`` (root node) + - FIXED → ``parent.generation + 1`` + - DERIVED → ``max(p.generation for p in parents) + 1`` + + ─── change_summary ───────────────────────────────────────────── + + LLM-generated free-text description of what changed vs. the parent. + Produced by the evolution LLM when creating FIXED or DERIVED skills. + Examples: + - FIXED: "Fixed curl parameter format in step 3" + - DERIVED: "Composed weather + geocoding guides into an + end-to-end location-aware forecast workflow" + - IMPORTED / CAPTURED: typically empty or a brief import note. + + ─── content_diff / content_snapshot ──────────────────────────── + + ``content_snapshot`` stores the **full directory snapshot** at this + version as a ``Dict[str, str]`` mapping relative file paths to their + text content. + + ``content_diff`` stores a combined unified diff (``git diff`` + format) covering **all** files in the skill directory. + Policy by parent count: + + - **0 parents** (IMPORTED / CAPTURED): + add-all diff — every line prefixed with ``+`` + (like ``git diff /dev/null`` for each file). + - **1 parent** (FIXED, or single-parent DERIVED): + normal unified diff between the parent's directory content + and this version's directory content, covering all files. + - **N parents** (multi-parent DERIVED): + ``""`` (empty string). A multi-parent composition is a + creative act, not a patch — per-parent diffs are large and + unhelpful. The composition intent is captured in + ``change_summary`` instead. Individual parent content can + be retrieved via ``parent_skill_ids`` → each parent's + ``content_snapshot``. + """ + + origin: SkillOrigin + generation: int = 0 # Distance from root (see docstring) + parent_skill_ids: List[str] = field(default_factory=list) # [] for IMPORTED / CAPTURED + source_task_id: Optional[str] = None # Task that triggered evolution / capture + change_summary: str = "" # LLM-generated description of changes + content_diff: str = "" # Combined unified diff of all files (empty for multi-parent DERIVED) + content_snapshot: Dict[str, str] = field(default_factory=dict) # {relative_path: content} full directory snapshot + created_at: datetime = field(default_factory=datetime.now) + created_by: str = "" # "human" | model name (version-level actor) + + def to_dict(self) -> Dict[str, Any]: + return { + "origin": self.origin.value, + "generation": self.generation, + "parent_skill_ids": self.parent_skill_ids, + "source_task_id": self.source_task_id, + "change_summary": self.change_summary, + "content_diff": self.content_diff, + "content_snapshot": self.content_snapshot, + "created_at": self.created_at.isoformat(), + "created_by": self.created_by, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkillLineage": + return cls( + origin=SkillOrigin(data["origin"]), + generation=data.get("generation", 0), + parent_skill_ids=data.get("parent_skill_ids", []), + source_task_id=data.get("source_task_id"), + change_summary=data.get("change_summary", ""), + content_diff=data.get("content_diff", ""), + content_snapshot=data.get("content_snapshot", {}), + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") else datetime.now() + ), + created_by=data.get("created_by", ""), + ) + + +# Per-skill judgment within a task analysis +@dataclass +class SkillJudgment: + """Per-skill assessment within an :class:`ExecutionAnalysis`. + + One ``ExecutionAnalysis`` (per task) contains zero or more + ``SkillJudgment`` entries — one for each skill that was selected + for that task. + + ``skill_applied`` semantics depend on skill category: + - WORKFLOW: agent followed the prescribed steps + - TOOL_GUIDE: agent used the described tool / approach + - REFERENCE: knowledge influenced agent decisions + """ + + skill_id: str + skill_applied: bool = False # Whether the skill was actually applied + note: str = "" # Per-skill observation (deviation, usage, etc.) + + def to_dict(self) -> Dict[str, Any]: + return { + "skill_id": self.skill_id, + "skill_applied": self.skill_applied, + "note": self.note, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkillJudgment": + return cls( + skill_id=data["skill_id"], + skill_applied=data.get("skill_applied", False), + note=data.get("note", ""), + ) + + +@dataclass +class EvolutionSuggestion: + """One evolution action suggested by the analysis LLM. + + ``target_skill_ids`` lists the parent skill(s) this action targets + using **true skill_id** values (e.g. ``weather__imp_a1b2c3d4``): + - FIX: exactly 1 parent (the skill to repair in-place) + - DERIVED: 1+ parents (single parent → enhance; multi → merge/fuse) + - CAPTURED: empty list (brand-new skill, no parents) + """ + + evolution_type: EvolutionType + target_skill_ids: List[str] = field(default_factory=list) # True skill_id(s) + category: Optional[SkillCategory] = None # Desired category of the result + direction: str = "" # Free-text: what to evolve / capture + + @property + def target_skill_id(self) -> str: + """Primary (or only) target skill_id. Empty string if none.""" + return self.target_skill_ids[0] if self.target_skill_ids else "" + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.evolution_type.value, + "target_skills": self.target_skill_ids, + # Keep legacy singular key for backward compat with stored analyses + "target_skill": self.target_skill_id, + "category": self.category.value if self.category else None, + "direction": self.direction, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EvolutionSuggestion": + cat = None + if data.get("category"): + try: + cat = SkillCategory(data["category"]) + except ValueError: + pass + # Support both new list format and legacy single-string format + raw_targets = data.get("target_skills") + if isinstance(raw_targets, list): + targets = [t for t in raw_targets if t] + else: + legacy = data.get("target_skill", "") + targets = [legacy] if legacy else [] + return cls( + evolution_type=EvolutionType(data["type"]), + target_skill_ids=targets, + category=cat, + direction=data.get("direction", ""), + ) + + +# Task-level execution analysis (1 per task) +@dataclass +class ExecutionAnalysis: + """LLM-produced analysis of a single task execution.""" + + task_id: str + timestamp: datetime + + # Task-level LLM judgments + task_completed: bool = False # Whether the task completed successfully + execution_note: str = "" # Task-level observation + tool_issues: List[str] = field(default_factory=list) # Tool keys that had issues + + # Per-skill judgments (one per selected skill; empty = no skill involved) + skill_judgments: List[SkillJudgment] = field(default_factory=list) + + # Evolution suggestions — 0-N per analysis, each fully specifies an action + evolution_suggestions: List[EvolutionSuggestion] = field(default_factory=list) + + # Analysis metadata + analyzed_by: str = "" # Model name used for analysis + analyzed_at: datetime = field(default_factory=datetime.now) + + def get_judgment(self, skill_id: str) -> Optional[SkillJudgment]: + """Find the judgment for a specific skill, or None.""" + for j in self.skill_judgments: + if j.skill_id == skill_id: + return j + return None + + @property + def skill_ids(self) -> List[str]: + """List of skill_ids that were judged in this analysis.""" + return [j.skill_id for j in self.skill_judgments] + + @property + def candidate_for_evolution(self) -> bool: + """Whether any evolution suggestions exist.""" + return len(self.evolution_suggestions) > 0 + + def suggestions_by_type(self, evo_type: EvolutionType) -> List[EvolutionSuggestion]: + """Filter evolution suggestions by type.""" + return [s for s in self.evolution_suggestions if s.evolution_type == evo_type] + + def to_dict(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "timestamp": self.timestamp.isoformat(), + "task_completed": self.task_completed, + "execution_note": self.execution_note, + "tool_issues": self.tool_issues, + "skill_judgments": [j.to_dict() for j in self.skill_judgments], + "evolution_suggestions": [s.to_dict() for s in self.evolution_suggestions], + "analyzed_by": self.analyzed_by, + "analyzed_at": self.analyzed_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecutionAnalysis": + return cls( + task_id=data["task_id"], + timestamp=datetime.fromisoformat(data["timestamp"]), + task_completed=data.get("task_completed", False), + execution_note=data.get("execution_note", ""), + tool_issues=data.get("tool_issues", []), + skill_judgments=[ + SkillJudgment.from_dict(j) + for j in data.get("skill_judgments", []) + ], + evolution_suggestions=[ + EvolutionSuggestion.from_dict(s) + for s in data.get("evolution_suggestions", []) + ], + analyzed_by=data.get("analyzed_by", ""), + analyzed_at=( + datetime.fromisoformat(data["analyzed_at"]) + if data.get("analyzed_at") else datetime.now() + ), + ) + + +# Full skill profile (identity + lineage + deps + quality) +@dataclass +class SkillRecord: + """Comprehensive record for a skill: identity + lineage + quality. + + This is the full profile for a skill within the quality / evolution system. + The lightweight SkillMeta is still used for discovery; SkillRecord is managed by + ExecutionAnalyzer. + """ + + skill_id: str # Unique identifier + name: str # Logical skill name (shared across versions) + description: str + path: str = "" # Path to SKILL.md (shared across FIXED versions) + + is_active: bool = True # Only the latest version is active + + # Category & tags + category: SkillCategory = SkillCategory.WORKFLOW + tags: List[str] = field(default_factory=list) # Auxiliary tags generated by LLM + + # Ownership & visibility (for cloud sync) + visibility: SkillVisibility = SkillVisibility.PRIVATE # Cloud visibility + creator_id: str = "" # User ID of the skill owner / creator + + # Lineage + lineage: SkillLineage = field( + default_factory=lambda: SkillLineage(origin=SkillOrigin.IMPORTED) + ) + + # Tool dependencies + tool_dependencies: List[str] = field(default_factory=list) # All involved tool keys + critical_tools: List[str] = field(default_factory=list) # Required (must-have) tool keys + + # Execution stats (updated by add_analysis or atomically in store) + total_selections: int = 0 # Times this skill was selected by the LLM + total_applied: int = 0 # Times the skill was actually applied by the agent + total_completions: int = 0 # Times task completed when skill was applied + total_fallbacks: int = 0 # Times skill was not applied and task failed + + # Recent analysis history (rolling window of analyses involving this skill) + recent_analyses: List[ExecutionAnalysis] = field(default_factory=list) + MAX_RECENT: ClassVar[int] = 50 + + # Metadata + first_seen: datetime = field(default_factory=datetime.now) + last_updated: datetime = field(default_factory=datetime.now) + + @property + def applied_rate(self) -> float: + """Ratio of selections where the skill was actually applied.""" + return self.total_applied / self.total_selections if self.total_selections else 0.0 + + @property + def completion_rate(self) -> float: + """Ratio of applied uses that led to task completion.""" + return self.total_completions / self.total_applied if self.total_applied else 0.0 + + @property + def effective_rate(self) -> float: + """End-to-end effectiveness: selected → applied → completed.""" + return self.total_completions / self.total_selections if self.total_selections else 0.0 + + @property + def fallback_rate(self) -> float: + """Ratio of selections that fell back (skill unusable signal).""" + return self.total_fallbacks / self.total_selections if self.total_selections else 0.0 + + # NOTE: Counter updates (total_selections, total_applied, etc.) are + # performed atomically in SQL by SkillStore.record_analysis(). + # Do NOT duplicate that logic here in Python. + + def to_dict(self) -> Dict[str, Any]: + return { + "skill_id": self.skill_id, + "name": self.name, + "description": self.description, + "path": self.path, + "is_active": self.is_active, + "category": self.category.value, + "tags": self.tags, + "visibility": self.visibility.value, + "creator_id": self.creator_id, + "lineage": self.lineage.to_dict(), + "tool_dependencies": self.tool_dependencies, + "critical_tools": self.critical_tools, + "total_selections": self.total_selections, + "total_applied": self.total_applied, + "total_completions": self.total_completions, + "total_fallbacks": self.total_fallbacks, + "recent_analyses": [a.to_dict() for a in self.recent_analyses], + "first_seen": self.first_seen.isoformat(), + "last_updated": self.last_updated.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkillRecord": + record = cls( + skill_id=data["skill_id"], + name=data["name"], + description=data.get("description", ""), + path=data.get("path", ""), + is_active=data.get("is_active", True), + category=SkillCategory(data["category"]) if data.get("category") else SkillCategory.WORKFLOW, + tags=data.get("tags", []), + visibility=( + SkillVisibility(data["visibility"]) + if data.get("visibility") else SkillVisibility.PRIVATE + ), + creator_id=data.get("creator_id", ""), + lineage=( + SkillLineage.from_dict(data["lineage"]) + if data.get("lineage") + else SkillLineage(origin=SkillOrigin.IMPORTED) + ), + tool_dependencies=data.get("tool_dependencies", []), + critical_tools=data.get("critical_tools", []), + total_selections=data.get("total_selections", 0), + total_applied=data.get("total_applied", 0), + total_completions=data.get("total_completions", 0), + total_fallbacks=data.get("total_fallbacks", 0), + first_seen=( + datetime.fromisoformat(data["first_seen"]) + if data.get("first_seen") else datetime.now() + ), + last_updated=( + datetime.fromisoformat(data["last_updated"]) + if data.get("last_updated") else datetime.now() + ), + ) + for a in data.get("recent_analyses", []): + record.recent_analyses.append(ExecutionAnalysis.from_dict(a)) + return record diff --git a/openspace/skills/README.md b/openspace/skills/README.md new file mode 100644 index 0000000000000000000000000000000000000000..15eadc7c158abbf7683db6f57be1291c515e3c21 --- /dev/null +++ b/openspace/skills/README.md @@ -0,0 +1,30 @@ +# Skills + +Place custom skills here. Each skill is a subdirectory containing a `SKILL.md`: + +``` +skills/ +├── my-skill/ +│ └── SKILL.md +└── another-skill/ + ├── SKILL.md + └── helper.sh (optional auxiliary files) +``` + +`SKILL.md` must start with YAML frontmatter containing `name` and `description`. The markdown body is the agent instruction, loaded only when selected for a task. + +## Discovery & Loading + +This directory is the **lowest-priority** skill source, always scanned at startup: + +1. `OPENSPACE_HOST_SKILL_DIRS` env (highest priority) +2. `config_grounding.json → skills.skill_dirs` +3. **This directory** (lowest priority) + +On first discovery, each skill gets a `.skill_id` sidecar file (`{name}__imp_{uuid[:8]}`) for persistent tracking across restarts. Cloud-downloaded and evolution-captured skills may also land here when no other skill directory is configured. + +Loose files like this README are safely ignored — only subdirectories with `SKILL.md` are scanned. + +## Safety + +All skills pass `check_skill_safety` before loading. Skills with dangerous patterns (prompt injection, credential exfiltration, etc.) are **blocked automatically** and logged as warnings. diff --git a/openspace/tool_layer.py b/openspace/tool_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea419f32c8c6248484a46e9fb1dfb187df372f4 --- /dev/null +++ b/openspace/tool_layer.py @@ -0,0 +1,936 @@ +from __future__ import annotations + +import asyncio +import traceback +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from openspace.agents import GroundingAgent +from openspace.llm import LLMClient +from openspace.grounding.core.grounding_client import GroundingClient +from openspace.config import get_config, load_config +from openspace.config.loader import get_agent_config +from openspace.recording import RecordingManager +from openspace.skill_engine import SkillRegistry, ExecutionAnalyzer, SkillStore +from openspace.skill_engine.evolver import SkillEvolver +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +@dataclass +class OpenSpaceConfig: + # LLM Configuration + llm_model: str = "openrouter/anthropic/claude-sonnet-4.5" + llm_enable_thinking: bool = False + llm_timeout: float = 120.0 + llm_max_retries: int = 3 + llm_rate_limit_delay: float = 0.0 + llm_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Separate models for specific tasks (None = use llm_model) + tool_retrieval_model: Optional[str] = None # Model for tool retrieval LLM filter + visual_analysis_model: Optional[str] = None # Model for visual analysis + + # Skill Engine Models — names map to class names (None = use llm_model) + skill_registry_model: Optional[str] = None # SkillRegistry: skill selection + execution_analyzer_model: Optional[str] = None # ExecutionAnalyzer: post-execution analysis + skill_evolver_model: Optional[str] = None # (future) SkillEvolver: skill evolution + + # Grounding Configuration + grounding_config_path: Optional[str] = None + grounding_max_iterations: int = 20 + grounding_system_prompt: Optional[str] = None + + # Backend Configuration + backend_scope: Optional[List[str]] = None # None = All backends ["shell", "gui", "mcp", "web", "system"] + use_clawwork_productivity: bool = False # If True, add ClawWork productivity tools (search_web, create_file, etc.) for fair comparison with ClawWork; requires livebench installed. + + # Workspace Configuration + workspace_dir: Optional[str] = None + + # Recording Configuration + enable_recording: bool = True + recording_backends: Optional[List[str]] = None + recording_log_dir: str = "./logs/recordings" + enable_screenshot: bool = False + enable_video: bool = False + enable_conversation_log: bool = True # Save LLM conversations to conversations.jsonl + + # Skill Evolution + evolution_max_concurrent: int = 3 # Max parallel evolutions per trigger + + # Logging Configuration + log_level: str = "INFO" + log_to_file: bool = False + log_file_path: Optional[str] = None + + def __post_init__(self): + """Validate configuration""" + if not self.llm_model: + raise ValueError("llm_model is required") + + logger.debug(f"OpenSpaceConfig initialized with model: {self.llm_model}") + + +class OpenSpace: + def __init__(self, config: Optional[OpenSpaceConfig] = None): + self.config = config or OpenSpaceConfig() + + self._llm_client: Optional[LLMClient] = None + self._grounding_client: Optional[GroundingClient] = None + self._grounding_config = None # GroundingConfig reference for skill settings + self._grounding_agent: Optional[GroundingAgent] = None + self._recording_manager: Optional[RecordingManager] = None + self._skill_registry: Optional[SkillRegistry] = None + self._skill_store: Optional[SkillStore] = None + self._execution_analyzer: Optional[ExecutionAnalyzer] = None + self._skill_evolver: Optional[SkillEvolver] = None + self._execution_count: int = 0 # For periodic metric-based evolution + self._last_evolved_skills: List[Dict[str, Any]] = [] # Tracks skills evolved during last execute() + + self._initialized = False + self._running = False + self._task_done = asyncio.Event() + self._task_done.set() # Initially not running, so "done" + + logger.debug("OpenSpace instance created") + + async def initialize(self) -> None: + if self._initialized: + logger.warning("OpenSpace already initialized") + return + + logger.info("Initializing OpenSpace...") + + try: + self._llm_client = LLMClient( + model=self.config.llm_model, + enable_thinking=self.config.llm_enable_thinking, + rate_limit_delay=self.config.llm_rate_limit_delay, + max_retries=self.config.llm_max_retries, + timeout=self.config.llm_timeout, + **self.config.llm_kwargs + ) + logger.info(f"✓ LLM Client: {self.config.llm_model}") + + # Load grounding config + # If custom config is provided, merge it with default configs + # load_config supports multiple files and deep merges them (later files override earlier ones) + if self.config.grounding_config_path: + from openspace.config.loader import CONFIG_DIR + from openspace.config.constants import CONFIG_GROUNDING, CONFIG_SECURITY + # Load default configs + custom config (custom values will override defaults) + grounding_config = load_config( + CONFIG_DIR / CONFIG_GROUNDING, + CONFIG_DIR / CONFIG_SECURITY, + self.config.grounding_config_path + ) + logger.info(f"Merged custom grounding config: {self.config.grounding_config_path}") + else: + # Load default configs only + grounding_config = get_config() + + # Optional: enable ClawWork productivity tools for fair benchmark comparison + if getattr(self.config, "use_clawwork_productivity", False): + shell_cfg = grounding_config.shell.model_copy( + update={ + "use_clawwork_productivity": True, + "working_dir": self.config.workspace_dir or grounding_config.shell.working_dir, + } + ) + grounding_config = grounding_config.model_copy(update={"shell": shell_cfg}) + logger.info("ClawWork productivity tools enabled (shell.working_dir used as sandbox root)") + + # Resolve backend_scope early so we can skip initializing + # providers that are not in scope (e.g. web when only shell is needed). + agent_config = get_agent_config("GroundingAgent") + _cli_max_iter = self.config.grounding_max_iterations + _default_max_iter = OpenSpaceConfig().grounding_max_iterations # dataclass default (20) + if agent_config: + cfg_max_iter = agent_config.get("max_iterations", _default_max_iter) + if _cli_max_iter != _default_max_iter: + max_iterations = _cli_max_iter + else: + max_iterations = cfg_max_iter + backend_scope = self.config.backend_scope or agent_config.get("backend_scope") or ["gui", "shell", "mcp", "web", "system"] + visual_analysis_timeout = agent_config.get("visual_analysis_timeout", 30.0) + self.config.grounding_max_iterations = max_iterations + logger.info(f"Loaded GroundingAgent config from config_agents.json (max_iterations={max_iterations}, visual_analysis_timeout={visual_analysis_timeout}s)") + else: + max_iterations = self.config.grounding_max_iterations + backend_scope = self.config.backend_scope or ["gui", "shell", "mcp", "web", "system"] + visual_analysis_timeout = 30.0 + logger.warning(f"config_agents.json not found, using default config (max_iterations={max_iterations})") + + # Filter enabled_backends in grounding config to only those in scope, + # so providers outside scope (e.g. web) are never registered/initialized. + if grounding_config.enabled_backends: + scope_set = set(backend_scope) + filtered = [ + entry for entry in grounding_config.enabled_backends + if entry.get("name", "").lower() in scope_set + ] + if len(filtered) != len(grounding_config.enabled_backends): + skipped = [ + entry.get("name") for entry in grounding_config.enabled_backends + if entry.get("name", "").lower() not in scope_set + ] + logger.info(f"Skipping backends not in scope: {skipped}") + grounding_config = grounding_config.model_copy( + update={"enabled_backends": filtered} + ) + + self._grounding_config = grounding_config + self._grounding_client = GroundingClient(config=grounding_config) + await self._grounding_client.initialize_all_providers() + + backends = list(self._grounding_client.list_providers().keys()) + logger.info(f"✓ Grounding Client: {len(backends)} backends") + logger.debug(f" Available backends: {[b.value for b in backends]}") + + if self.config.enable_recording: + self._recording_manager = RecordingManager( + enabled=True, + task_id="", + log_dir=self.config.recording_log_dir, + backends=self.config.recording_backends, + enable_screenshot=self.config.enable_screenshot, + enable_video=self.config.enable_video, + enable_conversation_log=self.config.enable_conversation_log, + agent_name="OpenSpace", + ) + # Inject recording_manager to grounding_client for GUI intermediate steps + self._grounding_client.recording_manager = self._recording_manager + self._recording_manager.register_to_llm(self._llm_client) + logger.info(f"✓ Recording enabled: {len(self._recording_manager.backends or [])} backends") + + # Create separate LLM client for tool retrieval if configured + # Inherits llm_kwargs (api_key, api_base, etc.) so credentials + # from the host agent are shared across all internal LLM clients. + tool_retrieval_llm = None + if self.config.tool_retrieval_model: + tool_retrieval_llm = LLMClient( + model=self.config.tool_retrieval_model, + timeout=self.config.llm_timeout, + max_retries=self.config.llm_max_retries, + **self.config.llm_kwargs, + ) + logger.info(f"✓ Tool retrieval LLM: {self.config.tool_retrieval_model}") + + self._grounding_agent = GroundingAgent( + name="OpenSpace-GroundingAgent", + backend_scope=backend_scope, + llm_client=self._llm_client, + grounding_client=self._grounding_client, + recording_manager=self._recording_manager, + system_prompt=self.config.grounding_system_prompt, + max_iterations=max_iterations, + visual_analysis_timeout=visual_analysis_timeout, + tool_retrieval_llm=tool_retrieval_llm, + visual_analysis_model=self.config.visual_analysis_model, + ) + logger.info(f"✓ GroundingAgent: {', '.join(backend_scope)}") + + # Initialize SkillRegistry (settings from config_grounding.json → skills) + if self._grounding_config and self._grounding_config.skills.enabled: + self._skill_registry = self._init_skill_registry() + if self._skill_registry: + skills = self._skill_registry.list_skills() + logger.info(f"✓ Skills: {len(skills)} discovered") + self._grounding_agent.set_skill_registry(self._skill_registry) + + # Initialize ExecutionAnalyzer (requires recording + skills) + if self.config.enable_recording and self._skill_registry: + try: + skill_store = SkillStore() + self._skill_store = skill_store # Expose for MCP server reuse + + # Sync filesystem skills → DB (creates initial records + # for newly discovered skills so that analysis stats + # can be recorded against them from the very first run). + await skill_store.sync_from_registry( + self._skill_registry.list_skills() + ) + + # Bridge: pass quality_manager so analysis can feed back + # LLM-identified tool issues to the tool quality system. + quality_mgr = ( + self._grounding_client.quality_manager + if self._grounding_client else None + ) + self._execution_analyzer = ExecutionAnalyzer( + store=skill_store, + llm_client=self._llm_client, + model=self.config.execution_analyzer_model, + skill_registry=self._skill_registry, + quality_manager=quality_mgr, + ) + logger.info("✓ Execution analysis enabled") + + # Share store with GroundingAgent so retrieve_skill + # can access quality metrics for LLM selection. + self._grounding_agent._skill_store = skill_store + + # Initialize SkillEvolver (reuses the same store & registry) + # available_tools will be updated before each evolution cycle + self._skill_evolver = SkillEvolver( + store=skill_store, + registry=self._skill_registry, + llm_client=self._llm_client, + model=self.config.skill_evolver_model, + max_concurrent=self.config.evolution_max_concurrent, + ) + logger.info( + f"✓ Skill evolution enabled " + f"(concurrent={self.config.evolution_max_concurrent})" + ) + except Exception as e: + logger.warning(f"Execution analyzer init failed (non-fatal): {e}") + + self._initialized = True + logger.info("="*60) + logger.info("OpenSpace ready to use!") + logger.info("="*60) + + except Exception as e: + logger.error(f"Failed to initialize OpenSpace: {e}") + await self.cleanup() + raise + + async def execute( + self, + task: str, + context: Optional[Dict[str, Any]] = None, + workspace_dir: Optional[str] = None, + max_iterations: Optional[int] = None, + task_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Execute a task with OpenSpace. + + Args: + task: Task instruction + context: Additional context + workspace_dir: Working directory + max_iterations: Max iterations override + task_id: External task ID for recording/logging. If None, generates a random one. + This allows external callers (e.g., OSWorld) to specify their own task ID + so recordings can be easily matched with benchmark results. + """ + if not self._initialized: + raise RuntimeError( + "OpenSpace not initialized. " + "Call await tool_layer.initialize() first or use async with." + ) + + _TASK_WAIT_TIMEOUT = 660 # slightly longer than MCP tool timeout (600s) + if self._running: + logger.info( + "OpenSpace is busy — waiting up to %ds for the current task to finish...", + _TASK_WAIT_TIMEOUT, + ) + try: + await asyncio.wait_for( + self._task_done.wait(), timeout=_TASK_WAIT_TIMEOUT + ) + except asyncio.TimeoutError: + raise RuntimeError( + f"OpenSpace is still running after waiting {_TASK_WAIT_TIMEOUT}s. " + "Please try again later." + ) + + logger.info("="*60) + logger.info(f"Task: {task[:100]}...") + logger.info("="*60) + + self._running = True + self._task_done.clear() + self._last_evolved_skills = [] # Reset per-execution tracking + start_time = asyncio.get_event_loop().time() + # Use external task_id if provided, otherwise generate one + if task_id is None: + task_id = f"task_{uuid.uuid4().hex[:12]}" + logger.info(f"Task ID: {task_id}") + + # Populated inside the try block; used by finally for analysis + result: Dict[str, Any] = {} + + try: + execution_context = context or {} + execution_context["task_id"] = task_id + execution_context["instruction"] = task + + if max_iterations is not None: + execution_context["max_iterations"] = max_iterations + + if self._recording_manager: + if self._recording_manager.recording_status: + await self._recording_manager.stop() + logger.debug("Stopped previous recording session") + + self._recording_manager.task_id = task_id + await self._recording_manager.start() + await self._recording_manager.add_metadata("instruction", task) + logger.info(f"Recording started: {task_id}") + + if workspace_dir: + execution_context["workspace_dir"] = workspace_dir + logger.info(f"Workspace: {workspace_dir}") + elif self.config.workspace_dir: + execution_context["workspace_dir"] = self.config.workspace_dir + logger.info(f"Workspace: {self.config.workspace_dir}") + elif self._recording_manager and self._recording_manager.trajectory_dir: + execution_context["workspace_dir"] = self._recording_manager.trajectory_dir + logger.info(f"Workspace: {execution_context['workspace_dir']}") + else: + import tempfile + from pathlib import Path + workspace = Path(tempfile.gettempdir()) / "openspace_workspace" / task_id + workspace.mkdir(parents=True, exist_ok=True) + execution_context["workspace_dir"] = str(workspace) + logger.info(f"Workspace: {execution_context['workspace_dir']}") + + # Update Shell session's default_working_dir so that + # productivity tools (create_file, create_video) write to the + # correct task workspace instead of the global CWD. + resolved_ws = execution_context["workspace_dir"] + try: + from openspace.grounding.core.types import BackendType as _BT + shell_prov = self._grounding_client._registry.get(_BT.SHELL) + for sess in shell_prov._sessions.values(): + sess.default_working_dir = resolved_ws + except Exception: + pass + + # Resolve iteration budget: use the larger of the caller's value + # and the configured value so external callers can't accidentally + # starve the agent with a too-low budget. + configured_max = self.config.grounding_max_iterations + if max_iterations: + max_iterations = max(max_iterations, configured_max) + else: + max_iterations = configured_max + + # Two-phase execution: Skill-First → Tool-Fallback + has_skills = False + + # Phase 1: Skill-guided execution + if self._skill_registry: + has_skills = await self._select_and_inject_skills(task) + + if has_skills: + logger.info( + f"[Phase 1 — Skill] Executing with skill guidance " + f"(max {max_iterations} iterations)..." + ) + execution_context_p1 = {**execution_context} + execution_context_p1["max_iterations"] = max_iterations + + # Snapshot workspace files before skill-guided execution + workspace_path = execution_context.get("workspace_dir", "") + pre_skill_files: set = set() + if workspace_path: + try: + from pathlib import Path as _P + pre_skill_files = { + f.name for f in _P(workspace_path).iterdir() + } if _P(workspace_path).exists() else set() + except Exception: + pass + + # Capture skill IDs before they get cleared + injected_skill_ids = list(self._grounding_agent._active_skill_ids) + skill_phase_result = await self._grounding_agent.process(execution_context_p1) + skill_status = skill_phase_result.get("status", "unknown") + skill_iterations = skill_phase_result.get("iterations", 0) + + # Clear skill context regardless of outcome + self._grounding_agent.clear_skill_context() + + if skill_status == "success": + result = skill_phase_result + result["active_skills"] = injected_skill_ids + logger.info( + f"[Phase 1 — Skill] Completed successfully " + f"({skill_iterations} iterations)" + ) + else: + # Skill failed — fall back to pure tool execution. + # Fallback gets the full budget because we clean the + # workspace below — it starts completely from scratch + # with no skill context and no leftover artifacts. + logger.warning( + f"[Phase 1 — Skill] {skill_status} after {skill_iterations} iterations, " + f"falling back to tool-only execution " + f"(budget: {max_iterations})" + ) + + # Clean up workspace artifacts created by the failed + # skill-guided phase so the fallback starts fresh. + if workspace_path: + try: + import shutil + from pathlib import Path as _P + ws = _P(workspace_path) + removed = 0 + if ws.exists(): + for f in list(ws.iterdir()): + if f.name not in pre_skill_files: + if f.is_dir(): + shutil.rmtree(f, ignore_errors=True) + else: + f.unlink(missing_ok=True) + removed += 1 + if removed: + logger.info( + f"[Phase 2 — Fallback] Cleaned {removed} artifact(s) " + f"from failed skill-guided phase" + ) + except Exception as e: + logger.debug(f"Workspace cleanup failed: {e}") + + execution_context_p2 = {**execution_context} + execution_context_p2["max_iterations"] = max_iterations + + result = await self._grounding_agent.process(execution_context_p2) + result["active_skills"] = injected_skill_ids + logger.info( + f"[Phase 2 — Fallback] {result.get('status', 'unknown')} " + f"({result.get('iterations', 0)} iterations)" + ) + else: + # No skills matched — standard tool-only execution + logger.info( + f"Executing with GroundingAgent " + f"(max {max_iterations} iterations, no skills)..." + ) + result = await self._grounding_agent.process(execution_context) + + execution_time = asyncio.get_event_loop().time() - start_time + + status = result.get('status', 'unknown') + iterations = result.get('iterations', 0) + tool_count = len(result.get('tool_executions', [])) + + logger.info("="*60) + if status == "success": + logger.info( + f"Task completed successfully! " + f"({iterations} iterations, {tool_count} tool calls, {execution_time:.2f}s)" + ) + elif status == "incomplete": + logger.warning( + f"Task incomplete after {iterations} iterations. " + f"Consider increasing max_iterations." + ) + else: + logger.error(f"Task failed: {result.get('error', 'Unknown error')}") + logger.info("="*60) + + except Exception as e: + execution_time = asyncio.get_event_loop().time() - start_time + tb = traceback.format_exc(limit=10) + logger.error(f"Task execution failed: {e}", exc_info=True) + + result = { + "status": "error", + "error": str(e), + "traceback": tb, + "response": f"Task execution error: {str(e)}", + "execution_time": execution_time, + "task_id": task_id, + "iterations": 0, + "tool_executions": [], + } + + finally: + recording_dir = None + if self._recording_manager and self._recording_manager.recording_status: + recording_dir = self._recording_manager.trajectory_dir + + # Persist execution outcome to metadata.json before finalizing + try: + exec_time = asyncio.get_event_loop().time() - start_time + await self._recording_manager.save_execution_outcome( + status=result.get("status", "unknown"), + iterations=result.get("iterations", 0), + execution_time=exec_time, + ) + except Exception: + pass # best-effort; don't block recording stop + + try: + await self._recording_manager.stop() + logger.debug(f"Recording stopped: {task_id}") + except Exception as e: + logger.warning(f"Failed to stop recording: {e}") + + # Run execution analysis + evolution BEFORE building the return + # value, so evolved_skills is populated. + await self._maybe_analyze_execution( + task_id, recording_dir, result + ) + + # Trigger quality evolution periodically + await self._maybe_evolve_quality() + + final_result = { + **result, + "task_id": task_id, + "execution_time": execution_time, + "skills_used": result.get("active_skills", []), + "evolved_skills": list(self._last_evolved_skills), + } + + self._running = False + self._task_done.set() + + return final_result + + # Skills helpers + def _init_skill_registry(self) -> Optional[SkillRegistry]: + """Build and populate the SkillRegistry from configured directories. + + Discovery order (earlier wins on name collision): + 1. ``OPENSPACE_HOST_SKILL_DIRS`` env — host agent skill directories + 2. ``config_grounding.json → skills.skill_dirs`` — user-specified + 3. ``openspace/skills/`` — built-in skills (always present) + + ``OPENSPACE_HOST_SKILL_DIRS`` is also handled by ``mcp_server.py`` + for the MCP transport path, but we process it here too so that + standalone mode (``python -m openspace``) gets the same skills + discovered and synced to the DB for quality tracking / evolution. + """ + skill_paths: List[Path] = [] + skill_cfg = self._grounding_config.skills if self._grounding_config else None + + # 1. Host agent skill directories from env (standalone mode support) + import os + host_dirs_raw = os.environ.get("OPENSPACE_HOST_SKILL_DIRS", "") + if host_dirs_raw: + for d in host_dirs_raw.split(","): + d = d.strip() + if not d: + continue + p = Path(d) + if p.exists(): + skill_paths.append(p) + logger.info(f"Host skill dir (from env): {p}") + else: + logger.warning(f"Host skill dir does not exist: {d}") + + # 2. User-specified skill directories from config_grounding.json + if skill_cfg and skill_cfg.skill_dirs: + for d in skill_cfg.skill_dirs: + p = Path(d) + if p in skill_paths: + continue # Already added via OPENSPACE_HOST_SKILL_DIRS + if p.exists(): + skill_paths.append(p) + else: + logger.warning(f"Configured skill dir does not exist: {d}") + + # 3. Built-in skills (openspace/skills/) + builtin_skills = Path(__file__).resolve().parent / "skills" + if builtin_skills.exists(): + skill_paths.append(builtin_skills) + + if not skill_paths: + logger.debug("No skill directories found, skills disabled") + return None + + registry = SkillRegistry(skill_dirs=skill_paths) + registry.discover() + return registry + + async def _select_and_inject_skills( + self, + task: str, + ) -> bool: + """Select skills for task via LLM, inject into GroundingAgent. + + When the registry has many skills, a BM25 + embedding pre-filter + narrows the candidate set before LLM selection (see + ``SkillRegistry.select_skills_with_llm``). + + Only selected skills are injected (full SKILL.md content). + Returns True if at least one active skill was injected. + """ + if not self._skill_registry or not self._grounding_agent: + return False + + selection_record = None + + # LLM-based skill selection (preferred) + skill_cfg = self._grounding_config.skills if self._grounding_config else None + max_select = skill_cfg.max_select if skill_cfg else 2 + skill_llm = self._get_skill_selection_llm() + + # Fetch quality metrics so the selector can filter/annotate + skill_quality: Optional[Dict[str, Dict[str, Any]]] = None + if self._skill_store: + try: + rows = self._skill_store.get_summary(active_only=True) + skill_quality = { + r["skill_id"]: { + "total_selections": r.get("total_selections", 0), + "total_applied": r.get("total_applied", 0), + "total_completions": r.get("total_completions", 0), + "total_fallbacks": r.get("total_fallbacks", 0), + } + for r in rows + } + except Exception as e: + logger.debug(f"Could not load skill quality metrics: {e}") + + if skill_llm: + selected, selection_record = await self._skill_registry.select_skills_with_llm( + task, + llm_client=skill_llm, + max_skills=max_select, + skill_quality=skill_quality, + ) + else: + # No LLM client — skip skill selection entirely + logger.info("No LLM client available for skill selection — proceeding without skills") + selected = [] + selection_record = { + "method": "no_llm", + "task": task[:500], + "available_skills": [s.skill_id for s in self._skill_registry.list_skills()], + "selected": [], + } + + # Record skill selection to metadata.json + if self._recording_manager and selection_record: + # Add model info to the record + selection_record["model"] = skill_llm.model if skill_llm else "keyword_only" + await RecordingManager.record_skill_selection(selection_record) + + if not selected: + self._grounding_agent.clear_skill_context() + return False + + # Inject active skills (full SKILL.md content, backend-aware) + agent_backends = self._grounding_agent.backend_scope if self._grounding_agent else None + context_text = self._skill_registry.build_context_injection(selected, backends=agent_backends) + skill_ids = [s.skill_id for s in selected] + self._grounding_agent.set_skill_context(context_text, skill_ids) + logger.info(f"Injected {len(selected)} active skill(s): {skill_ids}") + + return True + + def _get_skill_selection_llm(self) -> Optional[LLMClient]: + """Get the LLM client to use for skill selection. + + Priority: config.skill_registry_model > tool_retrieval_model > llm_model. + """ + # 1. Dedicated skill selection model (OpenSpaceConfig.skill_registry_model) + if self.config.skill_registry_model: + return LLMClient( + model=self.config.skill_registry_model, + timeout=30.0, # skill selection should be fast + max_retries=2, + **self.config.llm_kwargs, + ) + + # 2. Tool retrieval model + if hasattr(self._grounding_agent, '_tool_retrieval_llm') and self._grounding_agent._tool_retrieval_llm: + return self._grounding_agent._tool_retrieval_llm + + # 3. Main LLM client + return self._llm_client + + async def _maybe_analyze_execution( + self, + task_id: str, + recording_dir: Optional[str], + execution_result: Dict[str, Any], + ) -> None: + """Run post-execution analysis if enabled. + + Trigger 1: if the analysis produces evolution suggestions, the + SkillEvolver processes them immediately (FIX / DERIVED / CAPTURED). + Evolved skills are recorded in ``_last_evolved_skills`` so the + caller (MCP ``execute_task``) can include them in the response. + """ + if not self._execution_analyzer or not recording_dir: + return + try: + # Pass the agent's tools so the analyzer can reuse them + # for error reproduction / verification when needed. + agent_tools = getattr( + self._grounding_agent, "_last_tools", [] + ) if self._grounding_agent else [] + + analysis = await self._execution_analyzer.analyze_execution( + task_id=task_id, + recording_dir=recording_dir, + execution_result=execution_result, + available_tools=agent_tools, + ) + if not analysis: + return + + # Trigger 1: post-analysis evolution + if analysis.candidate_for_evolution and self._skill_evolver: + self._skill_evolver.set_available_tools(agent_tools) + + evo_summary = ", ".join( + f"{s.evolution_type.value}({'+'.join(s.target_skill_ids) or 'new'})" + for s in analysis.evolution_suggestions + ) + logger.info(f"[Skill Evolution] Suggestions: {evo_summary}") + evolved_records = await self._skill_evolver.process_analysis(analysis) + + # Track evolved skills for the caller + for rec in evolved_records: + self._last_evolved_skills.append({ + "skill_id": rec.skill_id, + "name": rec.name, + "description": rec.description, + "path": str(rec.path) if rec.path else "", + "origin": rec.lineage.origin.value, + "generation": rec.lineage.generation, + "parent_skill_ids": rec.lineage.parent_skill_ids, + "change_summary": rec.lineage.change_summary, + }) + + except Exception as e: + # Analysis failure must never break the main execution flow + logger.debug(f"Execution analysis skipped: {e}") + + async def _maybe_evolve_quality(self) -> None: + """Trigger quality evolution based on global execution count. + + Includes three sub-triggers: + - Tool quality evolution (ToolQualityManager) + - Trigger 2: tool degradation → fix related skills + - Trigger 3: periodic skill metric check + + Triggers 2 and 3 are always launched as background tasks so they + never block the main execute() flow. They are awaited on shutdown + via ``cleanup() → evolver.wait_background()``. + """ + self._execution_count += 1 + quality_mgr = ( + self._grounding_client.quality_manager + if self._grounding_client else None + ) + + # Ensure evolver has up-to-date tools for agent loop + if self._skill_evolver and self._grounding_agent: + agent_tools = getattr(self._grounding_agent, "_last_tools", []) + if agent_tools: + self._skill_evolver.set_available_tools(agent_tools) + + # Tool quality evolution + if quality_mgr and quality_mgr.should_evolve(): + try: + report = await self._grounding_client.evolve_quality() + if report.get("recommendations"): + logger.info(f"Quality evolution: {report['recommendations']}") + + # Trigger 2: tool degradation → fix skills that depend on bad tools + if self._skill_evolver: + problematic = quality_mgr.get_problematic_tools() + if problematic: + logger.info( + f"[Trigger:tool_degradation] {len(problematic)} " + f"problematic tool(s) detected" + ) + self._skill_evolver.schedule_background( + self._skill_evolver.process_tool_degradation(problematic), + label="trigger2_tool_degradation", + ) + + except Exception as e: + logger.debug(f"Quality evolution skipped: {e}") + + # Trigger 3: periodic skill metric check (every 5 executions) + if self._skill_evolver and self._execution_count % 5 == 0: + try: + self._skill_evolver.schedule_background( + self._skill_evolver.process_metric_check(), + label="trigger3_metric_check", + ) + except Exception as e: + logger.debug(f"Skill metric check skipped: {e}") + + async def cleanup(self) -> None: + """ + Close all sessions and release resources. + Automatically called when using context manager. + """ + logger.info("Cleaning up OpenSpace resources...") + + try: + # Wait for background evolution tasks before tearing down + if self._skill_evolver: + await self._skill_evolver.wait_background() + + if self._grounding_client: + await self._grounding_client.close_all_sessions() + logger.debug("All grounding sessions closed") + + if self._recording_manager and self._recording_manager.recording_status: + try: + await self._recording_manager.stop() + logger.debug("Recording manager stopped") + except Exception as e: + logger.warning(f"Failed to stop recording: {e}") + + if self._execution_analyzer: + try: + self._execution_analyzer.close() + logger.debug("Execution analyzer closed") + except Exception as e: + logger.debug(f"Failed to close execution analyzer: {e}") + + self._initialized = False + self._running = False + self._task_done.set() + + logger.info("OpenSpace cleanup complete") + + except Exception as e: + logger.error(f"Error during cleanup: {e}", exc_info=True) + + def is_initialized(self) -> bool: + return self._initialized + + def is_running(self) -> bool: + return self._running + + def get_config(self) -> OpenSpaceConfig: + return self.config + + def list_backends(self) -> List[str]: + if not self._initialized: + raise RuntimeError("OpenSpace not initialized") + return [backend.value for backend in self._grounding_client.list_providers().keys()] + + def list_sessions(self) -> List[str]: + if not self._initialized: + raise RuntimeError("OpenSpace not initialized") + return self._grounding_client.list_sessions() + + async def __aenter__(self): + """Context manager entry""" + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + await self.cleanup() + return False + + def __repr__(self) -> str: + status = "initialized" if self._initialized else "not initialized" + if self._running: + status = "running" + backends = ", ".join(self.config.backend_scope) if self.config.backend_scope else "all" + return f"" \ No newline at end of file diff --git a/openspace/utils/cli_display.py b/openspace/utils/cli_display.py new file mode 100644 index 0000000000000000000000000000000000000000..55720121f9cacacb7342ab00a83c3c0158a049ca --- /dev/null +++ b/openspace/utils/cli_display.py @@ -0,0 +1,220 @@ +"""CLI Display utilities for OpenSpace startup and interaction""" + +from openspace.tool_layer import OpenSpaceConfig +from openspace.utils.display import Box, BoxStyle, colorize + + +class CLIDisplay: + @staticmethod + def print_banner(): + box = Box(width=70, style=BoxStyle.ROUNDED, color='c') + + print() + print(box.top_line(indent=4)) + print(box.empty_line(indent=4)) + + title = colorize("OpenSpace", 'c', bold=True) + print(box.text_line(title, align='center', indent=4, text_color='')) + + subtitle = "Self-Evolving Skill Worker & Community" + print(box.text_line(subtitle, align='center', indent=4, text_color='gr')) + + print(box.empty_line(indent=4)) + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_configuration(config: OpenSpaceConfig): + box = Box(width=70, style=BoxStyle.ROUNDED, color='bl') + + print(box.text_line(colorize("◉ System Configuration", 'c', bold=True), align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + configs = [ + ("AI Model", config.llm_model, 'bl'), + ("Max Iterations", str(config.grounding_max_iterations), 'c'), + ("LLM Timeout", f"{config.llm_timeout}s", 'c'), + ] + + for label, value, color in configs: + line = f" {label:20s} {colorize(value, color)}" + print(box.text_line(line, indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_initialization_progress(steps: list, show_header: bool = True): + box = Box(width=70, style=BoxStyle.ROUNDED, color='g') + + if show_header: + print(box.text_line(colorize("► Initializing Components", 'g', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + for step, status in steps: + if status == "ok": + icon = colorize("✓", 'g') + elif status == "error": + icon = colorize("✗", 'rd') + else: + icon = colorize("[...]", 'y') + + line = f" {icon} {step}" + print(box.text_line(line, indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_result_summary(result: dict): + box = Box(width=70, style=BoxStyle.ROUNDED, color='c') + + print() + print(box.text_line(colorize("◈ Execution Summary", 'c', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + status = result.get("status", "unknown") + status_colors = { + "completed": 'g', + "timeout": 'y', + "error": 'rd', + "max_iterations_reached": 'y', + } + status_color = status_colors.get(status, 'gr') + status_display = colorize(status.upper(), status_color, bold=True) + + exec_time = result.get('execution_time', 0) + result_lines = [ + f" Status: {status_display}", + f" Execution Time: {colorize(f'{exec_time:.2f}s', 'c')}", + f" Iterations: {colorize(str(result.get('iterations', 0)), 'y')}", + f" Completed Tasks: {colorize(str(result.get('completed_tasks', 0)), 'g')}", + ] + + if result.get('evaluation_results'): + result_lines.append(f" Evaluations: {colorize(str(len(result['evaluation_results'])), 'bl')}") + + for line in result_lines: + print(box.text_line(line, indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + # Print user response (the actual answer/result) + if result.get('user_response'): + response_box = Box(width=70, style=BoxStyle.ROUNDED, color='g') + print(response_box.text_line(colorize("◈ Result", 'g', bold=True), + align='center', indent=4, text_color='')) + print(response_box.separator_line(indent=4)) + + user_response = result['user_response'] + for line in user_response.split('\n'): + if line.strip(): + display_line = f" {line.strip()}" + print(response_box.text_line(display_line, indent=4, text_color='')) + + print(response_box.bottom_line(indent=4)) + print() + + @staticmethod + def print_interactive_header(): + box = Box(width=70, style=BoxStyle.ROUNDED, color='c') + + print(box.text_line(colorize("⌨ Interactive Mode", 'c', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + help_lines = [ + "", + colorize(" Ready to execute your tasks!", 'g'), + "", + colorize(" Available Commands:", 'c', bold=True), + " " + colorize("status", 'bl') + " → View system status", + " " + colorize("help", 'bl') + " → Show available commands", + " " + colorize("quit", 'bl') + " → Exit interactive mode", + "", + colorize(" ▸ Enter your task description below:", 'gr'), + "", + ] + + for line in help_lines: + print(box.text_line(line, indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_task_header(query: str, title: str = "▶ Executing Task"): + box = Box(width=70, style=BoxStyle.ROUNDED, color='g') + print() + print(box.text_line(colorize(title, 'g', bold=True), align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + print(box.text_line("", indent=4, text_color='')) + print(box.text_line(f" {query}", indent=4, text_color='')) + print(box.text_line("", indent=4, text_color='')) + print(box.bottom_line(indent=4)) + + @staticmethod + def print_system_ready(): + box = Box(width=70, style=BoxStyle.ROUNDED, color='g') + print(box.text_line(colorize("◈ System Ready", 'g', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + print(box.text_line("", indent=4, text_color='')) + print(box.text_line(colorize(" Real-time UI will display:", 'c'), indent=4, text_color='')) + print(box.text_line(" § Agent activities and status", indent=4, text_color='')) + print(box.text_line(" ⊕ Grounding backend operations", indent=4, text_color='')) + print(box.text_line(" ⊞ Execution logs", indent=4, text_color='')) + print(box.text_line("", indent=4, text_color='')) + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_status(agent): + box = Box(width=70, style=BoxStyle.ROUNDED, color='bl') + print() + print(box.text_line(colorize("System Status", 'bl', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + status = agent.get_status() + status_lines = [ + f"Initialized: {colorize('Yes' if status['initialized'] else 'No', 'g' if status['initialized'] else 'rd')}", + f"Running: {colorize('Yes' if status['running'] else 'No', 'y' if status['running'] else 'g')}", + ] + + if "agents" in status: + status_lines.append(f"Agents: {colorize(', '.join(status['agents']), 'c')}") + + for line in status_lines: + print(box.text_line(line, indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + @staticmethod + def print_help(): + box = Box(width=70, style=BoxStyle.ROUNDED, color='y') + print() + print(box.text_line(colorize("Available Commands", 'y', bold=True), + align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + help_items = [ + (colorize("status", 'c'), "Show system status"), + (colorize("help", 'c'), "Show this help message"), + (colorize("quit/exit", 'c'), "Exit interactive mode"), + ("", ""), + (colorize("Other input", 'gr'), "Execute as task"), + ] + + for cmd, desc in help_items: + if cmd: + print(box.text_line(f" {cmd:20s} {desc}", indent=4, text_color='')) + else: + print(box.separator_line(indent=4)) + + print(box.bottom_line(indent=4)) + print() \ No newline at end of file diff --git a/openspace/utils/display.py b/openspace/utils/display.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd0742ac0ed887d102cb5557b4dfdd1d88f2418 --- /dev/null +++ b/openspace/utils/display.py @@ -0,0 +1,235 @@ +from typing import Optional, List +from enum import Enum +import re + + +class Colors: + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + CYAN = "\033[96m" + WHITE = "\033[97m" + GRAY = "\033[90m" + + GREEN_SOFT = '\033[38;5;78m' + BLUE_SOFT = '\033[38;5;39m' + CYAN_SOFT = '\033[38;5;51m' + YELLOW_SOFT = '\033[38;5;222m' + RED_SOFT = '\033[38;5;204m' + MAGENTA_SOFT = '\033[38;5;141m' + GRAY_SOFT = '\033[38;5;246m' + + +class BoxStyle(Enum): + ROUNDED = "rounded" # Rounded corner box ╭─╮╰╯ + SQUARE = "square" # Square corner box ┌─┐└┘ + DOUBLE = "double" # Double line box ╔═╗╚╝ + SIMPLE = "simple" # Simple box === + + +BOX_CHARS = { + BoxStyle.ROUNDED: { + 'tl': '╭', 'tr': '╮', 'bl': '╰', 'br': '╯', + 'h': '─', 'v': '│' + }, + BoxStyle.SQUARE: { + 'tl': '┌', 'tr': '┐', 'bl': '└', 'br': '┘', + 'h': '─', 'v': '│' + }, + BoxStyle.DOUBLE: { + 'tl': '╔', 'tr': '╗', 'bl': '╚', 'br': '╝', + 'h': '═', 'v': '║' + }, +} + + +def strip_ansi(text: str) -> str: + """ + Strip ANSI color codes from text + + Args: + text: Text with potential ANSI codes + + Returns: + Clean text without ANSI codes + """ + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + return ansi_escape.sub('', text) + + +def colorize(text: str, color: str = '', bold: bool = False) -> str: + try: + color_map = { + 'r': Colors.RESET, + 'b': Colors.BOLD, + 'd': Colors.DIM, + 'g': Colors.GREEN_SOFT, + 'bl': Colors.BLUE_SOFT, + 'c': Colors.CYAN_SOFT, + 'y': Colors.YELLOW_SOFT, + 'rd': Colors.RED_SOFT, + 'm': Colors.MAGENTA_SOFT, + 'gr': Colors.GRAY_SOFT, + } + + prefix = Colors.BOLD if bold else '' + code = color_map.get(color, color) + return f"{prefix}{code}{text}{Colors.RESET}" + except: + return text + + +class Box: + def __init__(self, + width: int = 68, + style: BoxStyle = BoxStyle.ROUNDED, + color: str = 'bl', + padding: int = 2): + + self.width = width + self.style = style + self.color = color + self.padding = padding + self.chars = BOX_CHARS.get(style, BOX_CHARS[BoxStyle.ROUNDED]) + + def top_line(self, indent: int = 2) -> str: + indent_str = " " * indent + if self.style == BoxStyle.SIMPLE: + return colorize(indent_str + "=" * self.width, self.color) + return colorize( + indent_str + self.chars['tl'] + self.chars['h'] * self.width + self.chars['tr'], + self.color + ) + + def bottom_line(self, indent: int = 2) -> str: + indent_str = " " * indent + if self.style == BoxStyle.SIMPLE: + return colorize(indent_str + "=" * self.width, self.color) + return colorize( + indent_str + self.chars['bl'] + self.chars['h'] * self.width + self.chars['br'], + self.color + ) + + def separator_line(self, indent: int = 2) -> str: + indent_str = " " * indent + if self.style == BoxStyle.SIMPLE: + return colorize(indent_str + "-" * self.width, self.color) + return colorize(indent_str + " " + self.chars['h'] * self.width, self.color) + + def empty_line(self, indent: int = 2) -> str: + indent_str = " " * indent + if self.style == BoxStyle.SIMPLE: + return "" + return colorize( + indent_str + self.chars['v'] + " " * self.width + self.chars['v'], + self.color + ) + + def text_line(self, text: str, align: str = 'left', indent: int = 2, text_color: str = '') -> str: + indent_str = " " * indent + content_width = self.width - 2 * self.padding + + # Strip ANSI codes to get actual display length + clean_text = strip_ansi(text) + text_len = len(clean_text) + + # Use original text (may contain colors) or apply new color + display_text = colorize(text, text_color) if text_color else text + + if align == 'center': + left_pad = (content_width - text_len) // 2 + right_pad = content_width - text_len - left_pad + content = " " * left_pad + display_text + " " * right_pad + elif align == 'right': + left_pad = content_width - text_len + content = " " * left_pad + display_text + else: # left + right_pad = content_width - text_len + content = display_text + " " * right_pad + + if self.style == BoxStyle.SIMPLE: + return indent_str + " " * self.padding + content + + padding_str = " " * self.padding + return colorize(indent_str + self.chars['v'], self.color) + \ + padding_str + content + padding_str + \ + colorize(self.chars['v'], self.color) + + def build(self, + title: Optional[str] = None, + lines: List[str] = None, + footer: Optional[str] = None, + indent: int = 2) -> str: + + result = [] + + result.append(self.top_line(indent)) + + if title: + result.append(self.empty_line(indent)) + result.append(self.text_line(title, align='center', indent=indent, text_color='c')) + result.append(self.empty_line(indent)) + + if lines: + for line in lines: + result.append(self.text_line(line, indent=indent)) + + if footer: + result.append(self.empty_line(indent)) + result.append(self.text_line(footer, align='center', indent=indent, text_color='gr')) + + result.append(self.bottom_line(indent)) + + return "\n".join(result) + + +def print_box(title: Optional[str] = None, + lines: List[str] = None, + footer: Optional[str] = None, + width: int = 68, + style: BoxStyle = BoxStyle.ROUNDED, + color: str = 'bl', + indent: int = 2): + + box = Box(width=width, style=style, color=color) + print(box.build(title=title, lines=lines, footer=footer, indent=indent)) + + +def print_banner(title: str, + subtitle: Optional[str] = None, + width: int = 66, + style: BoxStyle = BoxStyle.ROUNDED, + color: str = 'bl', + indent: int = 2): + + box = Box(width=width, style=style, color=color) + print() + print(box.top_line(indent)) + print(box.empty_line(indent)) + print(box.text_line(title, align='center', indent=indent, text_color='c')) + if subtitle: + print(box.text_line(subtitle, align='center', indent=indent, text_color='gr')) + print(box.empty_line(indent)) + print(box.bottom_line(indent)) + print() + + +def print_section(title: str, + content: List[str], + color: str = 'c', + indent: int = 2): + indent_str = " " * indent + print(f"\n{indent_str}{colorize('- ' + title, color, bold=True)}") + for line in content: + print(f"{indent_str} {line}") + + +def print_separator(width: int = 68, color: str = 'bl', indent: int = 2): + indent_str = " " * indent + print(colorize(indent_str + "─" * width, color)) \ No newline at end of file diff --git a/openspace/utils/logging.py b/openspace/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..fabb5b1191042252e95bf63ef8f1f472bc416bd0 --- /dev/null +++ b/openspace/utils/logging.py @@ -0,0 +1,312 @@ +import logging +import os +import sys +import threading +import json +from pathlib import Path +from datetime import datetime +from typing import Optional +from colorama import init + +init(autoreset=True) + + +def _load_log_level_from_config() -> int: + """ + Load log_level from config_grounding.json and convert to OPENSPACE_DEBUG value. + Returns: 0 (WARNING), 1 (INFO), or 2 (DEBUG) + """ + try: + config_path = Path(__file__).parent.parent / "config" / "config_grounding.json" + if config_path.exists(): + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + log_level = config.get("log_level", "INFO").upper() + + # Convert log level string to OPENSPACE_DEBUG value + level_map = { + "DEBUG": 2, + "INFO": 1, + "WARNING": 0, + "ERROR": 0, + "CRITICAL": 0 + } + return level_map.get(log_level, 1) # Default to INFO + except Exception: + # If any error occurs, silently return default INFO level + pass + return 1 # Default to INFO + + +# 0=WARNING, 1=INFO, 2=DEBUG; can be overridden by set_debug / environment variable +# Load from config_grounding.json to ensure consistency +OPENSPACE_DEBUG = _load_log_level_from_config() + +# Default log directory and file pattern +# Use absolute path to openspace/logs directory +DEFAULT_LOG_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "logs") +DEFAULT_LOG_FILE_PATTERN = "openspace_{timestamp}.log" + + +class FlushFileHandler(logging.FileHandler): + """File handler that flushes after each emit for real-time logging""" + + def emit(self, record): + super().emit(record) + self.flush() # Immediately flush to disk + + +class ColoredFormatter(logging.Formatter): + COLORS = { + 'DEBUG': '\033[1;36m', # Bold cyan + 'INFO': '\033[1;32m', # Bold green + 'WARNING': '\033[1;33m', # Bold yellow + 'ERROR': '\033[1;31m', # Bold red + 'CRITICAL': '\033[1;35m', # Bold magenta + 'RESET': '\033[0m', + } + + def format(self, record: logging.LogRecord) -> str: + formatted = super().format(record) + + level_color = self.COLORS.get(record.levelname, self.COLORS["RESET"]) + colored_line = f"{level_color}{formatted}{self.COLORS['RESET']}" + + return colored_line + + +class Logger: + """ + Thread-safe logger facade that: + 1. Configures handlers only once (lazy initialization). + 2. Ensures all subsequent loggers obtained via ``Logger.get_logger()`` + inherit the configured handlers. + 3. Dynamically adapts log levels according to ``OPENSPACE_DEBUG``. + """ + + _ROOT_NAME = "openspace" # Package root name + # Standard format: time with milliseconds | level | file:line number | message + _LOG_FORMAT = ( + "%(asctime)s.%(msecs)03d [%(levelname)-8s] %(filename)s:%(lineno)d - %(message)s" + ) + + _lock = threading.Lock() + _configured = False + _registered: dict[str, logging.Logger] = {} + + @staticmethod + def _get_default_log_file() -> str: + """Generate default log file path with timestamp (to seconds) + + Log files are organized by the running script name: + - logs//openspace_2025-10-24_15-30-00.log + """ + # Get the name of the main script + script_name = "openspace" # Default name + try: + import __main__ + if hasattr(__main__, "__file__") and __main__.__file__: + # Extract script name without extension + script_path = os.path.basename(__main__.__file__) + script_name = os.path.splitext(script_path)[0] + except Exception: + # If can't get script name, use default + pass + + # Create log directory: logs// + log_dir = os.path.join(DEFAULT_LOG_DIR, script_name) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + filename = DEFAULT_LOG_FILE_PATTERN.format(timestamp=timestamp) + return os.path.abspath(os.path.join(log_dir, filename)) + + @classmethod + def get_logger(cls, name: Optional[str] = None) -> logging.Logger: + """Return a logger with *name* (defaults to ``openspace``). + The first call triggers :meth:`configure` automatically.""" + if name is None: + name = cls._ROOT_NAME + + # Check if configuration is needed to avoid recursive calls. + need_config = False + with cls._lock: + logger = cls._registered.get(name) + if logger is None: + logger = logging.getLogger(name) + logger.propagate = True + cls._registered[name] = logger + if not cls._configured: + need_config = True + + if need_config: + cls.configure() + return logger + + @classmethod + def configure( + cls, + *, + level: Optional[int] = None, + fmt: Optional[str] = None, + log_to_console: bool = True, + log_to_file: Optional[str] = "auto", + use_colors: bool = True, + force_color: bool = False, + force: bool = False, + attach_to_root: bool = False, + ) -> None: + """ + Configure the logging system. Usually called automatically + on first use; pass ``force=True`` to reconfigure explicitly. + + Args: + level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + fmt: Log format string + log_to_console: Whether to output to console + log_to_file: Log file path ("auto" auto-generate by date, None disable, or specify path) + use_colors: Whether to use colors on console + force_color: Force use of colors (even if not supported) + force: Whether to force reconfiguration + attach_to_root: Whether to attach to root logger + + If *attach_to_root* is ``True``, handlers are attached to the *root* + logger (``""``). This makes every logger—regardless of its name— + inherit the handlers (handy for standalone scripts) but will also + surface logs from third-party libraries. Choose with care. + """ + with cls._lock: + if cls._configured and not force: + # Already configured and no need to force reconfiguration, only update level. + if level is not None: + cls._update_level(level) + return + + resolved_level = cls._resolve_level(level) + fmt_str = fmt or cls._LOG_FORMAT + + # Handle log_to_file parameter + actual_log_file = None + if log_to_file == "auto": + actual_log_file = cls._get_default_log_file() + elif log_to_file is not None: + actual_log_file = log_to_file + + # Select the logger to attach handlers to (root logger or openspace). + target_logger = ( + logging.getLogger() if attach_to_root else logging.getLogger(cls._ROOT_NAME) + ) + target_logger.setLevel(resolved_level) + + # Clean up old handlers. + for h in target_logger.handlers[:]: + target_logger.removeHandler(h) + + # Construct Formatter + date_fmt = "%Y-%m-%d %H:%M:%S" + color_supported = force_color or (use_colors and cls._stdout_supports_color()) + console_formatter = ( + ColoredFormatter(fmt_str, datefmt=date_fmt) if color_supported + else logging.Formatter(fmt_str, datefmt=date_fmt) + ) + file_formatter = logging.Formatter(fmt_str, datefmt=date_fmt) + + # Console Handler + if log_to_console: + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(resolved_level) + ch.setFormatter(console_formatter) + target_logger.addHandler(ch) + + # File Handler (with real-time flush) + if actual_log_file: + dir_path = os.path.dirname(actual_log_file) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + fh = FlushFileHandler(actual_log_file, encoding="utf-8") + fh.setLevel(resolved_level) + fh.setFormatter(file_formatter) + target_logger.addHandler(fh) + + # Record log file location + if not cls._configured: + print(f"Log file enabled: {actual_log_file}") + + cls._configured = True + + @classmethod + def set_debug(cls, debug_level: int = 2) -> None: + """Dynamically switch debug level: 0 = WARNING, 1 = INFO, 2 = DEBUG.""" + global OPENSPACE_DEBUG + OPENSPACE_DEBUG = max(0, min(debug_level, 2)) + cls._update_level(cls._resolve_level(None)) + + @classmethod + def add_file_handler( + cls, + filepath: str, + logger_name: Optional[str] = None + ) -> None: + """ + Append a file handler to the given (default ``openspace``) logger. + + Args: + filepath: Log file path + logger_name: Log logger name + """ + logger = cls.get_logger(logger_name or cls._ROOT_NAME) + + dir_path = os.path.dirname(filepath) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + + fh = FlushFileHandler(filepath, encoding="utf-8") + fh.setLevel(logger.level) + fh.setFormatter(logging.Formatter(cls._LOG_FORMAT, datefmt="%Y-%m-%d %H:%M:%S")) + logger.addHandler(fh) + + @classmethod + def reset_configuration(cls) -> None: + """Remove all handlers and clear registered loggers.""" + with cls._lock: + for lg in cls._registered.values(): + for h in lg.handlers[:]: + lg.removeHandler(h) + cls._registered.clear() + cls._configured = False + + @staticmethod + def _stdout_supports_color() -> bool: + return sys.stdout.isatty() and not os.getenv("NO_COLOR") + + @classmethod + def _resolve_level(cls, level: Optional[int]) -> int: + if level is not None: + # Allow passing logging.INFO / "INFO" / 20 etc. + return getattr(logging, str(level).upper(), level) + return {2: logging.DEBUG, 1: logging.INFO}.get(OPENSPACE_DEBUG, logging.WARNING) + + @classmethod + def _update_level(cls, level: int) -> None: + for lg in cls._registered.values(): + lg.setLevel(level) + for h in lg.handlers: + h.setLevel(level) + + +# Adjust debug level automatically according to the +# ``OPENSPACE_DEBUG`` (preferred) or legacy ``DEBUG`` environment variable. +_env_debug = os.getenv("OPENSPACE_DEBUG") or os.getenv("DEBUG") +if _env_debug is not None: + try: + Logger.set_debug(int(_env_debug)) + except ValueError: + # When not a number, use common format: DEBUG=1/true + Logger.set_debug(2 if _env_debug.strip().lower() in {"1", "true", "yes"} else 0) + +# Initialize logger system, attach to root so all loggers inherit the configuration +# This ensures any logger obtained via Logger.get_logger() will work correctly +Logger.configure(attach_to_root=True) + +# Get openspace logger for internal logging +logger = Logger.get_logger() +logger.debug("OpenSpace logging initialized") \ No newline at end of file diff --git a/openspace/utils/telemetry/__init__.py b/openspace/utils/telemetry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openspace/utils/telemetry/events.py b/openspace/utils/telemetry/events.py new file mode 100644 index 0000000000000000000000000000000000000000..ddebddf29944111fce40ad0af92fb5f2d7b456de --- /dev/null +++ b/openspace/utils/telemetry/events.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +class BaseTelemetryEvent(ABC): + """Base class for all telemetry events""" + + @property + @abstractmethod + def name(self) -> str: + """Event name for tracking""" + pass + + @property + @abstractmethod + def properties(self) -> dict[str, Any]: + """Event properties to send with the event""" + pass + + +@dataclass +class MCPAgentExecutionEvent(BaseTelemetryEvent): + """Comprehensive event for tracking complete MCP agent execution""" + + # Execution method and context + execution_method: str # "run" or "astream" + query: str # The actual user query + success: bool + + # Agent configuration + model_provider: str + model_name: str + server_count: int + server_identifiers: list[dict[str, str]] + total_tools_available: int + tools_available_names: list[str] + max_steps_configured: int + memory_enabled: bool + use_server_manager: bool + + # Execution PARAMETERS + max_steps_used: int | None + manage_connector: bool + external_history_used: bool + + # Execution results + steps_taken: int | None = None + tools_used_count: int | None = None + tools_used_names: list[str] | None = None + response: str | None = None # The actual response + execution_time_ms: int | None = None + error_type: str | None = None + + # Context + conversation_history_length: int | None = None + + @property + def name(self) -> str: + return "mcp_agent_execution" + + @property + def properties(self) -> dict[str, Any]: + return { + # Core execution info + "execution_method": self.execution_method, + "query": self.query, + "query_length": len(self.query), + "success": self.success, + # Agent configuration + "model_provider": self.model_provider, + "model_name": self.model_name, + "server_count": self.server_count, + "server_identifiers": self.server_identifiers, + "total_tools_available": self.total_tools_available, + "tools_available_names": self.tools_available_names, + "max_steps_configured": self.max_steps_configured, + "memory_enabled": self.memory_enabled, + "use_server_manager": self.use_server_manager, + # Execution parameters (always include, even if None) + "max_steps_used": self.max_steps_used, + "manage_connector": self.manage_connector, + "external_history_used": self.external_history_used, + # Execution results (always include, even if None) + "steps_taken": self.steps_taken, + "tools_used_count": self.tools_used_count, + "tools_used_names": self.tools_used_names, + "response": self.response, + "response_length": len(self.response) if self.response else None, + "execution_time_ms": self.execution_time_ms, + "error_type": self.error_type, + "conversation_history_length": self.conversation_history_length, + } \ No newline at end of file diff --git a/openspace/utils/telemetry/telemetry.py b/openspace/utils/telemetry/telemetry.py new file mode 100644 index 0000000000000000000000000000000000000000..cab29a0137bceb8f9d89f8d94d04c2756c6c7763 --- /dev/null +++ b/openspace/utils/telemetry/telemetry.py @@ -0,0 +1,316 @@ +import logging +import os +import platform +import uuid +from collections.abc import Callable +from functools import wraps +from pathlib import Path +from typing import Any + +from posthog import Posthog +from scarf import ScarfEventLogger + +from mcp_use.logging import MCP_USE_DEBUG +from mcp_use.telemetry.events import ( + BaseTelemetryEvent, + MCPAgentExecutionEvent, +) +from mcp_use.telemetry.utils import get_package_version + +logger = logging.getLogger(__name__) + + +def singleton(cls): + """A decorator that implements the singleton pattern for a class.""" + instance = [None] + + def wrapper(*args, **kwargs): + if instance[0] is None: + instance[0] = cls(*args, **kwargs) + return instance[0] + + return wrapper + +def requires_telemetry(func: Callable) -> Callable: + """Decorator that skips function execution if telemetry is disabled""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self._posthog_client and not self._scarf_client: + return None + return func(self, *args, **kwargs) + + return wrapper + + +def get_cache_home() -> Path: + """Get platform-appropriate cache directory.""" + # XDG_CACHE_HOME for Linux and manually set envs + env_var: str | None = os.getenv("XDG_CACHE_HOME") + if env_var and (path := Path(env_var)).is_absolute(): + return path + + system = platform.system() + if system == "Windows": + appdata = os.getenv("LOCALAPPDATA") or os.getenv("APPDATA") + if appdata: + return Path(appdata) + return Path.home() / "AppData" / "Local" + elif system == "Darwin": # macOS + return Path.home() / "Library" / "Caches" + else: # Linux or other Unix + return Path.home() / ".cache" + + +@singleton +class Telemetry: + """ + Service for capturing anonymized telemetry data via PostHog and Scarf. + If the environment variable `MCP_USE_ANONYMIZED_TELEMETRY=false`, telemetry will be disabled. + """ + + USER_ID_PATH = str(get_cache_home() / "mcp_use_3" / "telemetry_user_id") + VERSION_DOWNLOAD_PATH = str(get_cache_home() / "mcp_use" / "download_version") + PROJECT_API_KEY = "phc_lyTtbYwvkdSbrcMQNPiKiiRWrrM1seyKIMjycSvItEI" + HOST = "https://eu.i.posthog.com" + SCARF_GATEWAY_URL = "https://mcpuse.gateway.scarf.sh/events" + UNKNOWN_USER_ID = "UNKNOWN_USER_ID" + + _curr_user_id = None + + def __init__(self): + telemetry_disabled = os.getenv("MCP_USE_ANONYMIZED_TELEMETRY", "true").lower() == "false" + + if telemetry_disabled: + self._posthog_client = None + self._scarf_client = None + logger.debug("Telemetry disabled") + else: + logger.info("Anonymized telemetry enabled. Set MCP_USE_ANONYMIZED_TELEMETRY=false to disable.") + + # Initialize PostHog + try: + self._posthog_client = Posthog( + project_api_key=self.PROJECT_API_KEY, + host=self.HOST, + disable_geoip=False, + enable_exception_autocapture=True, + ) + + # Silence posthog's logging unless debug mode (level 2) + if MCP_USE_DEBUG < 2: + posthog_logger = logging.getLogger("posthog") + posthog_logger.disabled = True + + except Exception as e: + logger.warning(f"Failed to initialize PostHog telemetry: {e}") + self._posthog_client = None + + # Initialize Scarf + try: + self._scarf_client = ScarfEventLogger( + endpoint_url=self.SCARF_GATEWAY_URL, + timeout=3.0, + verbose=MCP_USE_DEBUG >= 2, + ) + + # Silence scarf's logging unless debug mode (level 2) + if MCP_USE_DEBUG < 2: + scarf_logger = logging.getLogger("scarf") + scarf_logger.disabled = True + + except Exception as e: + logger.warning(f"Failed to initialize Scarf telemetry: {e}") + self._scarf_client = None + + @property + def user_id(self) -> str: + """Get or create a persistent anonymous user ID""" + if self._curr_user_id: + return self._curr_user_id + + try: + is_first_time = not os.path.exists(self.USER_ID_PATH) + + if is_first_time: + logger.debug(f"Creating user ID path: {self.USER_ID_PATH}") + os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) + with open(self.USER_ID_PATH, "w") as f: + new_user_id = str(uuid.uuid4()) + f.write(new_user_id) + self._curr_user_id = new_user_id + + logger.debug(f"User ID path created: {self.USER_ID_PATH}") + else: + with open(self.USER_ID_PATH) as f: + self._curr_user_id = f.read().strip() + + # Always check for version-based download tracking + self.track_package_download( + { + "triggered_by": "user_id_property", + } + ) + except Exception as e: + logger.debug(f"Failed to get/create user ID: {e}") + self._curr_user_id = self.UNKNOWN_USER_ID + + return self._curr_user_id + + @requires_telemetry + def capture(self, event: BaseTelemetryEvent) -> None: + """Capture a telemetry event""" + # Send to PostHog + if self._posthog_client: + try: + # Add package version to all events + properties = event.properties.copy() + properties["mcp_use_version"] = get_package_version() + + self._posthog_client.capture(distinct_id=self.user_id, event=event.name, properties=properties) + except Exception as e: + logger.debug(f"Failed to track PostHog event {event.name}: {e}") + + # Send to Scarf + if self._scarf_client: + try: + # Add package version and user_id to all events + properties = {} + properties["mcp_use_version"] = get_package_version() + properties["user_id"] = self.user_id + properties["event"] = event.name + + # Convert complex types to simple types for Scarf compatibility + self._scarf_client.log_event(properties=properties) + except Exception as e: + logger.debug(f"Failed to track Scarf event {event.name}: {e}") + + @requires_telemetry + def track_package_download(self, properties: dict[str, Any] | None = None) -> None: + """Track package download event specifically for Scarf analytics""" + if self._scarf_client: + try: + current_version = get_package_version() + should_track = False + first_download = False + + # Check if version file exists + if not os.path.exists(self.VERSION_DOWNLOAD_PATH): + # First download + should_track = True + first_download = True + + # Create directory and save version + os.makedirs(os.path.dirname(self.VERSION_DOWNLOAD_PATH), exist_ok=True) + with open(self.VERSION_DOWNLOAD_PATH, "w") as f: + f.write(current_version) + else: + # Read saved version + with open(self.VERSION_DOWNLOAD_PATH) as f: + saved_version = f.read().strip() + + # Compare versions (simple string comparison for now) + if current_version > saved_version: + should_track = True + first_download = False + + # Update saved version + with open(self.VERSION_DOWNLOAD_PATH, "w") as f: + f.write(current_version) + + if should_track: + logger.debug(f"Tracking package download event with properties: {properties}") + # Add package version and user_id to event + event_properties = (properties or {}).copy() + event_properties["mcp_use_version"] = current_version + event_properties["user_id"] = self.user_id + event_properties["event"] = "package_download" + event_properties["first_download"] = first_download + + # Convert complex types to simple types for Scarf compatibility + self._scarf_client.log_event(properties=event_properties) + except Exception as e: + logger.debug(f"Failed to track Scarf package_download event: {e}") + + @requires_telemetry + def track_agent_execution( + self, + execution_method: str, + query: str, + success: bool, + model_provider: str, + model_name: str, + server_count: int, + server_identifiers: list[dict[str, str]], + total_tools_available: int, + tools_available_names: list[str], + max_steps_configured: int, + memory_enabled: bool, + use_server_manager: bool, + max_steps_used: int | None, + manage_connector: bool, + external_history_used: bool, + steps_taken: int | None = None, + tools_used_count: int | None = None, + tools_used_names: list[str] | None = None, + response: str | None = None, + execution_time_ms: int | None = None, + error_type: str | None = None, + conversation_history_length: int | None = None, + ) -> None: + """Track comprehensive agent execution""" + event = MCPAgentExecutionEvent( + execution_method=execution_method, + query=query, + success=success, + model_provider=model_provider, + model_name=model_name, + server_count=server_count, + server_identifiers=server_identifiers, + total_tools_available=total_tools_available, + tools_available_names=tools_available_names, + max_steps_configured=max_steps_configured, + memory_enabled=memory_enabled, + use_server_manager=use_server_manager, + max_steps_used=max_steps_used, + manage_connector=manage_connector, + external_history_used=external_history_used, + steps_taken=steps_taken, + tools_used_count=tools_used_count, + tools_used_names=tools_used_names, + response=response, + execution_time_ms=execution_time_ms, + error_type=error_type, + conversation_history_length=conversation_history_length, + ) + self.capture(event) + + @requires_telemetry + def flush(self) -> None: + """Flush any queued telemetry events""" + # Flush PostHog + if self._posthog_client: + try: + self._posthog_client.flush() + logger.debug("PostHog client telemetry queue flushed") + except Exception as e: + logger.debug(f"Failed to flush PostHog client: {e}") + + # Scarf events are sent immediately, no flush needed + if self._scarf_client: + logger.debug("Scarf telemetry events sent immediately (no flush needed)") + + @requires_telemetry + def shutdown(self) -> None: + """Shutdown telemetry clients and flush remaining events""" + # Shutdown PostHog + if self._posthog_client: + try: + self._posthog_client.shutdown() + logger.debug("PostHog client shutdown successfully") + except Exception as e: + logger.debug(f"Error shutting down PostHog client: {e}") + + # Scarf doesn't require explicit shutdown + if self._scarf_client: + logger.debug("Scarf telemetry client shutdown (no action needed)") diff --git a/openspace/utils/telemetry/utils.py b/openspace/utils/telemetry/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..245eb9dc3a2bc6eccb0cd62817d67405ddabced1 --- /dev/null +++ b/openspace/utils/telemetry/utils.py @@ -0,0 +1,54 @@ +""" +Utility functions for extracting model information from LangChain LLMs. + +This module provides utilities to extract provider and model information +from LangChain language models for telemetry purposes. +""" + +import importlib.metadata +from typing import Any + +try: + from langchain_core.language_models.base import BaseLanguageModel # type: ignore[import-untyped] +except ImportError: + BaseLanguageModel = None # type: ignore[misc, assignment] + + +def get_package_version() -> str: + """Get the current mcp-use package version.""" + try: + return importlib.metadata.version("mcp-use") + except importlib.metadata.PackageNotFoundError: + return "unknown" + + +def get_model_provider(llm: Any) -> str: + """Extract the model provider from LangChain LLM using BaseChatModel standards.""" + if BaseLanguageModel is None: + return getattr(llm, "__class__", type(llm)).__name__.lower() + # Use LangChain's standard _llm_type property for identification + return getattr(llm, "_llm_type", llm.__class__.__name__.lower()) + + +def get_model_name(llm: Any) -> str: + """Extract the model name from LangChain LLM using BaseChatModel standards.""" + # First try _identifying_params which may contain model info + if hasattr(llm, "_identifying_params"): + identifying_params = llm._identifying_params + if isinstance(identifying_params, dict): + # Common keys that contain model names + for key in ["model", "model_name", "model_id", "deployment_name"]: + if key in identifying_params: + return str(identifying_params[key]) + + # Fallback to direct model attributes + return getattr(llm, "model", getattr(llm, "model_name", llm.__class__.__name__)) + + +def extract_model_info(llm: Any) -> tuple[str, str]: + """Extract both provider and model name from LangChain LLM. + + Returns: + Tuple of (provider, model_name) + """ + return get_model_provider(llm), get_model_name(llm) diff --git a/openspace/utils/ui.py b/openspace/utils/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..22729ec2543d48cd0472a44c168f630fdc802816 --- /dev/null +++ b/openspace/utils/ui.py @@ -0,0 +1,446 @@ +""" +OpenSpace Terminal UI System + +Provides real-time CLI visualization for OpenSpace execution flow. +Displays agent activities, grounding backends, and detailed logs. + +Uses native ANSI colors and custom box drawing for a clean, lightweight interface. +""" + +from typing import Optional, Dict, Any, List, Tuple +from datetime import datetime +from enum import Enum +import asyncio +import sys +import shutil + +from openspace.utils.display import Box, BoxStyle, colorize + + +class AgentStatus(Enum): + """Agent execution status""" + IDLE = "idle" + THINKING = "thinking" + EXECUTING = "executing" + WAITING = "waiting" + + +class OpenSpaceUI: + """ + OpenSpace Terminal UI + + Provides real-time visualization of: + - Agent activities and status + - Grounding backend operations + - Execution logs + - System metrics + + Design Philosophy: + - Lightweight and fast (no heavy dependencies) + - Clean ANSI-based rendering + - Minimal CPU overhead + - Easy to customize + """ + + def __init__(self, enable_live: bool = True, compact: bool = False): + """ + Initialize UI + + Args: + enable_live: Whether to enable live display updates + compact: Use compact layout (for smaller terminals) + """ + self.enable_live = enable_live + self.compact = compact + + # Terminal dimensions + self.term_width, self.term_height = self._get_terminal_size() + + # State tracking + self.agent_status: Dict[str, AgentStatus] = {} + self.agent_activities: Dict[str, List[str]] = {} + self.grounding_operations: List[Dict[str, Any]] = [] + self.grounding_backends: List[Dict[str, Any]] = [] # Backend info (type, servers, etc.) + self.log_buffer: List[Tuple[str, str, datetime]] = [] # (message, level, timestamp) + + # Metrics + self.metrics: Dict[str, Any] = { + "start_time": None, + "iterations": 0, + "completed_tasks": 0, + "llm_calls": 0, + "grounding_calls": 0, + } + + # Live display state + self._live_running = False + self._live_task: Optional[asyncio.Task] = None + self._last_render: List[str] = [] + + def _get_terminal_size(self) -> Tuple[int, int]: + """Get terminal size""" + try: + size = shutil.get_terminal_size((80, 24)) + return size.columns, size.lines + except: + return 80, 24 + + def _clear_screen(self): + """Clear screen""" + if self.enable_live: + # Clear entire screen and move cursor to top-left + sys.stdout.write('\033[2J\033[H') + sys.stdout.flush() + + def _move_cursor_home(self): + """Move cursor to home position""" + sys.stdout.write('\033[H') + sys.stdout.flush() + + def _hide_cursor(self): + """Hide cursor""" + sys.stdout.write('\033[?25l') + sys.stdout.flush() + + def _show_cursor(self): + """Show cursor""" + sys.stdout.write('\033[?25h') + sys.stdout.flush() + + # Banner and Startup + def print_banner(self): + """Print startup banner""" + box = Box(width=70, style=BoxStyle.ROUNDED, color='c') + + print() + print(box.top_line(indent=4)) + print(box.empty_line(indent=4)) + + # Title + title = colorize("OpenSpace", 'c', bold=True) + print(box.text_line(title, align='center', indent=4, text_color='')) + + # Subtitle + subtitle = "Self-Evolving Skill Worker & Community" + print(box.text_line(subtitle, align='center', indent=4, text_color='gr')) + + print(box.empty_line(indent=4)) + print(box.bottom_line(indent=4)) + print() + + def print_initialization(self, steps: List[Tuple[str, str]]): + """ + Print initialization steps + + Args: + steps: List of (component_name, status) tuples + """ + box = Box(width=70, style=BoxStyle.ROUNDED, color='bl') + + print(box.text_line("Initializing Components", align='center', indent=4, text_color='c')) + print(box.separator_line(indent=4)) + + for component, status in steps: + icon = colorize("✓", 'g') if status == "ok" else colorize("✗", 'rd') + line = f"{icon} {component}" + print(box.text_line(line, indent=4)) + + print(box.bottom_line(indent=4)) + print() + + async def start_live_display(self): + """Start live display""" + if not self.enable_live or self._live_running: + return + + self._live_running = True + self.metrics["start_time"] = datetime.now() + self._clear_screen() + self._hide_cursor() + + # Start update loop + self._live_task = asyncio.create_task(self._live_update_loop()) + + async def stop_live_display(self): + """Stop live display""" + if not self._live_running: + return + + self._live_running = False + + if self._live_task: + self._live_task.cancel() + try: + await self._live_task + except asyncio.CancelledError: + pass + + self._show_cursor() + print() # Add newline after live display + + async def _live_update_loop(self): + """Live update loop""" + while self._live_running: + try: + self.render() + await asyncio.sleep(2.0) + except asyncio.CancelledError: + break + except Exception as e: + print(f"UI render error: {e}") + + def render(self): + """Render entire UI""" + if not self.enable_live or not self._live_running: + return + + # Clear and redraw + self._clear_screen() + + lines = [] + + # Header + lines.extend(self._render_header()) + lines.append("") + + # Stack all panels vertically + lines.extend(self._render_agents()) + lines.append("") + lines.extend(self._render_grounding()) + lines.append("") + lines.extend(self._render_logs()) + + output = "\n".join(lines) + sys.stdout.write(output) + sys.stdout.flush() + + def update_display(self): + """Update display (alias for render())""" + self.render() + + def _render_header(self) -> List[str]: + """Render header section""" + lines = [] + + # Calculate elapsed time + elapsed = "0s" + if self.metrics["start_time"]: + delta = datetime.now() - self.metrics["start_time"] + minutes = delta.seconds // 60 + seconds = delta.seconds % 60 + if minutes > 0: + elapsed = f"{minutes}m{seconds}s" + else: + elapsed = f"{seconds}s" + + status_text = ( + f"▶ {colorize('RUNNING', 'g')} | " + f"Time: {colorize(elapsed, 'c')} | " + f"Iter: {colorize(str(self.metrics['iterations']), 'y')} | " + f"Tasks: {colorize(str(self.metrics['completed_tasks']), 'g')} | " + f"LLM: {colorize(str(self.metrics['llm_calls']), 'bl')} | " + f"Grounding: {colorize(str(self.metrics['grounding_calls']), 'm')}" + ) + + lines.append(" " + status_text) + lines.append(" " + "─" * 60) + + return lines + + def _render_agents(self) -> List[str]: + """Render agents section""" + lines = [] + + lines.append(" " + colorize("§ Agents", 'c', bold=True)) + + # Agent info + agents = [ + ("GroundingAgent", 'c', self.agent_status.get("GroundingAgent", AgentStatus.IDLE)), + ] + + for agent_name, color, status in agents: + # Status icon + status_icons = { + AgentStatus.IDLE: "○", + AgentStatus.THINKING: "◐", + AgentStatus.EXECUTING: "◉", + AgentStatus.WAITING: "◷", + } + icon = status_icons.get(status, "○") + + # Recent activity + activities = self.agent_activities.get(agent_name, []) + activity = activities[-1][:40] if activities else "idle" + + # Format line + line = f" {colorize(icon, 'y')} {colorize(agent_name, color):<20s} {activity}" + lines.append(line) + + return lines + + + def _render_grounding(self) -> List[str]: + """Render grounding operations section""" + lines = [] + + lines.append(" " + colorize("⊕ Grounding Backends", 'c', bold=True)) + + # Show backend types and servers + if self.grounding_backends: + for backend_info in self.grounding_backends: + backend_name = backend_info.get("name", "unknown") + backend_type = backend_info.get("type", "unknown") + servers = backend_info.get("servers", []) + + # Backend type icon + type_icons = { + "gui": "■", + "shell": "$", + "mcp": "◆", + "system": "●", + "web": "◉", + } + icon = type_icons.get(backend_type, "○") + + # Format backend line + if backend_type == "mcp" and servers: + servers_str = ", ".join([s[:15] for s in servers]) + line = f" {icon} {colorize(backend_name, 'y')} ({backend_type}): {colorize(servers_str, 'gr')}" + else: + line = f" {icon} {colorize(backend_name, 'y')} ({backend_type})" + + lines.append(line) + + # Show last 3 operations + recent_ops = self.grounding_operations[-3:] if self.grounding_operations else [] + + if recent_ops: + lines.append(" " + colorize("Recent Operations:", 'gr')) + for op in recent_ops: + backend = op.get("backend", "unknown") + action = op.get("action", "unknown")[:40] + status = op.get("status", "pending") + + # Status icon + if status == "success": + icon = colorize("✓", 'g') + elif status == "pending": + icon = colorize("⏳", 'y') + else: + icon = colorize("✗", 'rd') + + line = f" {icon} {colorize(backend, 'bl')}: {action}" + lines.append(line) + + return lines + + def _render_logs(self) -> List[str]: + """Render logs section""" + lines = [] + + lines.append(" " + colorize("⊞ Recent Events", 'c', bold=True)) + + # Show last 5 logs + recent_logs = self.log_buffer[-5:] if self.log_buffer else [] + + if recent_logs: + for message, level, timestamp in recent_logs: + time_str = timestamp.strftime("%H:%M:%S") + + # Truncate long messages + msg_display = message[:55] + + log_line = f" {colorize(time_str, 'gr')} | {msg_display}" + lines.append(log_line) + + return lines + + + def update_agent_status(self, agent_name: str, status: AgentStatus): + """Update agent status""" + self.agent_status[agent_name] = status + + def add_agent_activity(self, agent_name: str, activity: str): + """Add agent activity""" + if agent_name not in self.agent_activities: + self.agent_activities[agent_name] = [] + + self.agent_activities[agent_name].append(activity) + + # Keep only last 10 activities + if len(self.agent_activities[agent_name]) > 10: + self.agent_activities[agent_name] = self.agent_activities[agent_name][-10:] + + def update_grounding_backends(self, backends: List[Dict[str, Any]]): + """ + Update grounding backends information + + Args: + backends: List of backend info dicts with keys: + - name: backend name + - type: backend type (gui, shell, mcp, system, web) + - servers: list of server names (for mcp) + """ + self.grounding_backends = backends + + def add_grounding_operation(self, backend: str, action: str, status: str = "pending"): + """Add grounding operation""" + self.grounding_operations.append({ + "backend": backend, + "action": action, + "status": status, + "timestamp": datetime.now(), + }) + + self.metrics["grounding_calls"] += 1 + + def add_log(self, message: str, level: str = "info"): + """Add log message""" + self.log_buffer.append((message, level, datetime.now())) + + # Keep only last 100 logs + if len(self.log_buffer) > 100: + self.log_buffer = self.log_buffer[-100:] + + def update_metrics(self, **kwargs): + """Update metrics""" + self.metrics.update(kwargs) + + def print_summary(self, result: Dict[str, Any]): + """Print execution summary""" + box = Box(width=70, style=BoxStyle.ROUNDED, color='c') + + print() + print(box.text_line(colorize("◈ Execution Summary", 'c', bold=True), align='center', indent=4, text_color='')) + print(box.separator_line(indent=4)) + + # Status + status = result.get("status", "unknown") + status_display = { + "completed": colorize("COMPLETED", 'g', bold=True), + "timeout": colorize("TIMEOUT", 'y', bold=True), + "error": colorize("ERROR", 'rd', bold=True), + } + status_text = status_display.get(status, status) + + print(box.text_line(f" Status: {status_text}", indent=4, text_color='')) + print(box.text_line(f" Execution Time: {colorize(f'{result.get('execution_time', 0):.2f}s', 'c')}", indent=4, text_color='')) + print(box.text_line(f" Iterations: {colorize(str(result.get('iterations', 0)), 'y')}", indent=4, text_color='')) + print(box.text_line(f" Completed Tasks: {colorize(str(result.get('completed_tasks', 0)), 'g')}", indent=4, text_color='')) + + if result.get('evaluation_results'): + print(box.text_line(f" Evaluations: {colorize(str(len(result['evaluation_results'])), 'bl')}", indent=4, text_color='')) + + print(box.bottom_line(indent=4)) + print() + + +def create_ui(enable_live: bool = True, compact: bool = False) -> OpenSpaceUI: + """ + Create OpenSpace UI instance + + Args: + enable_live: Whether to enable live display updates + compact: Use compact layout for smaller terminals + """ + return OpenSpaceUI(enable_live=enable_live, compact=compact) \ No newline at end of file diff --git a/openspace/utils/ui_integration.py b/openspace/utils/ui_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbf180e3dad7037baa75369d9a2196a00449a06 --- /dev/null +++ b/openspace/utils/ui_integration.py @@ -0,0 +1,298 @@ +""" +OpenSpace UI Integration + +Integrates the UI system with OpenSpace core components. +Provides hooks and callbacks to update UI in real-time. +""" + +import asyncio +from typing import Optional + +from openspace.utils.ui import OpenSpaceUI, AgentStatus +from openspace.utils.logging import Logger + +logger = Logger.get_logger(__name__) + + +class UIIntegration: + """ + UI Integration for OpenSpace + + Connects OpenSpace components with the UI system to provide real-time + visualization of agent activities and execution flow. + """ + + def __init__(self, ui: OpenSpaceUI): + """ + Initialize UI integration + + Args: + ui: OpenSpaceUI instance + """ + self.ui = ui + self._update_task: Optional[asyncio.Task] = None + self._running = False + + # Tracked components + self._llm_client = None + self._grounding_client = None + + def attach_llm_client(self, llm_client): + """ + Attach LLM client + + Args: + llm_client: LLMClient instance + """ + self._llm_client = llm_client + logger.debug("UI attached to LLMClient") + + def attach_grounding_client(self, grounding_client): + """ + Attach grounding client + + Args: + grounding_client: GroundingClient instance + """ + self._grounding_client = grounding_client + logger.debug("UI attached to GroundingClient") + + async def start_monitoring(self, poll_interval: float = 0.5): + """ + Start monitoring and updating UI + + Args: + poll_interval: Update interval in seconds + """ + if self._running: + logger.warning("UI monitoring already running") + return + + self._running = True + + # Immediately update UI once before starting the loop + await self._update_ui() + + self._update_task = asyncio.create_task( + self._monitor_loop(poll_interval) + ) + logger.info("UI monitoring started") + + async def stop_monitoring(self): + """Stop monitoring""" + if not self._running: + return + + self._running = False + + if self._update_task: + self._update_task.cancel() + try: + await self._update_task + except asyncio.CancelledError: + pass + + logger.info("UI monitoring stopped") + + async def _monitor_loop(self, poll_interval: float): + """ + Main monitoring loop + + Args: + poll_interval: Update interval in seconds + """ + while self._running: + try: + await self._update_ui() + await asyncio.sleep(poll_interval) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"UI update error: {e}", exc_info=True) + + async def _update_ui(self): + """Update UI with current state""" + # Update grounding backends info + if self._grounding_client: + backends = [] + try: + # Get list of providers + providers = self._grounding_client.list_providers() + + for backend_type, provider in providers.items(): + backend_name = backend_type.value if hasattr(backend_type, 'value') else str(backend_type) + + backend_info = { + "name": backend_name, + "type": backend_name, # gui, shell, mcp, system, web + "servers": [] + } + + # For MCP provider, get server names + if backend_name == "mcp": + try: + # Try to get MCP sessions from provider + if hasattr(provider, '_sessions'): + backend_info["servers"] = list(provider._sessions.keys()) + except Exception: + pass + + backends.append(backend_info) + + self.ui.update_grounding_backends(backends) + except Exception as e: + logger.debug(f"Failed to update grounding backends: {e}") + + # Refresh display + self.ui.update_display() + + # Event handlers - to be called by agents + + def on_agent_start(self, agent_name: str, activity: str): + """ + Called when agent starts an activity + + Args: + agent_name: Agent name + activity: Activity description + """ + self.ui.update_agent_status(agent_name, AgentStatus.EXECUTING) + self.ui.add_agent_activity(agent_name, activity) + self.ui.add_log(f"{agent_name}: {activity}", level="info") + + def on_agent_thinking(self, agent_name: str): + """ + Called when agent is thinking + + Args: + agent_name: Agent name + """ + self.ui.update_agent_status(agent_name, AgentStatus.THINKING) + + def on_agent_complete(self, agent_name: str, result: str = ""): + """ + Called when agent completes an activity + + Args: + agent_name: Agent name + result: Result description + """ + self.ui.update_agent_status(agent_name, AgentStatus.IDLE) + if result: + self.ui.add_log(f"{agent_name}: {result}", level="success") + + def on_llm_call(self, model: str, prompt_length: int): + """ + Called when LLM is called + + Args: + model: Model name + prompt_length: Prompt length + """ + self.ui.update_metrics( + llm_calls=self.ui.metrics.get("llm_calls", 0) + 1 + ) + self.ui.add_log(f"LLM call: {model} (prompt: {prompt_length} chars)", level="debug") + + def on_grounding_call(self, backend: str, action: str): + """ + Called when grounding backend is called + + Args: + backend: Backend name + action: Action description + """ + self.ui.add_grounding_operation(backend, action, status="pending") + self.ui.add_log(f"Grounding [{backend}]: {action}", level="info") + + def on_grounding_complete(self, backend: str, action: str, success: bool): + """ + Called when grounding operation completes + + Args: + backend: Backend name + action: Action description + success: Whether operation succeeded + """ + status = "success" if success else "error" + + # Update last operation status + for op in reversed(self.ui.grounding_operations): + if op["backend"] == backend and op["action"] == action and op["status"] == "pending": + op["status"] = status + break + + level = "success" if success else "error" + result = "✓" if success else "✗" + self.ui.add_log(f"Grounding [{backend}]: {action} {result}", level=level) + + + def on_iteration(self, iteration: int): + """ + Called on each iteration + + Args: + iteration: Iteration number + """ + self.ui.update_metrics(iterations=iteration) + + def on_error(self, message: str): + """ + Called when an error occurs + + Args: + message: Error message + """ + self.ui.add_log(f"ERROR: {message}", level="error") + + +class UILoggingHandler: + """ + Logging handler that forwards logs to UI + """ + + def __init__(self, ui: OpenSpaceUI): + """ + Initialize logging handler + + Args: + ui: OpenSpaceUI instance + """ + self.ui = ui + + def emit(self, record): + """ + Emit a log record to UI + + Args: + record: Log record + """ + level_map = { + "DEBUG": "debug", + "INFO": "info", + "WARNING": "warning", + "ERROR": "error", + "CRITICAL": "error", + } + + level = level_map.get(record.levelname, "info") + message = record.getMessage() + + # Filter out noisy logs + if any(skip in message.lower() for skip in ["processing card", "workflow poll"]): + return + + self.ui.add_log(message, level=level) + + +def create_integration(ui: OpenSpaceUI) -> UIIntegration: + """ + Create UI integration instance + + Args: + ui: OpenSpaceUI instance + + Returns: + UIIntegration instance + """ + return UIIntegration(ui) \ No newline at end of file