Spaces:
Running
Running
Upload 160 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- openspace/.env.example +53 -0
- openspace/__init__.py +71 -0
- openspace/__main__.py +473 -0
- openspace/agents/__init__.py +9 -0
- openspace/agents/base.py +194 -0
- openspace/agents/grounding_agent.py +1212 -0
- openspace/cloud/__init__.py +31 -0
- openspace/cloud/auth.py +102 -0
- openspace/cloud/cli/__init__.py +0 -0
- openspace/cloud/cli/download_skill.py +63 -0
- openspace/cloud/cli/upload_skill.py +83 -0
- openspace/cloud/client.py +497 -0
- openspace/cloud/embedding.py +129 -0
- openspace/cloud/search.py +393 -0
- openspace/config/README.md +115 -0
- openspace/config/__init__.py +32 -0
- openspace/config/config_agents.json +11 -0
- openspace/config/config_dev.json.example +12 -0
- openspace/config/config_grounding.json +82 -0
- openspace/config/config_mcp.json.example +11 -0
- openspace/config/config_security.json +68 -0
- openspace/config/constants.py +23 -0
- openspace/config/grounding.py +311 -0
- openspace/config/loader.py +177 -0
- openspace/config/utils.py +30 -0
- openspace/dashboard_server.py +639 -0
- openspace/grounding/backends/__init__.py +34 -0
- openspace/grounding/backends/gui/__init__.py +25 -0
- openspace/grounding/backends/gui/anthropic_client.py +575 -0
- openspace/grounding/backends/gui/anthropic_utils.py +241 -0
- openspace/grounding/backends/gui/config.py +76 -0
- openspace/grounding/backends/gui/provider.py +143 -0
- openspace/grounding/backends/gui/session.py +188 -0
- openspace/grounding/backends/gui/tool.py +712 -0
- openspace/grounding/backends/gui/transport/actions.py +232 -0
- openspace/grounding/backends/gui/transport/connector.py +389 -0
- openspace/grounding/backends/gui/transport/local_connector.py +364 -0
- openspace/grounding/backends/mcp/__init__.py +41 -0
- openspace/grounding/backends/mcp/client.py +409 -0
- openspace/grounding/backends/mcp/config.py +132 -0
- openspace/grounding/backends/mcp/installer.py +697 -0
- openspace/grounding/backends/mcp/provider.py +473 -0
- openspace/grounding/backends/mcp/session.py +75 -0
- openspace/grounding/backends/mcp/tool_cache.py +254 -0
- openspace/grounding/backends/mcp/tool_converter.py +194 -0
- openspace/grounding/backends/mcp/transport/connectors/__init__.py +20 -0
- openspace/grounding/backends/mcp/transport/connectors/base.py +374 -0
- openspace/grounding/backends/mcp/transport/connectors/http.py +705 -0
- openspace/grounding/backends/mcp/transport/connectors/sandbox.py +251 -0
- openspace/grounding/backends/mcp/transport/connectors/stdio.py +76 -0
openspace/.env.example
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# OpenSpace Environment Variables
|
| 3 |
+
# Copy this file to .env and fill in your keys
|
| 4 |
+
# ============================================
|
| 5 |
+
|
| 6 |
+
# ---- LLM API Keys ----
|
| 7 |
+
# At least one LLM API key is required for OpenSpace to function.
|
| 8 |
+
# OpenSpace uses LiteLLM for model routing, so the key you need depends on your chosen model.
|
| 9 |
+
# See https://docs.litellm.ai/docs/providers for supported providers.
|
| 10 |
+
|
| 11 |
+
# Anthropic (for anthropic/claude-* models)
|
| 12 |
+
# ANTHROPIC_API_KEY=
|
| 13 |
+
|
| 14 |
+
# OpenAI (for openai/gpt-* models)
|
| 15 |
+
# OPENAI_API_KEY=
|
| 16 |
+
|
| 17 |
+
# OpenRouter (for openrouter/* models, e.g. openrouter/anthropic/claude-sonnet-4.5)
|
| 18 |
+
OPENROUTER_API_KEY=
|
| 19 |
+
|
| 20 |
+
# ── OpenSpace Cloud (optional) ──────────────────────────────
|
| 21 |
+
# Register at https://open-space.cloud to get your key.
|
| 22 |
+
# Enables cloud skill search & upload; local features work without it.
|
| 23 |
+
|
| 24 |
+
OPENSPACE_API_KEY=sk_xxxxxxxxxxxxxxxx
|
| 25 |
+
|
| 26 |
+
# ---- GUI Backend (Anthropic Computer Use) ----
|
| 27 |
+
# Required only if using the GUI backend. Uses the same ANTHROPIC_API_KEY above.
|
| 28 |
+
# Optional backup key for rate limit fallback:
|
| 29 |
+
# ANTHROPIC_API_KEY_BACKUP=
|
| 30 |
+
|
| 31 |
+
# ---- Web Backend (Deep Research) ----
|
| 32 |
+
# Required only if using the Web backend for deep research.
|
| 33 |
+
# Uses OpenRouter API by default:
|
| 34 |
+
# OPENROUTER_API_KEY=
|
| 35 |
+
|
| 36 |
+
# ---- Embedding (Optional) ----
|
| 37 |
+
# For remote embedding API instead of local model.
|
| 38 |
+
# If not set, OpenSpace uses a local embedding model (BAAI/bge-small-en-v1.5).
|
| 39 |
+
# EMBEDDING_BASE_URL=
|
| 40 |
+
# EMBEDDING_API_KEY=
|
| 41 |
+
# EMBEDDING_MODEL= "openai/text-embedding-3-small"
|
| 42 |
+
|
| 43 |
+
# ---- E2B Sandbox (Optional) ----
|
| 44 |
+
# Required only if sandbox mode is enabled in security config.
|
| 45 |
+
# E2B_API_KEY=
|
| 46 |
+
|
| 47 |
+
# ---- Local Server (Optional) ----
|
| 48 |
+
# Override the default local server URL (default: http://127.0.0.1:5000)
|
| 49 |
+
# Useful for remote VM integration (e.g., OSWorld).
|
| 50 |
+
# LOCAL_SERVER_URL=http://127.0.0.1:5000
|
| 51 |
+
|
| 52 |
+
# ---- Debug (Optional) ----
|
| 53 |
+
# OPENSPACE_DEBUG=true
|
openspace/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module as _imp
|
| 2 |
+
from typing import Dict as _Dict, Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
if _TYPE_CHECKING:
|
| 5 |
+
from openspace.tool_layer import OpenSpace as OpenSpace, OpenSpaceConfig as OpenSpaceConfig
|
| 6 |
+
from openspace.agents import GroundingAgent as GroundingAgent
|
| 7 |
+
from openspace.llm import LLMClient as LLMClient
|
| 8 |
+
from openspace.recording import RecordingManager as RecordingManager
|
| 9 |
+
|
| 10 |
+
__version__ = "0.1.0"
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
# Version
|
| 14 |
+
"__version__",
|
| 15 |
+
|
| 16 |
+
# Main API
|
| 17 |
+
"OpenSpace",
|
| 18 |
+
"OpenSpaceConfig",
|
| 19 |
+
|
| 20 |
+
# Core Components
|
| 21 |
+
"GroundingAgent",
|
| 22 |
+
"GroundingClient",
|
| 23 |
+
"LLMClient",
|
| 24 |
+
"BaseTool",
|
| 25 |
+
"ToolResult",
|
| 26 |
+
"BackendType",
|
| 27 |
+
|
| 28 |
+
# Recording System
|
| 29 |
+
"RecordingManager",
|
| 30 |
+
"RecordingViewer",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# Map attribute → sub-module that provides it
|
| 34 |
+
_attr_to_module: _Dict[str, str] = {
|
| 35 |
+
# Main API
|
| 36 |
+
"OpenSpace": "openspace.tool_layer",
|
| 37 |
+
"OpenSpaceConfig": "openspace.tool_layer",
|
| 38 |
+
|
| 39 |
+
# Core Components
|
| 40 |
+
"GroundingAgent": "openspace.agents",
|
| 41 |
+
"GroundingClient": "openspace.grounding.core.grounding_client",
|
| 42 |
+
"LLMClient": "openspace.llm",
|
| 43 |
+
"BaseTool": "openspace.grounding.core.tool.base",
|
| 44 |
+
"ToolResult": "openspace.grounding.core.types",
|
| 45 |
+
"BackendType": "openspace.grounding.core.types",
|
| 46 |
+
|
| 47 |
+
# Recording System
|
| 48 |
+
"RecordingManager": "openspace.recording",
|
| 49 |
+
"RecordingViewer": "openspace.recording.viewer",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def __getattr__(name: str) -> _Any:
|
| 54 |
+
"""Dynamically import sub-modules on first attribute access.
|
| 55 |
+
|
| 56 |
+
This keeps the *initial* package import lightweight and avoids raising
|
| 57 |
+
`ModuleNotFoundError` for optional / heavy dependencies until the
|
| 58 |
+
corresponding functionality is explicitly used.
|
| 59 |
+
"""
|
| 60 |
+
if name not in _attr_to_module:
|
| 61 |
+
raise AttributeError(f"module 'openspace' has no attribute '{name}'")
|
| 62 |
+
|
| 63 |
+
module_name = _attr_to_module[name]
|
| 64 |
+
module = _imp(module_name)
|
| 65 |
+
value = getattr(module, name)
|
| 66 |
+
globals()[name] = value
|
| 67 |
+
return value
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def __dir__():
|
| 71 |
+
return sorted(list(globals().keys()) + list(_attr_to_module.keys()))
|
openspace/__main__.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import argparse
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from openspace.tool_layer import OpenSpace, OpenSpaceConfig
|
| 8 |
+
from openspace.utils.logging import Logger
|
| 9 |
+
from openspace.utils.ui import create_ui, OpenSpaceUI
|
| 10 |
+
from openspace.utils.ui_integration import UIIntegration
|
| 11 |
+
from openspace.utils.cli_display import CLIDisplay
|
| 12 |
+
from openspace.utils.display import colorize
|
| 13 |
+
|
| 14 |
+
logger = Logger.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class UIManager:
|
| 18 |
+
def __init__(self, ui: Optional[OpenSpaceUI], ui_integration: Optional[UIIntegration]):
|
| 19 |
+
self.ui = ui
|
| 20 |
+
self.ui_integration = ui_integration
|
| 21 |
+
self._original_log_levels = {}
|
| 22 |
+
|
| 23 |
+
async def start_live_display(self):
|
| 24 |
+
if not self.ui or not self.ui_integration:
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
print()
|
| 28 |
+
print(colorize(" ▣ Starting real-time visualization...", 'c'))
|
| 29 |
+
print()
|
| 30 |
+
await asyncio.sleep(1)
|
| 31 |
+
|
| 32 |
+
self._suppress_logs()
|
| 33 |
+
|
| 34 |
+
await self.ui.start_live_display()
|
| 35 |
+
await self.ui_integration.start_monitoring(poll_interval=2.0)
|
| 36 |
+
|
| 37 |
+
async def stop_live_display(self):
|
| 38 |
+
if not self.ui or not self.ui_integration:
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
await self.ui_integration.stop_monitoring()
|
| 42 |
+
await self.ui.stop_live_display()
|
| 43 |
+
|
| 44 |
+
self._restore_logs()
|
| 45 |
+
|
| 46 |
+
def print_summary(self, result: dict):
|
| 47 |
+
if self.ui:
|
| 48 |
+
self.ui.print_summary(result)
|
| 49 |
+
else:
|
| 50 |
+
CLIDisplay.print_result_summary(result)
|
| 51 |
+
|
| 52 |
+
def _suppress_logs(self):
|
| 53 |
+
log_names = ["openspace", "openspace.grounding", "openspace.agents"]
|
| 54 |
+
for name in log_names:
|
| 55 |
+
log = logging.getLogger(name)
|
| 56 |
+
self._original_log_levels[name] = log.level
|
| 57 |
+
log.setLevel(logging.CRITICAL)
|
| 58 |
+
|
| 59 |
+
def _restore_logs(self):
|
| 60 |
+
for name, level in self._original_log_levels.items():
|
| 61 |
+
logging.getLogger(name).setLevel(level)
|
| 62 |
+
self._original_log_levels.clear()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def _execute_task(openspace: OpenSpace, query: str, ui_manager: UIManager):
|
| 66 |
+
await ui_manager.start_live_display()
|
| 67 |
+
result = await openspace.execute(query)
|
| 68 |
+
await ui_manager.stop_live_display()
|
| 69 |
+
ui_manager.print_summary(result)
|
| 70 |
+
return result
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def interactive_mode(openspace: OpenSpace, ui_manager: UIManager):
|
| 74 |
+
CLIDisplay.print_interactive_header()
|
| 75 |
+
|
| 76 |
+
while True:
|
| 77 |
+
try:
|
| 78 |
+
prompt = colorize(">>> ", 'c', bold=True)
|
| 79 |
+
query = input(f"\n{prompt}").strip()
|
| 80 |
+
|
| 81 |
+
if not query:
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
if query.lower() in ['exit', 'quit', 'q']:
|
| 85 |
+
print("\nExiting...")
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
if query.lower() == 'status':
|
| 89 |
+
_print_status(openspace)
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
if query.lower() == 'help':
|
| 93 |
+
CLIDisplay.print_help()
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
CLIDisplay.print_task_header(query)
|
| 97 |
+
await _execute_task(openspace, query, ui_manager)
|
| 98 |
+
|
| 99 |
+
except KeyboardInterrupt:
|
| 100 |
+
print("\n\nInterrupt signal detected, exiting...")
|
| 101 |
+
break
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Error: {e}", exc_info=True)
|
| 104 |
+
print(f"\nError: {e}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def single_query_mode(openspace: OpenSpace, query: str, ui_manager: UIManager):
|
| 108 |
+
CLIDisplay.print_task_header(query, title="▶ Single Query Execution")
|
| 109 |
+
await _execute_task(openspace, query, ui_manager)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _print_status(openspace: OpenSpace):
|
| 113 |
+
"""Print system status"""
|
| 114 |
+
from openspace.utils.display import Box, BoxStyle
|
| 115 |
+
|
| 116 |
+
box = Box(width=70, style=BoxStyle.ROUNDED, color='bl')
|
| 117 |
+
print()
|
| 118 |
+
print(box.text_line(colorize("System Status", 'bl', bold=True),
|
| 119 |
+
align='center', indent=4, text_color=''))
|
| 120 |
+
print(box.separator_line(indent=4))
|
| 121 |
+
|
| 122 |
+
status_lines = [
|
| 123 |
+
f"Initialized: {colorize('Yes' if openspace.is_initialized() else 'No', 'g' if openspace.is_initialized() else 'rd')}",
|
| 124 |
+
f"Running: {colorize('Yes' if openspace.is_running() else 'No', 'y' if openspace.is_running() else 'g')}",
|
| 125 |
+
f"Model: {colorize(openspace.config.llm_model, 'c')}",
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
if openspace.is_initialized():
|
| 129 |
+
backends = openspace.list_backends()
|
| 130 |
+
status_lines.append(f"Backends: {colorize(', '.join(backends), 'c')}")
|
| 131 |
+
|
| 132 |
+
sessions = openspace.list_sessions()
|
| 133 |
+
status_lines.append(f"Active Sessions: {colorize(str(len(sessions)), 'y')}")
|
| 134 |
+
|
| 135 |
+
for line in status_lines:
|
| 136 |
+
print(box.text_line(f" {line}", indent=4, text_color=''))
|
| 137 |
+
|
| 138 |
+
print(box.bottom_line(indent=4))
|
| 139 |
+
print()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _create_argument_parser() -> argparse.ArgumentParser:
|
| 143 |
+
"""Create command-line argument parser"""
|
| 144 |
+
parser = argparse.ArgumentParser(
|
| 145 |
+
description='OpenSpace - Self-Evolving Skill Worker & Community',
|
| 146 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Subcommands
|
| 150 |
+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
| 151 |
+
|
| 152 |
+
# refresh-cache subcommand
|
| 153 |
+
cache_parser = subparsers.add_parser(
|
| 154 |
+
'refresh-cache',
|
| 155 |
+
help='Refresh MCP tool cache (starts all servers once)'
|
| 156 |
+
)
|
| 157 |
+
cache_parser.add_argument(
|
| 158 |
+
'--config', '-c', type=str,
|
| 159 |
+
help='MCP configuration file path'
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Basic arguments (for run mode)
|
| 163 |
+
parser.add_argument('--config', '-c', type=str, help='Configuration file path (JSON format)')
|
| 164 |
+
parser.add_argument('--query', '-q', type=str, help='Single query mode: execute query directly')
|
| 165 |
+
|
| 166 |
+
# LLM arguments
|
| 167 |
+
parser.add_argument('--model', '-m', type=str, help='LLM model name')
|
| 168 |
+
|
| 169 |
+
# Logging arguments
|
| 170 |
+
parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], help='Log level')
|
| 171 |
+
|
| 172 |
+
# Execution arguments
|
| 173 |
+
parser.add_argument('--max-iterations', type=int, help='Maximum iteration count')
|
| 174 |
+
parser.add_argument('--timeout', type=float, help='LLM API call timeout (seconds)')
|
| 175 |
+
|
| 176 |
+
# UI arguments
|
| 177 |
+
parser.add_argument('--interactive', '-i', action='store_true', help='Force interactive mode')
|
| 178 |
+
parser.add_argument('--no-ui', action='store_true', help='Disable visualization UI')
|
| 179 |
+
parser.add_argument('--ui-compact', action='store_true', help='Use compact UI layout')
|
| 180 |
+
|
| 181 |
+
return parser
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def refresh_mcp_cache(config_path: Optional[str] = None):
|
| 185 |
+
"""Refresh MCP tool cache by starting servers one by one and saving tool metadata."""
|
| 186 |
+
from openspace.grounding.backends.mcp import MCPProvider, get_tool_cache
|
| 187 |
+
from openspace.grounding.core.types import SessionConfig, BackendType
|
| 188 |
+
from openspace.config import load_config, get_config
|
| 189 |
+
|
| 190 |
+
print("Refreshing MCP tool cache...")
|
| 191 |
+
print("Servers will be started one by one (start -> get tools -> close).")
|
| 192 |
+
print()
|
| 193 |
+
|
| 194 |
+
# Load config
|
| 195 |
+
if config_path:
|
| 196 |
+
config = load_config(config_path)
|
| 197 |
+
else:
|
| 198 |
+
config = get_config()
|
| 199 |
+
|
| 200 |
+
# Get MCP config
|
| 201 |
+
mcp_config = getattr(config, 'mcp', None) or {}
|
| 202 |
+
if hasattr(mcp_config, 'model_dump'):
|
| 203 |
+
mcp_config = mcp_config.model_dump()
|
| 204 |
+
|
| 205 |
+
# Skip dependency checks for refresh-cache (servers are pre-validated)
|
| 206 |
+
mcp_config["check_dependencies"] = False
|
| 207 |
+
|
| 208 |
+
# Create provider
|
| 209 |
+
provider = MCPProvider(config=mcp_config)
|
| 210 |
+
await provider.initialize()
|
| 211 |
+
|
| 212 |
+
servers = provider.list_servers()
|
| 213 |
+
total = len(servers)
|
| 214 |
+
print(f"Found {total} MCP servers configured")
|
| 215 |
+
print()
|
| 216 |
+
|
| 217 |
+
cache = get_tool_cache()
|
| 218 |
+
cache.set_server_order(servers) # Preserve config order when saving
|
| 219 |
+
total_tools = 0
|
| 220 |
+
success_count = 0
|
| 221 |
+
skipped_count = 0
|
| 222 |
+
failed_servers = []
|
| 223 |
+
|
| 224 |
+
# Load existing cache to skip already processed servers
|
| 225 |
+
existing_cache = cache.get_all_tools()
|
| 226 |
+
|
| 227 |
+
# Timeout for each server (in seconds)
|
| 228 |
+
SERVER_TIMEOUT = 60
|
| 229 |
+
|
| 230 |
+
# Process servers one by one
|
| 231 |
+
for i, server_name in enumerate(servers, 1):
|
| 232 |
+
# Skip if already cached (resume support)
|
| 233 |
+
if server_name in existing_cache:
|
| 234 |
+
cached_tools = existing_cache[server_name]
|
| 235 |
+
total_tools += len(cached_tools)
|
| 236 |
+
skipped_count += 1
|
| 237 |
+
print(f"[{i}/{total}] {server_name}... ⏭ cached ({len(cached_tools)} tools)")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
print(f"[{i}/{total}] {server_name}...", end=" ", flush=True)
|
| 241 |
+
session_id = f"mcp-{server_name}"
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
# Create session and get tools with timeout protection
|
| 245 |
+
async with asyncio.timeout(SERVER_TIMEOUT):
|
| 246 |
+
# Create session for this server
|
| 247 |
+
cfg = SessionConfig(
|
| 248 |
+
session_name=session_id,
|
| 249 |
+
backend_type=BackendType.MCP,
|
| 250 |
+
connection_params={"server": server_name},
|
| 251 |
+
)
|
| 252 |
+
session = await provider.create_session(cfg)
|
| 253 |
+
|
| 254 |
+
# Get tools from this server
|
| 255 |
+
tools = await session.list_tools()
|
| 256 |
+
|
| 257 |
+
# Convert to metadata format
|
| 258 |
+
tool_metadata = []
|
| 259 |
+
for tool in tools:
|
| 260 |
+
tool_metadata.append({
|
| 261 |
+
"name": tool.schema.name,
|
| 262 |
+
"description": tool.schema.description or "",
|
| 263 |
+
"parameters": tool.schema.parameters or {},
|
| 264 |
+
})
|
| 265 |
+
|
| 266 |
+
# Save to cache (incremental)
|
| 267 |
+
cache.save_server(server_name, tool_metadata)
|
| 268 |
+
|
| 269 |
+
# Close session immediately to free resources
|
| 270 |
+
await provider.close_session(session_id)
|
| 271 |
+
|
| 272 |
+
total_tools += len(tools)
|
| 273 |
+
success_count += 1
|
| 274 |
+
print(f"✓ {len(tools)} tools")
|
| 275 |
+
|
| 276 |
+
except asyncio.TimeoutError:
|
| 277 |
+
error_msg = f"Timeout after {SERVER_TIMEOUT}s"
|
| 278 |
+
failed_servers.append((server_name, error_msg))
|
| 279 |
+
print(f"✗ {error_msg}")
|
| 280 |
+
|
| 281 |
+
# Save failed server info to cache
|
| 282 |
+
cache.save_failed_server(server_name, error_msg)
|
| 283 |
+
|
| 284 |
+
# Try to close session if it was created
|
| 285 |
+
try:
|
| 286 |
+
await provider.close_session(session_id)
|
| 287 |
+
except Exception:
|
| 288 |
+
pass
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
error_msg = str(e)
|
| 292 |
+
failed_servers.append((server_name, error_msg))
|
| 293 |
+
print(f"✗ {error_msg[:50]}")
|
| 294 |
+
|
| 295 |
+
# Save failed server info to cache
|
| 296 |
+
cache.save_failed_server(server_name, error_msg)
|
| 297 |
+
|
| 298 |
+
# Try to close session if it was created
|
| 299 |
+
try:
|
| 300 |
+
await provider.close_session(session_id)
|
| 301 |
+
except Exception:
|
| 302 |
+
pass
|
| 303 |
+
|
| 304 |
+
print()
|
| 305 |
+
print(f"{'='*50}")
|
| 306 |
+
print(f"✓ Collected {total_tools} tools from {success_count + skipped_count}/{total} servers")
|
| 307 |
+
if skipped_count > 0:
|
| 308 |
+
print(f" (skipped {skipped_count} cached, processed {success_count} new)")
|
| 309 |
+
print(f"✓ Cache saved to: {cache.cache_path}")
|
| 310 |
+
|
| 311 |
+
if failed_servers:
|
| 312 |
+
print(f"✗ Failed servers ({len(failed_servers)}):")
|
| 313 |
+
for name, err in failed_servers[:10]:
|
| 314 |
+
print(f" - {name}: {err[:60]}")
|
| 315 |
+
if len(failed_servers) > 10:
|
| 316 |
+
print(f" ... and {len(failed_servers) - 10} more (see cache file for details)")
|
| 317 |
+
|
| 318 |
+
print()
|
| 319 |
+
print("Done! Future list_tools() calls will use cache (no server startup).")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _load_config(args) -> OpenSpaceConfig:
|
| 323 |
+
"""Load configuration"""
|
| 324 |
+
cli_overrides = {}
|
| 325 |
+
if args.model:
|
| 326 |
+
cli_overrides['llm_model'] = args.model
|
| 327 |
+
if args.max_iterations is not None:
|
| 328 |
+
cli_overrides['grounding_max_iterations'] = args.max_iterations
|
| 329 |
+
if args.timeout is not None:
|
| 330 |
+
cli_overrides['llm_timeout'] = args.timeout
|
| 331 |
+
if args.log_level:
|
| 332 |
+
cli_overrides['log_level'] = args.log_level
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
# Load from config file if provided
|
| 336 |
+
if args.config:
|
| 337 |
+
import json
|
| 338 |
+
with open(args.config, 'r', encoding='utf-8') as f:
|
| 339 |
+
config_dict = json.load(f)
|
| 340 |
+
|
| 341 |
+
# Apply CLI overrides
|
| 342 |
+
config_dict.update(cli_overrides)
|
| 343 |
+
config = OpenSpaceConfig(**config_dict)
|
| 344 |
+
|
| 345 |
+
print(f"✓ Loaded from config file: {args.config}")
|
| 346 |
+
else:
|
| 347 |
+
# Use default config + CLI overrides
|
| 348 |
+
config = OpenSpaceConfig(**cli_overrides)
|
| 349 |
+
print("✓ Using default configuration")
|
| 350 |
+
|
| 351 |
+
if cli_overrides:
|
| 352 |
+
print(f"✓ CLI overrides: {', '.join(cli_overrides.keys())}")
|
| 353 |
+
|
| 354 |
+
if args.log_level:
|
| 355 |
+
Logger.set_level(args.log_level)
|
| 356 |
+
|
| 357 |
+
return config
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.error(f"Failed to load configuration: {e}")
|
| 361 |
+
sys.exit(1)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _setup_ui(args) -> tuple[Optional[OpenSpaceUI], Optional[UIIntegration]]:
|
| 365 |
+
if args.no_ui:
|
| 366 |
+
CLIDisplay.print_banner()
|
| 367 |
+
return None, None
|
| 368 |
+
|
| 369 |
+
ui = create_ui(enable_live=True, compact=args.ui_compact)
|
| 370 |
+
ui.print_banner()
|
| 371 |
+
ui_integration = UIIntegration(ui)
|
| 372 |
+
return ui, ui_integration
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
async def _initialize_openspace(config: OpenSpaceConfig, args) -> OpenSpace:
|
| 376 |
+
openspace = OpenSpace(config)
|
| 377 |
+
|
| 378 |
+
init_steps = [("Initializing OpenSpace...", "loading")]
|
| 379 |
+
CLIDisplay.print_initialization_progress(init_steps, show_header=False)
|
| 380 |
+
|
| 381 |
+
if not args.config:
|
| 382 |
+
original_log_level = Logger.get_logger("openspace").level
|
| 383 |
+
for log_name in ["openspace", "openspace.grounding", "openspace.agents"]:
|
| 384 |
+
Logger.get_logger(log_name).setLevel(logging.WARNING)
|
| 385 |
+
|
| 386 |
+
await openspace.initialize()
|
| 387 |
+
|
| 388 |
+
# Restore log level
|
| 389 |
+
if not args.config:
|
| 390 |
+
for log_name in ["openspace", "openspace.grounding", "openspace.agents"]:
|
| 391 |
+
Logger.get_logger(log_name).setLevel(original_log_level)
|
| 392 |
+
|
| 393 |
+
# Print initialization results
|
| 394 |
+
backends = openspace.list_backends()
|
| 395 |
+
init_steps = [
|
| 396 |
+
("LLM Client", "ok"),
|
| 397 |
+
(f"Grounding Backends ({len(backends)} available)", "ok"),
|
| 398 |
+
("Grounding Agent", "ok"),
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
if config.enable_recording:
|
| 402 |
+
init_steps.append(("Recording Manager", "ok"))
|
| 403 |
+
|
| 404 |
+
CLIDisplay.print_initialization_progress(init_steps, show_header=True)
|
| 405 |
+
|
| 406 |
+
return openspace
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
async def main():
|
| 410 |
+
parser = _create_argument_parser()
|
| 411 |
+
args = parser.parse_args()
|
| 412 |
+
|
| 413 |
+
# Handle subcommands
|
| 414 |
+
if args.command == 'refresh-cache':
|
| 415 |
+
await refresh_mcp_cache(args.config)
|
| 416 |
+
return 0
|
| 417 |
+
|
| 418 |
+
# Load configuration
|
| 419 |
+
config = _load_config(args)
|
| 420 |
+
|
| 421 |
+
# Setup UI
|
| 422 |
+
ui, ui_integration = _setup_ui(args)
|
| 423 |
+
|
| 424 |
+
# Print configuration
|
| 425 |
+
CLIDisplay.print_configuration(config)
|
| 426 |
+
|
| 427 |
+
openspace = None
|
| 428 |
+
|
| 429 |
+
try:
|
| 430 |
+
# Initialize OpenSpace
|
| 431 |
+
openspace = await _initialize_openspace(config, args)
|
| 432 |
+
|
| 433 |
+
# Connect UI (if enabled)
|
| 434 |
+
if ui_integration:
|
| 435 |
+
ui_integration.attach_llm_client(openspace._llm_client)
|
| 436 |
+
ui_integration.attach_grounding_client(openspace._grounding_client)
|
| 437 |
+
CLIDisplay.print_system_ready()
|
| 438 |
+
|
| 439 |
+
ui_manager = UIManager(ui, ui_integration)
|
| 440 |
+
|
| 441 |
+
# Run appropriate mode
|
| 442 |
+
if args.query:
|
| 443 |
+
await single_query_mode(openspace, args.query, ui_manager)
|
| 444 |
+
else:
|
| 445 |
+
await interactive_mode(openspace, ui_manager)
|
| 446 |
+
|
| 447 |
+
except KeyboardInterrupt:
|
| 448 |
+
print("\n\nInterrupt signal detected")
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Error: {e}", exc_info=True)
|
| 451 |
+
print(f"\nError: {e}")
|
| 452 |
+
return 1
|
| 453 |
+
finally:
|
| 454 |
+
if openspace:
|
| 455 |
+
print("\nCleaning up resources...")
|
| 456 |
+
await openspace.cleanup()
|
| 457 |
+
|
| 458 |
+
print("\nGoodbye!")
|
| 459 |
+
return 0
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def run_main():
|
| 463 |
+
"""Run main function"""
|
| 464 |
+
try:
|
| 465 |
+
exit_code = asyncio.run(main())
|
| 466 |
+
sys.exit(exit_code)
|
| 467 |
+
except KeyboardInterrupt:
|
| 468 |
+
print("\n\nProgram interrupted")
|
| 469 |
+
sys.exit(0)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
if __name__ == "__main__":
|
| 473 |
+
run_main()
|
openspace/agents/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openspace.agents.base import BaseAgent, AgentStatus, AgentRegistry
|
| 2 |
+
from openspace.agents.grounding_agent import GroundingAgent
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"BaseAgent",
|
| 6 |
+
"AgentStatus",
|
| 7 |
+
"AgentRegistry",
|
| 8 |
+
"GroundingAgent",
|
| 9 |
+
]
|
openspace/agents/base.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Any
|
| 6 |
+
|
| 7 |
+
from openspace.utils.logging import Logger
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from openspace.llm import LLMClient
|
| 11 |
+
from openspace.grounding.core.grounding_client import GroundingClient
|
| 12 |
+
from openspace.recording import RecordingManager
|
| 13 |
+
|
| 14 |
+
logger = Logger.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseAgent(ABC):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
name: str,
|
| 21 |
+
backend_scope: Optional[List[str]] = None,
|
| 22 |
+
llm_client: Optional[LLMClient] = None,
|
| 23 |
+
grounding_client: Optional[GroundingClient] = None,
|
| 24 |
+
recording_manager: Optional[RecordingManager] = None,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Initialize the BaseAgent.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
name: Unique name for the agent
|
| 31 |
+
backend_scope: List of backend types this agent can access (e.g., ["gui", "shell", "mcp", "web", "system"])
|
| 32 |
+
llm_client: LLM client for agent reasoning (optional, can be set later)
|
| 33 |
+
grounding_client: Reference to GroundingClient for tool execution
|
| 34 |
+
recording_manager: RecordingManager for recording execution
|
| 35 |
+
"""
|
| 36 |
+
self._name = name
|
| 37 |
+
self._grounding_client: Optional[GroundingClient] = grounding_client
|
| 38 |
+
self._backend_scope = backend_scope or []
|
| 39 |
+
self._llm_client = llm_client
|
| 40 |
+
self._recording_manager: Optional[RecordingManager] = recording_manager
|
| 41 |
+
self._step = 0
|
| 42 |
+
self._status = AgentStatus.ACTIVE
|
| 43 |
+
|
| 44 |
+
self._register_self()
|
| 45 |
+
logger.info(f"Initialized {self.__class__.__name__}: {name}")
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def name(self) -> str:
|
| 49 |
+
return self._name
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def grounding_client(self) -> Optional[GroundingClient]:
|
| 53 |
+
"""Get the grounding client."""
|
| 54 |
+
return self._grounding_client
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def backend_scope(self) -> List[str]:
|
| 58 |
+
return self._backend_scope
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def llm_client(self) -> Optional[LLMClient]:
|
| 62 |
+
return self._llm_client
|
| 63 |
+
|
| 64 |
+
@llm_client.setter
|
| 65 |
+
def llm_client(self, client: LLMClient) -> None:
|
| 66 |
+
self._llm_client = client
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def recording_manager(self) -> Optional[RecordingManager]:
|
| 70 |
+
"""Get the recording manager."""
|
| 71 |
+
return self._recording_manager
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def step(self) -> int:
|
| 75 |
+
return self._step
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def status(self) -> str:
|
| 79 |
+
return self._status
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
async def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def construct_messages(self, context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 87 |
+
"""
|
| 88 |
+
Construct messages for LLM reasoning.
|
| 89 |
+
Context must contain 'instruction' key.
|
| 90 |
+
"""
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
async def get_llm_response(
|
| 94 |
+
self,
|
| 95 |
+
messages: List[Dict[str, Any]],
|
| 96 |
+
tools: Optional[List] = None,
|
| 97 |
+
**kwargs
|
| 98 |
+
) -> Dict[str, Any]:
|
| 99 |
+
if not self._llm_client:
|
| 100 |
+
raise ValueError(f"LLM client not initialized for agent {self.name}")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
response = await self._llm_client.complete(
|
| 104 |
+
messages=messages,
|
| 105 |
+
tools=tools,
|
| 106 |
+
**kwargs
|
| 107 |
+
)
|
| 108 |
+
return response
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"{self.name}: LLM call failed: {e}", exc_info=True)
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
def response_to_dict(self, response: str) -> Dict[str, Any]:
|
| 114 |
+
try:
|
| 115 |
+
if response.strip().startswith("```json") or response.strip().startswith("```"):
|
| 116 |
+
lines = response.strip().split('\n')
|
| 117 |
+
if lines and lines[0].startswith('```'):
|
| 118 |
+
lines = lines[1:]
|
| 119 |
+
end_idx = len(lines)
|
| 120 |
+
for i, line in enumerate(lines):
|
| 121 |
+
if line.strip() == '```':
|
| 122 |
+
end_idx = i
|
| 123 |
+
break
|
| 124 |
+
response = '\n'.join(lines[:end_idx])
|
| 125 |
+
|
| 126 |
+
return json.loads(response)
|
| 127 |
+
except json.JSONDecodeError as e:
|
| 128 |
+
# If parsing fails, try to find and extract just the JSON object/array
|
| 129 |
+
if "Extra data" in str(e):
|
| 130 |
+
try:
|
| 131 |
+
decoder = json.JSONDecoder()
|
| 132 |
+
obj, idx = decoder.raw_decode(response)
|
| 133 |
+
logger.warning(
|
| 134 |
+
f"{self.name}: Successfully extracted JSON but found extra text after position {idx}. "
|
| 135 |
+
f"Extra text: {response[idx:idx+100]}..."
|
| 136 |
+
)
|
| 137 |
+
return obj
|
| 138 |
+
except Exception as e2:
|
| 139 |
+
logger.error(f"{self.name}: Failed to extract JSON even with raw_decode: {e2}")
|
| 140 |
+
|
| 141 |
+
logger.error(f"{self.name}: Failed to parse response: {e}")
|
| 142 |
+
logger.error(f"{self.name}: Response content: {response[:500]}")
|
| 143 |
+
return {"error": "Failed to parse response", "raw": response}
|
| 144 |
+
|
| 145 |
+
def increment_step(self) -> None:
|
| 146 |
+
self._step += 1
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def _register_self(cls) -> None:
|
| 150 |
+
"""Register the agent class in the registry upon instantiation."""
|
| 151 |
+
# Get the actual instance class, not BaseAgent
|
| 152 |
+
if cls.__name__ != "BaseAgent" and cls.__name__ not in AgentRegistry._registry:
|
| 153 |
+
AgentRegistry.register(cls.__name__, cls)
|
| 154 |
+
|
| 155 |
+
def __repr__(self) -> str:
|
| 156 |
+
return f"<{self.__class__.__name__}(name={self.name}, step={self.step}, status={self.status})>"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class AgentStatus:
|
| 160 |
+
"""Constants for agent status."""
|
| 161 |
+
ACTIVE = "active"
|
| 162 |
+
IDLE = "idle"
|
| 163 |
+
WAITING = "waiting"
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class AgentRegistry:
|
| 167 |
+
"""
|
| 168 |
+
Registry for managing agent classes.
|
| 169 |
+
Allows dynamic registration and retrieval of agent types.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
_registry: Dict[str, Type[BaseAgent]] = {}
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def register(cls, name: str, agent_cls: Type[BaseAgent]) -> None:
|
| 176 |
+
if name in cls._registry:
|
| 177 |
+
logger.warning(f"Agent class '{name}' already registered, overwriting")
|
| 178 |
+
cls._registry[name] = agent_cls
|
| 179 |
+
logger.debug(f"Registered agent class: {name}")
|
| 180 |
+
|
| 181 |
+
@classmethod
|
| 182 |
+
def get_cls(cls, name: str) -> Type[BaseAgent]:
|
| 183 |
+
if name not in cls._registry:
|
| 184 |
+
raise ValueError(f"No agent class registered under '{name}'")
|
| 185 |
+
return cls._registry[name]
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def list_registered(cls) -> List[str]:
|
| 189 |
+
return list(cls._registry.keys())
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def clear(cls) -> None:
|
| 193 |
+
cls._registry.clear()
|
| 194 |
+
logger.debug("Agent registry cleared")
|
openspace/agents/grounding_agent.py
ADDED
|
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from openspace.agents.base import BaseAgent
|
| 8 |
+
from openspace.grounding.core.types import BackendType, ToolResult
|
| 9 |
+
from openspace.platforms.screenshot import ScreenshotClient
|
| 10 |
+
from openspace.prompts import GroundingAgentPrompts
|
| 11 |
+
from openspace.utils.logging import Logger
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from openspace.llm import LLMClient
|
| 15 |
+
from openspace.grounding.core.grounding_client import GroundingClient
|
| 16 |
+
from openspace.recording import RecordingManager
|
| 17 |
+
from openspace.skill_engine import SkillRegistry
|
| 18 |
+
|
| 19 |
+
logger = Logger.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GroundingAgent(BaseAgent):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
name: str = "GroundingAgent",
|
| 26 |
+
backend_scope: Optional[List[str]] = None,
|
| 27 |
+
llm_client: Optional[LLMClient] = None,
|
| 28 |
+
grounding_client: Optional[GroundingClient] = None,
|
| 29 |
+
recording_manager: Optional[RecordingManager] = None,
|
| 30 |
+
system_prompt: Optional[str] = None,
|
| 31 |
+
max_iterations: int = 15,
|
| 32 |
+
visual_analysis_timeout: float = 30.0,
|
| 33 |
+
tool_retrieval_llm: Optional[LLMClient] = None,
|
| 34 |
+
visual_analysis_model: Optional[str] = None,
|
| 35 |
+
) -> None:
|
| 36 |
+
"""
|
| 37 |
+
Initialize the Grounding Agent.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
name: Agent name
|
| 41 |
+
backend_scope: List of backends this agent can access (None = all available)
|
| 42 |
+
llm_client: LLM client for reasoning
|
| 43 |
+
grounding_client: GroundingClient for tool execution
|
| 44 |
+
recording_manager: RecordingManager for recording execution
|
| 45 |
+
system_prompt: Custom system prompt
|
| 46 |
+
max_iterations: Maximum LLM reasoning iterations for self-correction
|
| 47 |
+
visual_analysis_timeout: Timeout for visual analysis LLM calls in seconds
|
| 48 |
+
tool_retrieval_llm: LLM client for tool retrieval filter (None = use llm_client)
|
| 49 |
+
visual_analysis_model: Model name for visual analysis (None = use llm_client.model)
|
| 50 |
+
"""
|
| 51 |
+
super().__init__(
|
| 52 |
+
name=name,
|
| 53 |
+
backend_scope=backend_scope or ["gui", "shell", "mcp", "web", "system"],
|
| 54 |
+
llm_client=llm_client,
|
| 55 |
+
grounding_client=grounding_client,
|
| 56 |
+
recording_manager=recording_manager
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self._system_prompt = system_prompt or self._default_system_prompt()
|
| 60 |
+
self._max_iterations = max_iterations
|
| 61 |
+
self._visual_analysis_timeout = visual_analysis_timeout
|
| 62 |
+
self._tool_retrieval_llm = tool_retrieval_llm
|
| 63 |
+
self._visual_analysis_model = visual_analysis_model
|
| 64 |
+
|
| 65 |
+
# Skill context injection (set externally before process())
|
| 66 |
+
self._skill_context: Optional[str] = None
|
| 67 |
+
self._active_skill_ids: List[str] = []
|
| 68 |
+
|
| 69 |
+
# Skill registry for mid-iteration retrieve_skill tool
|
| 70 |
+
self._skill_registry: Optional["SkillRegistry"] = None
|
| 71 |
+
|
| 72 |
+
# Tools from the last execution (available for post-execution analysis)
|
| 73 |
+
self._last_tools: List = []
|
| 74 |
+
|
| 75 |
+
logger.info(f"Grounding Agent initialized: {name}")
|
| 76 |
+
logger.info(f"Backend scope: {self._backend_scope}")
|
| 77 |
+
logger.info(f"Max iterations: {self._max_iterations}")
|
| 78 |
+
logger.info(f"Visual analysis timeout: {self._visual_analysis_timeout}s")
|
| 79 |
+
if tool_retrieval_llm:
|
| 80 |
+
logger.info(f"Tool retrieval model: {tool_retrieval_llm.model}")
|
| 81 |
+
if visual_analysis_model:
|
| 82 |
+
logger.info(f"Visual analysis model: {visual_analysis_model}")
|
| 83 |
+
|
| 84 |
+
def set_skill_context(
|
| 85 |
+
self,
|
| 86 |
+
context: str,
|
| 87 |
+
skill_ids: Optional[List[str]] = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
"""Inject skill guidance into the agent's system prompt.
|
| 90 |
+
|
| 91 |
+
Called by ``OpenSpace.execute()`` before ``process()`` when skills
|
| 92 |
+
are matched. The context is a formatted string built by
|
| 93 |
+
``SkillRegistry.build_context_injection()``.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
context: Formatted skill content for system prompt injection.
|
| 97 |
+
skill_ids: skill_id values of injected skills.
|
| 98 |
+
"""
|
| 99 |
+
self._skill_context = context if context else None
|
| 100 |
+
self._active_skill_ids = skill_ids or []
|
| 101 |
+
if self._skill_context:
|
| 102 |
+
logger.info(f"Skill context set: {', '.join(self._active_skill_ids) or '(unnamed)'}")
|
| 103 |
+
|
| 104 |
+
def clear_skill_context(self) -> None:
|
| 105 |
+
"""Remove skill guidance (used before fallback execution)."""
|
| 106 |
+
if self._skill_context:
|
| 107 |
+
logger.info(f"Skill context cleared (was: {', '.join(self._active_skill_ids)})")
|
| 108 |
+
self._skill_context = None
|
| 109 |
+
self._active_skill_ids = []
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def has_skill_context(self) -> bool:
|
| 113 |
+
return self._skill_context is not None
|
| 114 |
+
|
| 115 |
+
def set_skill_registry(self, registry: Optional["SkillRegistry"]) -> None:
|
| 116 |
+
"""Attach a SkillRegistry so the agent can offer ``retrieve_skill`` as a tool."""
|
| 117 |
+
self._skill_registry = registry
|
| 118 |
+
if registry:
|
| 119 |
+
count = len(registry.list_skills())
|
| 120 |
+
logger.info(f"Skill registry attached ({count} skill(s) available for mid-iteration retrieval)")
|
| 121 |
+
|
| 122 |
+
_MAX_SINGLE_CONTENT_CHARS = 30_000
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def _cap_message_content(cls, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 126 |
+
"""Truncate oversized individual message contents in-place.
|
| 127 |
+
|
| 128 |
+
Targets tool-result messages and assistant messages that can
|
| 129 |
+
carry enormous file contents (read_file on large CSVs/scripts).
|
| 130 |
+
System messages and the first user instruction are never touched.
|
| 131 |
+
"""
|
| 132 |
+
cap = cls._MAX_SINGLE_CONTENT_CHARS
|
| 133 |
+
trimmed = 0
|
| 134 |
+
for msg in messages:
|
| 135 |
+
content = msg.get("content")
|
| 136 |
+
if not isinstance(content, str) or len(content) <= cap:
|
| 137 |
+
continue
|
| 138 |
+
if msg.get("role") == "system":
|
| 139 |
+
continue
|
| 140 |
+
original_len = len(content)
|
| 141 |
+
msg["content"] = (
|
| 142 |
+
content[: cap // 2]
|
| 143 |
+
+ f"\n\n... [truncated {original_len - cap:,} chars] ...\n\n"
|
| 144 |
+
+ content[-(cap // 2):]
|
| 145 |
+
)
|
| 146 |
+
trimmed += 1
|
| 147 |
+
if trimmed:
|
| 148 |
+
logger.info(f"Capped {trimmed} oversized message(s) to {cap:,} chars each")
|
| 149 |
+
return messages
|
| 150 |
+
|
| 151 |
+
def _truncate_messages(
|
| 152 |
+
self,
|
| 153 |
+
messages: List[Dict[str, Any]],
|
| 154 |
+
keep_recent: int = 8,
|
| 155 |
+
max_tokens_estimate: int = 120000
|
| 156 |
+
) -> List[Dict[str, Any]]:
|
| 157 |
+
# First: cap any single oversized message to prevent one huge
|
| 158 |
+
# tool-result from dominating the context window.
|
| 159 |
+
messages = self._cap_message_content(messages)
|
| 160 |
+
|
| 161 |
+
if len(messages) <= keep_recent + 2: # +2 for system and initial user
|
| 162 |
+
return messages
|
| 163 |
+
|
| 164 |
+
total_text = json.dumps(messages, ensure_ascii=False)
|
| 165 |
+
estimated_tokens = len(total_text) // 4
|
| 166 |
+
|
| 167 |
+
if estimated_tokens < max_tokens_estimate:
|
| 168 |
+
return messages
|
| 169 |
+
|
| 170 |
+
logger.info(f"Truncating message history: {len(messages)} messages, "
|
| 171 |
+
f"~{estimated_tokens:,} tokens -> keeping recent {keep_recent} rounds")
|
| 172 |
+
|
| 173 |
+
system_messages = []
|
| 174 |
+
user_instruction = None
|
| 175 |
+
conversation_messages = []
|
| 176 |
+
|
| 177 |
+
for msg in messages:
|
| 178 |
+
role = msg.get("role")
|
| 179 |
+
if role == "system":
|
| 180 |
+
system_messages.append(msg)
|
| 181 |
+
elif role == "user" and user_instruction is None:
|
| 182 |
+
user_instruction = msg
|
| 183 |
+
else:
|
| 184 |
+
conversation_messages.append(msg)
|
| 185 |
+
|
| 186 |
+
recent_messages = conversation_messages[-(keep_recent * 2):] if conversation_messages else []
|
| 187 |
+
|
| 188 |
+
truncated = system_messages.copy()
|
| 189 |
+
if user_instruction:
|
| 190 |
+
truncated.append(user_instruction)
|
| 191 |
+
truncated.extend(recent_messages)
|
| 192 |
+
|
| 193 |
+
logger.info(f"After truncation: {len(truncated)} messages, "
|
| 194 |
+
f"~{len(json.dumps(truncated, ensure_ascii=False))//4:,} tokens (estimated)")
|
| 195 |
+
|
| 196 |
+
return truncated
|
| 197 |
+
|
| 198 |
+
async def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 199 |
+
"""
|
| 200 |
+
Process a task execution request with multi-round iteration control.
|
| 201 |
+
"""
|
| 202 |
+
instruction = context.get("instruction", "")
|
| 203 |
+
if not instruction:
|
| 204 |
+
logger.error("Grounding Agent: No instruction provided")
|
| 205 |
+
return {"error": "No instruction provided", "status": "error"}
|
| 206 |
+
|
| 207 |
+
# Store current instruction for visual analysis context
|
| 208 |
+
self._current_instruction = instruction
|
| 209 |
+
|
| 210 |
+
logger.info(f"Grounding Agent: Processing instruction at step {self.step}")
|
| 211 |
+
|
| 212 |
+
# Exist workspace files check
|
| 213 |
+
workspace_info = await self._check_workspace_artifacts(context)
|
| 214 |
+
if workspace_info["has_files"]:
|
| 215 |
+
context["workspace_artifacts"] = workspace_info
|
| 216 |
+
logger.info(f"Workspace has {len(workspace_info['files'])} existing files: {workspace_info['files']}")
|
| 217 |
+
|
| 218 |
+
# Get available tools (auto-search with cap)
|
| 219 |
+
tools = await self._get_available_tools(instruction)
|
| 220 |
+
self._last_tools = tools # expose for post-execution analysis
|
| 221 |
+
|
| 222 |
+
# Get search debug info (similarity scores, LLM selections)
|
| 223 |
+
search_debug_info = None
|
| 224 |
+
if self.grounding_client:
|
| 225 |
+
search_debug_info = self.grounding_client.get_last_search_debug_info()
|
| 226 |
+
|
| 227 |
+
# Build retrieved tools list for return value
|
| 228 |
+
retrieved_tools_list = []
|
| 229 |
+
for tool in tools:
|
| 230 |
+
tool_info = {
|
| 231 |
+
"name": getattr(tool, "name", str(tool)),
|
| 232 |
+
"description": getattr(tool, "description", ""),
|
| 233 |
+
}
|
| 234 |
+
# Prefer runtime_info.backend
|
| 235 |
+
# over backend_type (may be NOT_SET for cached RemoteTools)
|
| 236 |
+
runtime_info = getattr(tool, "_runtime_info", None)
|
| 237 |
+
if runtime_info and hasattr(runtime_info, "backend"):
|
| 238 |
+
tool_info["backend"] = runtime_info.backend.value if hasattr(runtime_info.backend, "value") else str(runtime_info.backend)
|
| 239 |
+
tool_info["server_name"] = runtime_info.server_name
|
| 240 |
+
elif hasattr(tool, "backend_type"):
|
| 241 |
+
tool_info["backend"] = tool.backend_type.value if hasattr(tool.backend_type, "value") else str(tool.backend_type)
|
| 242 |
+
|
| 243 |
+
# Add similarity score if available
|
| 244 |
+
if search_debug_info and search_debug_info.get("tool_scores"):
|
| 245 |
+
for score_info in search_debug_info["tool_scores"]:
|
| 246 |
+
if score_info["name"] == tool_info["name"]:
|
| 247 |
+
tool_info["similarity_score"] = score_info["score"]
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
retrieved_tools_list.append(tool_info)
|
| 251 |
+
|
| 252 |
+
# Record retrieved tools
|
| 253 |
+
if self._recording_manager:
|
| 254 |
+
from openspace.recording import RecordingManager
|
| 255 |
+
await RecordingManager.record_retrieved_tools(
|
| 256 |
+
task_instruction=instruction,
|
| 257 |
+
tools=tools,
|
| 258 |
+
search_debug_info=search_debug_info,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Initialize iteration state
|
| 262 |
+
max_iterations = context.get("max_iterations", self._max_iterations)
|
| 263 |
+
current_iteration = 0
|
| 264 |
+
all_tool_results = []
|
| 265 |
+
iteration_contexts = []
|
| 266 |
+
consecutive_empty_responses = 0 # Track consecutive empty LLM responses
|
| 267 |
+
MAX_CONSECUTIVE_EMPTY = 5 # Exit after this many empty responses
|
| 268 |
+
|
| 269 |
+
# Build initial messages
|
| 270 |
+
messages = self.construct_messages(context)
|
| 271 |
+
|
| 272 |
+
# Record initial conversation setup once (system prompts + user instruction + tool definitions)
|
| 273 |
+
from openspace.recording import RecordingManager
|
| 274 |
+
await RecordingManager.record_conversation_setup(
|
| 275 |
+
setup_messages=copy.deepcopy(messages),
|
| 276 |
+
tools=tools,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
while current_iteration < max_iterations:
|
| 281 |
+
current_iteration += 1
|
| 282 |
+
logger.info(f"Grounding Agent: Iteration {current_iteration}/{max_iterations}")
|
| 283 |
+
|
| 284 |
+
# Strip skill context after the first iteration to save prompt tokens.
|
| 285 |
+
# Skills only need to guide the first LLM call; subsequent iterations
|
| 286 |
+
# already have the plan and tool results in context.
|
| 287 |
+
if current_iteration == 2 and self._skill_context:
|
| 288 |
+
skill_ctx = self._skill_context
|
| 289 |
+
messages = [
|
| 290 |
+
m for m in messages
|
| 291 |
+
if not (m.get("role") == "system" and m.get("content") == skill_ctx)
|
| 292 |
+
]
|
| 293 |
+
logger.info("Skill context removed from messages after first iteration")
|
| 294 |
+
|
| 295 |
+
# Cap oversized individual messages every iteration to prevent
|
| 296 |
+
# a single huge tool result from ballooning all subsequent calls.
|
| 297 |
+
if current_iteration >= 2:
|
| 298 |
+
messages = self._cap_message_content(messages)
|
| 299 |
+
|
| 300 |
+
# Truncate message history to prevent context length issues
|
| 301 |
+
# Start truncating after 5 iterations to keep context manageable
|
| 302 |
+
if current_iteration >= 5:
|
| 303 |
+
messages = self._truncate_messages(
|
| 304 |
+
messages,
|
| 305 |
+
keep_recent=8,
|
| 306 |
+
max_tokens_estimate=120000
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
messages_input_snapshot = copy.deepcopy(messages)
|
| 310 |
+
|
| 311 |
+
# [DISABLED] Iteration summary generation
|
| 312 |
+
# Tool results (including visual analysis) are already in context,
|
| 313 |
+
# LLM can make decisions directly without separate summary.
|
| 314 |
+
# To re-enable, uncomment below and pass iteration_summary_prompt to complete()
|
| 315 |
+
# iteration_summary_prompt = GroundingAgentPrompts.iteration_summary(
|
| 316 |
+
# instruction=instruction,
|
| 317 |
+
# iteration=current_iteration,
|
| 318 |
+
# max_iterations=max_iterations
|
| 319 |
+
# ) if context.get("auto_execute", True) else None
|
| 320 |
+
|
| 321 |
+
# Call LLMClient for single round
|
| 322 |
+
# LLM will decide whether to call tools or finish with <COMPLETE>
|
| 323 |
+
llm_response = await self._llm_client.complete(
|
| 324 |
+
messages=messages,
|
| 325 |
+
tools=tools if context.get("auto_execute", True) else None,
|
| 326 |
+
execute_tools=context.get("auto_execute", True),
|
| 327 |
+
summary_prompt=None, # Disabled
|
| 328 |
+
tool_result_callback=self._visual_analysis_callback
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Update messages with LLM response
|
| 332 |
+
messages = llm_response["messages"]
|
| 333 |
+
|
| 334 |
+
# Collect tool results
|
| 335 |
+
tool_results_this_iteration = llm_response.get("tool_results", [])
|
| 336 |
+
if tool_results_this_iteration:
|
| 337 |
+
all_tool_results.extend(tool_results_this_iteration)
|
| 338 |
+
|
| 339 |
+
# [DISABLED] Iteration summary logging
|
| 340 |
+
# llm_summary = llm_response.get("iteration_summary")
|
| 341 |
+
# if llm_summary:
|
| 342 |
+
# logger.info(f"Iteration {current_iteration} summary: {llm_summary[:150]}...")
|
| 343 |
+
|
| 344 |
+
assistant_message = llm_response.get("message", {})
|
| 345 |
+
assistant_content = assistant_message.get("content", "")
|
| 346 |
+
|
| 347 |
+
has_tool_calls = llm_response.get('has_tool_calls', False)
|
| 348 |
+
logger.info(f"Iteration {current_iteration} - Has tool calls: {has_tool_calls}, "
|
| 349 |
+
f"Tool results: {len(tool_results_this_iteration)}, "
|
| 350 |
+
f"Content length: {len(assistant_content)} chars")
|
| 351 |
+
|
| 352 |
+
if len(assistant_content) > 0:
|
| 353 |
+
logger.info(f"Iteration {current_iteration} - Assistant content preview: {repr(assistant_content[:300])}")
|
| 354 |
+
consecutive_empty_responses = 0 # Reset counter on valid response
|
| 355 |
+
else:
|
| 356 |
+
if not has_tool_calls:
|
| 357 |
+
consecutive_empty_responses += 1
|
| 358 |
+
logger.warning(f"Iteration {current_iteration} - NO tool calls and NO content "
|
| 359 |
+
f"(empty response {consecutive_empty_responses}/{MAX_CONSECUTIVE_EMPTY})")
|
| 360 |
+
|
| 361 |
+
if consecutive_empty_responses >= MAX_CONSECUTIVE_EMPTY:
|
| 362 |
+
logger.error(f"Exiting due to {MAX_CONSECUTIVE_EMPTY} consecutive empty LLM responses. "
|
| 363 |
+
"This may indicate API issues, rate limiting, or context too long.")
|
| 364 |
+
break
|
| 365 |
+
else:
|
| 366 |
+
consecutive_empty_responses = 0 # Reset if we have tool calls
|
| 367 |
+
|
| 368 |
+
# Snapshot messages after LLM call (accumulated context)
|
| 369 |
+
messages_output_snapshot = copy.deepcopy(messages)
|
| 370 |
+
|
| 371 |
+
# Delta messages: only the messages produced in this iteration
|
| 372 |
+
# (avoids repeating system prompts / initial user instruction each time)
|
| 373 |
+
delta_messages = messages[len(messages_input_snapshot):]
|
| 374 |
+
|
| 375 |
+
# Response metadata (lightweight; full content lives in delta_messages)
|
| 376 |
+
response_metadata = {
|
| 377 |
+
"has_tool_calls": has_tool_calls,
|
| 378 |
+
"tool_calls_count": len(tool_results_this_iteration),
|
| 379 |
+
}
|
| 380 |
+
iteration_context = {
|
| 381 |
+
"iteration": current_iteration,
|
| 382 |
+
"messages_input": messages_input_snapshot,
|
| 383 |
+
"messages_output": messages_output_snapshot,
|
| 384 |
+
"response_metadata": response_metadata,
|
| 385 |
+
}
|
| 386 |
+
iteration_contexts.append(iteration_context)
|
| 387 |
+
|
| 388 |
+
# Real-time save to conversations.jsonl (delta only, no redundancy)
|
| 389 |
+
await RecordingManager.record_iteration_context(
|
| 390 |
+
iteration=current_iteration,
|
| 391 |
+
delta_messages=copy.deepcopy(delta_messages),
|
| 392 |
+
response_metadata=response_metadata,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Check for completion token in assistant content
|
| 396 |
+
# [DISABLED] Also check in iteration summary when enabled
|
| 397 |
+
# is_complete = (
|
| 398 |
+
# GroundingAgentPrompts.TASK_COMPLETE in assistant_content or
|
| 399 |
+
# (llm_summary and GroundingAgentPrompts.TASK_COMPLETE in llm_summary)
|
| 400 |
+
# )
|
| 401 |
+
is_complete = GroundingAgentPrompts.TASK_COMPLETE in assistant_content
|
| 402 |
+
|
| 403 |
+
if is_complete:
|
| 404 |
+
# Task is complete - LLM generated completion token
|
| 405 |
+
logger.info(f"Task completed at iteration {current_iteration} (found {GroundingAgentPrompts.TASK_COMPLETE})")
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
else:
|
| 409 |
+
# LLM didn't generate <COMPLETE>, continue to next iteration
|
| 410 |
+
if tool_results_this_iteration:
|
| 411 |
+
logger.debug(f"Task in progress, LLM called {len(tool_results_this_iteration)} tools")
|
| 412 |
+
else:
|
| 413 |
+
logger.debug(f"Task in progress, LLM did not generate <COMPLETE>")
|
| 414 |
+
|
| 415 |
+
# Remove previous iteration guidance to avoid accumulation
|
| 416 |
+
messages = [
|
| 417 |
+
msg for msg in messages
|
| 418 |
+
if not (msg.get("role") == "system" and "Iteration" in msg.get("content", "") and "complete" in msg.get("content", ""))
|
| 419 |
+
]
|
| 420 |
+
|
| 421 |
+
guidance_msg = {
|
| 422 |
+
"role": "system",
|
| 423 |
+
"content": f"Iteration {current_iteration} complete. "
|
| 424 |
+
f"Check if task is finished - if yes, output {GroundingAgentPrompts.TASK_COMPLETE}. "
|
| 425 |
+
f"If not, continue with next action."
|
| 426 |
+
}
|
| 427 |
+
messages.append(guidance_msg)
|
| 428 |
+
|
| 429 |
+
# [DISABLED] Full iteration feedback with summary
|
| 430 |
+
# self._remove_previous_guidance(messages)
|
| 431 |
+
# feedback_msg = self._build_iteration_feedback(
|
| 432 |
+
# iteration=current_iteration,
|
| 433 |
+
# llm_summary=llm_summary,
|
| 434 |
+
# add_guidance=True
|
| 435 |
+
# )
|
| 436 |
+
# if feedback_msg:
|
| 437 |
+
# messages.append(feedback_msg)
|
| 438 |
+
# logger.debug(f"Added iteration {current_iteration} feedback with guidance")
|
| 439 |
+
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
# Build final result
|
| 443 |
+
result = await self._build_final_result(
|
| 444 |
+
instruction=instruction,
|
| 445 |
+
messages=messages,
|
| 446 |
+
all_tool_results=all_tool_results,
|
| 447 |
+
iterations=current_iteration,
|
| 448 |
+
max_iterations=max_iterations,
|
| 449 |
+
iteration_contexts=iteration_contexts,
|
| 450 |
+
retrieved_tools_list=retrieved_tools_list,
|
| 451 |
+
search_debug_info=search_debug_info,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# Record agent action to recording manager
|
| 455 |
+
if self._recording_manager:
|
| 456 |
+
await self._record_agent_execution(result, instruction)
|
| 457 |
+
|
| 458 |
+
# Increment step
|
| 459 |
+
self.increment_step()
|
| 460 |
+
|
| 461 |
+
logger.info(f"Grounding Agent: Execution completed with status: {result.get('status')}")
|
| 462 |
+
return result
|
| 463 |
+
|
| 464 |
+
except Exception as e:
|
| 465 |
+
logger.error(f"Grounding Agent: Execution failed: {e}")
|
| 466 |
+
result = {
|
| 467 |
+
"error": str(e),
|
| 468 |
+
"status": "error",
|
| 469 |
+
"instruction": instruction,
|
| 470 |
+
"iteration": current_iteration
|
| 471 |
+
}
|
| 472 |
+
self.increment_step()
|
| 473 |
+
return result
|
| 474 |
+
|
| 475 |
+
def _default_system_prompt(self) -> str:
|
| 476 |
+
"""Default system prompt tailored to the agent's actual backend scope."""
|
| 477 |
+
return GroundingAgentPrompts.build_system_prompt(self._backend_scope)
|
| 478 |
+
|
| 479 |
+
def construct_messages(
|
| 480 |
+
self,
|
| 481 |
+
context: Dict[str, Any]
|
| 482 |
+
) -> List[Dict[str, Any]]:
|
| 483 |
+
messages = [{"role": "system", "content": self._system_prompt}]
|
| 484 |
+
|
| 485 |
+
# Get instruction from context
|
| 486 |
+
instruction = context.get("instruction", "")
|
| 487 |
+
if not instruction:
|
| 488 |
+
raise ValueError("context must contain 'instruction' field")
|
| 489 |
+
|
| 490 |
+
# Add workspace directory
|
| 491 |
+
workspace_dir = context.get("workspace_dir")
|
| 492 |
+
if workspace_dir:
|
| 493 |
+
messages.append({
|
| 494 |
+
"role": "system",
|
| 495 |
+
"content": GroundingAgentPrompts.workspace_directory(workspace_dir)
|
| 496 |
+
})
|
| 497 |
+
|
| 498 |
+
# Add workspace artifacts information
|
| 499 |
+
workspace_artifacts = context.get("workspace_artifacts")
|
| 500 |
+
if workspace_artifacts and workspace_artifacts.get("has_files"):
|
| 501 |
+
files = workspace_artifacts.get("files", [])
|
| 502 |
+
matching_files = workspace_artifacts.get("matching_files", [])
|
| 503 |
+
recent_files = workspace_artifacts.get("recent_files", [])
|
| 504 |
+
|
| 505 |
+
if matching_files:
|
| 506 |
+
artifact_msg = GroundingAgentPrompts.workspace_matching_files(matching_files)
|
| 507 |
+
elif len(recent_files) >= 2:
|
| 508 |
+
artifact_msg = GroundingAgentPrompts.workspace_recent_files(
|
| 509 |
+
total_files=len(files),
|
| 510 |
+
recent_files=recent_files
|
| 511 |
+
)
|
| 512 |
+
else:
|
| 513 |
+
artifact_msg = GroundingAgentPrompts.workspace_file_list(files)
|
| 514 |
+
|
| 515 |
+
messages.append({
|
| 516 |
+
"role": "system",
|
| 517 |
+
"content": artifact_msg
|
| 518 |
+
})
|
| 519 |
+
|
| 520 |
+
# Skill injection — only active (selected) skills, full content
|
| 521 |
+
if self._skill_context:
|
| 522 |
+
messages.append({
|
| 523 |
+
"role": "system",
|
| 524 |
+
"content": self._skill_context
|
| 525 |
+
})
|
| 526 |
+
logger.info(f"Injected active skill context ({len(self._active_skill_ids)} skill(s))")
|
| 527 |
+
|
| 528 |
+
# User instruction
|
| 529 |
+
messages.append({"role": "user", "content": instruction})
|
| 530 |
+
|
| 531 |
+
return messages
|
| 532 |
+
|
| 533 |
+
async def _get_available_tools(self, task_description: Optional[str]) -> List:
|
| 534 |
+
"""
|
| 535 |
+
Retrieve tools for the current execution phase.
|
| 536 |
+
|
| 537 |
+
Both skill-augmented and normal modes use the same
|
| 538 |
+
``get_tools_with_auto_search`` pipeline:
|
| 539 |
+
- Non-MCP tools (shell, gui, web, system) are always included.
|
| 540 |
+
- MCP tools are filtered by relevance only when their count
|
| 541 |
+
exceeds ``max_tools``.
|
| 542 |
+
|
| 543 |
+
When skills are active, the shell backend is guaranteed to be in
|
| 544 |
+
scope (skills commonly reference ``shell_agent``).
|
| 545 |
+
|
| 546 |
+
Falls back to returning all tools if anything fails.
|
| 547 |
+
"""
|
| 548 |
+
grounding_client = self.grounding_client
|
| 549 |
+
if not grounding_client:
|
| 550 |
+
return []
|
| 551 |
+
|
| 552 |
+
backends = [BackendType(name) for name in self._backend_scope]
|
| 553 |
+
|
| 554 |
+
# Ensure shell backend is available when skills are active
|
| 555 |
+
# (skills commonly reference shell_agent, read_file, etc.)
|
| 556 |
+
if self.has_skill_context:
|
| 557 |
+
shell_bt = BackendType.SHELL
|
| 558 |
+
if shell_bt not in backends:
|
| 559 |
+
backends = list(backends) + [shell_bt]
|
| 560 |
+
logger.info("Added Shell backend to scope for skill file I/O")
|
| 561 |
+
|
| 562 |
+
try:
|
| 563 |
+
retrieval_llm = self._tool_retrieval_llm or self._llm_client
|
| 564 |
+
tools = await grounding_client.get_tools_with_auto_search(
|
| 565 |
+
task_description=task_description,
|
| 566 |
+
backend=backends,
|
| 567 |
+
use_cache=True,
|
| 568 |
+
llm_callable=retrieval_llm,
|
| 569 |
+
)
|
| 570 |
+
logger.info(
|
| 571 |
+
f"GroundingAgent selected {len(tools)} tools (auto-search) "
|
| 572 |
+
f"from {len(backends)} backends"
|
| 573 |
+
+ (f" [skill-augmented]" if self.has_skill_context else "")
|
| 574 |
+
)
|
| 575 |
+
except Exception as e:
|
| 576 |
+
logger.warning(f"Auto-search tools failed, falling back to full list: {e}")
|
| 577 |
+
tools = await self._load_all_tools(grounding_client)
|
| 578 |
+
|
| 579 |
+
# Append retrieve_skill tool when skill registry is available
|
| 580 |
+
if self._skill_registry and self._skill_registry.list_skills():
|
| 581 |
+
from openspace.skill_engine.retrieve_tool import RetrieveSkillTool
|
| 582 |
+
retrieve_llm = self._tool_retrieval_llm or self._llm_client
|
| 583 |
+
retrieve_tool = RetrieveSkillTool(
|
| 584 |
+
self._skill_registry,
|
| 585 |
+
backends=[b.value for b in backends],
|
| 586 |
+
llm_client=retrieve_llm,
|
| 587 |
+
skill_store=getattr(self, "_skill_store", None),
|
| 588 |
+
)
|
| 589 |
+
retrieve_tool.bind_runtime_info(
|
| 590 |
+
backend=BackendType.SYSTEM,
|
| 591 |
+
session_name="internal",
|
| 592 |
+
)
|
| 593 |
+
tools.append(retrieve_tool)
|
| 594 |
+
logger.info("Added retrieve_skill tool for mid-iteration skill retrieval")
|
| 595 |
+
|
| 596 |
+
return tools
|
| 597 |
+
|
| 598 |
+
async def _load_all_tools(self, grounding_client: "GroundingClient") -> List:
|
| 599 |
+
"""Fallback: load all tools from all backends without search."""
|
| 600 |
+
all_tools = []
|
| 601 |
+
for backend_name in self._backend_scope:
|
| 602 |
+
try:
|
| 603 |
+
backend_type = BackendType(backend_name)
|
| 604 |
+
tools = await grounding_client.list_tools(backend=backend_type)
|
| 605 |
+
all_tools.extend(tools)
|
| 606 |
+
logger.debug(f"Retrieved {len(tools)} tools from backend: {backend_name}")
|
| 607 |
+
except Exception as e:
|
| 608 |
+
logger.debug(f"Could not get tools from {backend_name}: {e}")
|
| 609 |
+
|
| 610 |
+
logger.info(
|
| 611 |
+
f"GroundingAgent fallback retrieved {len(all_tools)} tools "
|
| 612 |
+
f"from {len(self._backend_scope)} backends"
|
| 613 |
+
)
|
| 614 |
+
return all_tools
|
| 615 |
+
|
| 616 |
+
async def _visual_analysis_callback(
|
| 617 |
+
self,
|
| 618 |
+
result: ToolResult,
|
| 619 |
+
tool_name: str,
|
| 620 |
+
tool_call: Dict,
|
| 621 |
+
backend: str
|
| 622 |
+
) -> ToolResult:
|
| 623 |
+
"""
|
| 624 |
+
Callback for LLMClient to handle visual analysis after tool execution.
|
| 625 |
+
"""
|
| 626 |
+
# 1. Check if LLM requested to skip visual analysis
|
| 627 |
+
skip_visual_analysis = False
|
| 628 |
+
try:
|
| 629 |
+
arguments = tool_call.function.arguments
|
| 630 |
+
if isinstance(arguments, str):
|
| 631 |
+
args = json.loads(arguments.strip() or "{}")
|
| 632 |
+
else:
|
| 633 |
+
args = arguments
|
| 634 |
+
|
| 635 |
+
if isinstance(args, dict) and args.get("skip_visual_analysis"):
|
| 636 |
+
skip_visual_analysis = True
|
| 637 |
+
logger.info(f"Visual analysis skipped for {tool_name} (meta-parameter set by LLM)")
|
| 638 |
+
except Exception as e:
|
| 639 |
+
logger.debug(f"Could not parse tool arguments: {e}")
|
| 640 |
+
|
| 641 |
+
# 2. If skip requested, return original result
|
| 642 |
+
if skip_visual_analysis:
|
| 643 |
+
return result
|
| 644 |
+
|
| 645 |
+
# 3. Check if this backend needs visual analysis
|
| 646 |
+
if backend != "gui":
|
| 647 |
+
return result
|
| 648 |
+
|
| 649 |
+
# 4. Check if tool has visual data
|
| 650 |
+
metadata = getattr(result, 'metadata', None)
|
| 651 |
+
has_screenshots = metadata and (metadata.get("screenshot") or metadata.get("screenshots"))
|
| 652 |
+
|
| 653 |
+
# 5. If no visual data, try to capture a screenshot
|
| 654 |
+
if not has_screenshots:
|
| 655 |
+
try:
|
| 656 |
+
logger.info(f"No visual data from {tool_name}, capturing screenshot...")
|
| 657 |
+
screenshot_client = ScreenshotClient()
|
| 658 |
+
screenshot_bytes = await screenshot_client.capture()
|
| 659 |
+
|
| 660 |
+
if screenshot_bytes:
|
| 661 |
+
# Add screenshot to result metadata
|
| 662 |
+
if metadata is None:
|
| 663 |
+
result.metadata = {}
|
| 664 |
+
metadata = result.metadata
|
| 665 |
+
metadata["screenshot"] = screenshot_bytes
|
| 666 |
+
has_screenshots = True
|
| 667 |
+
logger.info(f"Screenshot captured for visual analysis")
|
| 668 |
+
else:
|
| 669 |
+
logger.warning("Failed to capture screenshot")
|
| 670 |
+
except Exception as e:
|
| 671 |
+
logger.warning(f"Error capturing screenshot: {e}")
|
| 672 |
+
|
| 673 |
+
# 6. If still no screenshots, return original result
|
| 674 |
+
if not has_screenshots:
|
| 675 |
+
logger.debug(f"No visual data available for {tool_name}")
|
| 676 |
+
return result
|
| 677 |
+
|
| 678 |
+
# 7. Perform visual analysis
|
| 679 |
+
return await self._enhance_result_with_visual_context(result, tool_name)
|
| 680 |
+
|
| 681 |
+
async def _enhance_result_with_visual_context(
|
| 682 |
+
self,
|
| 683 |
+
result: ToolResult,
|
| 684 |
+
tool_name: str
|
| 685 |
+
) -> ToolResult:
|
| 686 |
+
"""
|
| 687 |
+
Enhance tool result with visual analysis for grounding agent workflows.
|
| 688 |
+
"""
|
| 689 |
+
import asyncio
|
| 690 |
+
import base64
|
| 691 |
+
import litellm
|
| 692 |
+
|
| 693 |
+
try:
|
| 694 |
+
metadata = getattr(result, 'metadata', None)
|
| 695 |
+
if not metadata:
|
| 696 |
+
return result
|
| 697 |
+
|
| 698 |
+
# Collect all screenshots
|
| 699 |
+
screenshots_bytes = []
|
| 700 |
+
|
| 701 |
+
# Check for multiple screenshots first
|
| 702 |
+
if metadata.get("screenshots"):
|
| 703 |
+
screenshots_list = metadata["screenshots"]
|
| 704 |
+
if isinstance(screenshots_list, list):
|
| 705 |
+
screenshots_bytes = [s for s in screenshots_list if s]
|
| 706 |
+
# Fall back to single screenshot
|
| 707 |
+
elif metadata.get("screenshot"):
|
| 708 |
+
screenshots_bytes = [metadata["screenshot"]]
|
| 709 |
+
|
| 710 |
+
if not screenshots_bytes:
|
| 711 |
+
return result
|
| 712 |
+
|
| 713 |
+
# Select key screenshots if there are too many
|
| 714 |
+
selected_screenshots = self._select_key_screenshots(screenshots_bytes, max_count=3)
|
| 715 |
+
|
| 716 |
+
# Convert to base64
|
| 717 |
+
visual_b64_list = []
|
| 718 |
+
for visual_data in selected_screenshots:
|
| 719 |
+
if isinstance(visual_data, bytes):
|
| 720 |
+
visual_b64_list.append(base64.b64encode(visual_data).decode('utf-8'))
|
| 721 |
+
else:
|
| 722 |
+
visual_b64_list.append(visual_data) # Already base64
|
| 723 |
+
|
| 724 |
+
# Build prompt based on number of screenshots
|
| 725 |
+
num_screenshots = len(visual_b64_list)
|
| 726 |
+
|
| 727 |
+
prompt = GroundingAgentPrompts.visual_analysis(
|
| 728 |
+
tool_name=tool_name,
|
| 729 |
+
num_screenshots=num_screenshots,
|
| 730 |
+
task_description=getattr(self, '_current_instruction', '')
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Build content with text prompt + all images
|
| 734 |
+
content = [{"type": "text", "text": prompt}]
|
| 735 |
+
for visual_b64 in visual_b64_list:
|
| 736 |
+
content.append({
|
| 737 |
+
"type": "image_url",
|
| 738 |
+
"image_url": {
|
| 739 |
+
"url": f"data:image/png;base64,{visual_b64}"
|
| 740 |
+
}
|
| 741 |
+
})
|
| 742 |
+
|
| 743 |
+
# Use dedicated visual analysis model if configured, otherwise use main LLM model
|
| 744 |
+
visual_model = self._visual_analysis_model or (self._llm_client.model if self._llm_client else "openrouter/anthropic/claude-sonnet-4.5")
|
| 745 |
+
response = await asyncio.wait_for(
|
| 746 |
+
litellm.acompletion(
|
| 747 |
+
model=visual_model,
|
| 748 |
+
messages=[{
|
| 749 |
+
"role": "user",
|
| 750 |
+
"content": content
|
| 751 |
+
}],
|
| 752 |
+
timeout=self._visual_analysis_timeout
|
| 753 |
+
),
|
| 754 |
+
timeout=self._visual_analysis_timeout + 5
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
analysis = response.choices[0].message.content.strip()
|
| 758 |
+
|
| 759 |
+
# Inject visual analysis into content
|
| 760 |
+
original_content = result.content or "(no text output)"
|
| 761 |
+
enhanced_content = f"{original_content}\n\n**Visual content**: {analysis}"
|
| 762 |
+
|
| 763 |
+
# Create enhanced result
|
| 764 |
+
enhanced_result = ToolResult(
|
| 765 |
+
status=result.status,
|
| 766 |
+
content=enhanced_content,
|
| 767 |
+
error=result.error,
|
| 768 |
+
metadata={**metadata, "visual_analyzed": True, "visual_analysis": analysis},
|
| 769 |
+
execution_time=result.execution_time
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
logger.info(f"Enhanced {tool_name} result with visual analysis ({num_screenshots} screenshot(s))")
|
| 773 |
+
return enhanced_result
|
| 774 |
+
|
| 775 |
+
except asyncio.TimeoutError:
|
| 776 |
+
logger.warning(f"Visual analysis timed out for {tool_name}, returning original result")
|
| 777 |
+
return result
|
| 778 |
+
except Exception as e:
|
| 779 |
+
logger.warning(f"Failed to analyze visual content for {tool_name}: {e}")
|
| 780 |
+
return result
|
| 781 |
+
|
| 782 |
+
def _select_key_screenshots(
|
| 783 |
+
self,
|
| 784 |
+
screenshots: List[bytes],
|
| 785 |
+
max_count: int = 3
|
| 786 |
+
) -> List[bytes]:
|
| 787 |
+
"""
|
| 788 |
+
Select key screenshots if there are too many.
|
| 789 |
+
"""
|
| 790 |
+
if len(screenshots) <= max_count:
|
| 791 |
+
return screenshots
|
| 792 |
+
|
| 793 |
+
selected_indices = set()
|
| 794 |
+
|
| 795 |
+
# Always include last (final state)
|
| 796 |
+
selected_indices.add(len(screenshots) - 1)
|
| 797 |
+
|
| 798 |
+
# If room, include first (initial state)
|
| 799 |
+
if max_count >= 2:
|
| 800 |
+
selected_indices.add(0)
|
| 801 |
+
|
| 802 |
+
# Fill remaining slots with evenly spaced middle screenshots
|
| 803 |
+
remaining_slots = max_count - len(selected_indices)
|
| 804 |
+
if remaining_slots > 0:
|
| 805 |
+
# Calculate spacing
|
| 806 |
+
available_indices = [
|
| 807 |
+
i for i in range(1, len(screenshots) - 1)
|
| 808 |
+
if i not in selected_indices
|
| 809 |
+
]
|
| 810 |
+
|
| 811 |
+
if available_indices:
|
| 812 |
+
step = max(1, len(available_indices) // (remaining_slots + 1))
|
| 813 |
+
for i in range(remaining_slots):
|
| 814 |
+
idx = min((i + 1) * step, len(available_indices) - 1)
|
| 815 |
+
if idx < len(available_indices):
|
| 816 |
+
selected_indices.add(available_indices[idx])
|
| 817 |
+
|
| 818 |
+
# Return screenshots in original order
|
| 819 |
+
selected = [screenshots[i] for i in sorted(selected_indices)]
|
| 820 |
+
|
| 821 |
+
logger.debug(
|
| 822 |
+
f"Selected {len(selected)} screenshots at indices {sorted(selected_indices)} "
|
| 823 |
+
f"from total of {len(screenshots)}"
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
return selected
|
| 827 |
+
|
| 828 |
+
def _get_workspace_path(self, context: Dict[str, Any]) -> Optional[str]:
|
| 829 |
+
"""
|
| 830 |
+
Get workspace directory path from context.
|
| 831 |
+
"""
|
| 832 |
+
return context.get("workspace_dir")
|
| 833 |
+
|
| 834 |
+
def _scan_workspace_files(
|
| 835 |
+
self,
|
| 836 |
+
workspace_path: str,
|
| 837 |
+
recent_threshold: int = 600 # seconds
|
| 838 |
+
) -> Dict[str, Any]:
|
| 839 |
+
"""
|
| 840 |
+
Scan workspace directory and collect file information.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
workspace_path: Path to workspace directory
|
| 844 |
+
recent_threshold: Threshold in seconds for recent files
|
| 845 |
+
|
| 846 |
+
Returns:
|
| 847 |
+
Dictionary with file information:
|
| 848 |
+
- files: List of all filenames
|
| 849 |
+
- file_details: Dict mapping filename to file info (size, modified, age_seconds)
|
| 850 |
+
- recent_files: List of recently modified filenames
|
| 851 |
+
"""
|
| 852 |
+
import os
|
| 853 |
+
import time
|
| 854 |
+
|
| 855 |
+
result = {
|
| 856 |
+
"files": [],
|
| 857 |
+
"file_details": {},
|
| 858 |
+
"recent_files": []
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
if not workspace_path or not os.path.exists(workspace_path):
|
| 862 |
+
return result
|
| 863 |
+
|
| 864 |
+
# Recording system files to exclude from workspace scanning
|
| 865 |
+
excluded_files = {"metadata.json", "traj.jsonl"}
|
| 866 |
+
|
| 867 |
+
try:
|
| 868 |
+
current_time = time.time()
|
| 869 |
+
|
| 870 |
+
for filename in os.listdir(workspace_path):
|
| 871 |
+
filepath = os.path.join(workspace_path, filename)
|
| 872 |
+
if os.path.isfile(filepath) and filename not in excluded_files:
|
| 873 |
+
result["files"].append(filename)
|
| 874 |
+
|
| 875 |
+
# Get file stats
|
| 876 |
+
stat = os.stat(filepath)
|
| 877 |
+
file_info = {
|
| 878 |
+
"size": stat.st_size,
|
| 879 |
+
"modified": stat.st_mtime,
|
| 880 |
+
"age_seconds": current_time - stat.st_mtime
|
| 881 |
+
}
|
| 882 |
+
result["file_details"][filename] = file_info
|
| 883 |
+
|
| 884 |
+
# Track recently created/modified files
|
| 885 |
+
if file_info["age_seconds"] < recent_threshold:
|
| 886 |
+
result["recent_files"].append(filename)
|
| 887 |
+
|
| 888 |
+
result["files"] = sorted(result["files"])
|
| 889 |
+
|
| 890 |
+
except Exception as e:
|
| 891 |
+
logger.debug(f"Error scanning workspace files: {e}")
|
| 892 |
+
|
| 893 |
+
return result
|
| 894 |
+
|
| 895 |
+
async def _check_workspace_artifacts(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 896 |
+
"""
|
| 897 |
+
Check workspace directory for existing artifacts that might be relevant to the task.
|
| 898 |
+
Enhanced to detect if task might already be completed.
|
| 899 |
+
"""
|
| 900 |
+
import re
|
| 901 |
+
|
| 902 |
+
workspace_info = {"has_files": False, "files": [], "file_details": {}, "recent_files": []}
|
| 903 |
+
|
| 904 |
+
try:
|
| 905 |
+
# Get workspace path
|
| 906 |
+
workspace_path = self._get_workspace_path(context)
|
| 907 |
+
|
| 908 |
+
# Scan workspace files
|
| 909 |
+
scan_result = self._scan_workspace_files(workspace_path, recent_threshold=600)
|
| 910 |
+
|
| 911 |
+
if scan_result["files"]:
|
| 912 |
+
workspace_info["has_files"] = True
|
| 913 |
+
workspace_info["files"] = scan_result["files"]
|
| 914 |
+
workspace_info["file_details"] = scan_result["file_details"]
|
| 915 |
+
workspace_info["recent_files"] = scan_result["recent_files"]
|
| 916 |
+
|
| 917 |
+
logger.info(f"Grounding Agent: Found {len(scan_result['files'])} existing files in workspace "
|
| 918 |
+
f"({len(scan_result['recent_files'])} recent)")
|
| 919 |
+
|
| 920 |
+
# Check if instruction mentions specific filenames
|
| 921 |
+
instruction = context.get("instruction", "")
|
| 922 |
+
if instruction:
|
| 923 |
+
# Look for potential file references in instruction
|
| 924 |
+
potential_outputs = []
|
| 925 |
+
# Match common file patterns: filename.ext, "filename", 'filename'
|
| 926 |
+
file_patterns = re.findall(r'["\']?([a-zA-Z0-9_\-]+\.[a-zA-Z0-9]+)["\']?', instruction)
|
| 927 |
+
for pattern in file_patterns:
|
| 928 |
+
if pattern in scan_result["files"]:
|
| 929 |
+
potential_outputs.append(pattern)
|
| 930 |
+
|
| 931 |
+
if potential_outputs:
|
| 932 |
+
workspace_info["matching_files"] = potential_outputs
|
| 933 |
+
logger.info(f"Grounding Agent: Found {len(potential_outputs)} files matching task: {potential_outputs}")
|
| 934 |
+
|
| 935 |
+
except Exception as e:
|
| 936 |
+
logger.debug(f"Could not check workspace artifacts: {e}")
|
| 937 |
+
|
| 938 |
+
return workspace_info
|
| 939 |
+
|
| 940 |
+
def _build_iteration_feedback(
|
| 941 |
+
self,
|
| 942 |
+
iteration: int,
|
| 943 |
+
llm_summary: Optional[str] = None,
|
| 944 |
+
add_guidance: bool = True
|
| 945 |
+
) -> Optional[Dict[str, str]]:
|
| 946 |
+
"""
|
| 947 |
+
Build feedback message to add to next iteration.
|
| 948 |
+
"""
|
| 949 |
+
if not llm_summary:
|
| 950 |
+
return None
|
| 951 |
+
|
| 952 |
+
feedback_content = GroundingAgentPrompts.iteration_feedback(
|
| 953 |
+
iteration=iteration,
|
| 954 |
+
llm_summary=llm_summary,
|
| 955 |
+
add_guidance=add_guidance
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
return {
|
| 959 |
+
"role": "system",
|
| 960 |
+
"content": feedback_content
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
def _remove_previous_guidance(self, messages: List[Dict[str, Any]]) -> None:
|
| 964 |
+
"""
|
| 965 |
+
Remove guidance section from previous iteration feedback messages.
|
| 966 |
+
"""
|
| 967 |
+
for msg in messages:
|
| 968 |
+
if msg.get("role") == "system":
|
| 969 |
+
content = msg.get("content", "")
|
| 970 |
+
# Check if this is an iteration feedback message with guidance
|
| 971 |
+
if "## Iteration" in content and "Summary" in content and "---" in content:
|
| 972 |
+
# Remove everything from "---" onwards (the guidance part)
|
| 973 |
+
summary_only = content.split("---")[0].strip()
|
| 974 |
+
msg["content"] = summary_only
|
| 975 |
+
|
| 976 |
+
async def _generate_final_summary(
|
| 977 |
+
self,
|
| 978 |
+
instruction: str,
|
| 979 |
+
messages: List[Dict],
|
| 980 |
+
iterations: int
|
| 981 |
+
) -> tuple[str, bool, List[Dict]]:
|
| 982 |
+
"""
|
| 983 |
+
Generate final summary across all iterations for reporting to upper layer.
|
| 984 |
+
|
| 985 |
+
Returns:
|
| 986 |
+
tuple[str, bool, List[Dict]]: (summary_text, success_flag, context_used)
|
| 987 |
+
- summary_text: The generated summary or error message
|
| 988 |
+
- success_flag: True if summary was generated successfully, False otherwise
|
| 989 |
+
- context_used: The cleaned messages used for generating summary
|
| 990 |
+
"""
|
| 991 |
+
final_summary_prompt = {
|
| 992 |
+
"role": "user",
|
| 993 |
+
"content": GroundingAgentPrompts.final_summary(
|
| 994 |
+
instruction=instruction,
|
| 995 |
+
iterations=iterations
|
| 996 |
+
)
|
| 997 |
+
}
|
| 998 |
+
|
| 999 |
+
clean_messages = []
|
| 1000 |
+
for msg in messages:
|
| 1001 |
+
# Skip tool result messages
|
| 1002 |
+
if msg.get("role") == "tool":
|
| 1003 |
+
continue
|
| 1004 |
+
# Copy message and remove tool_calls if present
|
| 1005 |
+
clean_msg = msg.copy()
|
| 1006 |
+
if "tool_calls" in clean_msg:
|
| 1007 |
+
del clean_msg["tool_calls"]
|
| 1008 |
+
clean_messages.append(clean_msg)
|
| 1009 |
+
|
| 1010 |
+
clean_messages.append(final_summary_prompt)
|
| 1011 |
+
|
| 1012 |
+
# Save context for return
|
| 1013 |
+
context_for_return = copy.deepcopy(clean_messages)
|
| 1014 |
+
|
| 1015 |
+
try:
|
| 1016 |
+
# Call LLMClient to generate final summary (without tools)
|
| 1017 |
+
summary_response = await self._llm_client.complete(
|
| 1018 |
+
messages=clean_messages,
|
| 1019 |
+
tools=None,
|
| 1020 |
+
execute_tools=False
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
final_summary = summary_response.get("message", {}).get("content", "")
|
| 1024 |
+
|
| 1025 |
+
if final_summary:
|
| 1026 |
+
logger.info(f"Generated final summary: {final_summary[:200]}...")
|
| 1027 |
+
return final_summary, True, context_for_return
|
| 1028 |
+
else:
|
| 1029 |
+
logger.warning("LLM returned empty final summary")
|
| 1030 |
+
return f"Task completed after {iterations} iteration(s). Check execution history for details.", True, context_for_return
|
| 1031 |
+
|
| 1032 |
+
except Exception as e:
|
| 1033 |
+
logger.error(f"Error generating final summary: {e}")
|
| 1034 |
+
return f"Task completed after {iterations} iteration(s), but failed to generate summary: {str(e)}", False, context_for_return
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
async def _build_final_result(
|
| 1038 |
+
self,
|
| 1039 |
+
instruction: str,
|
| 1040 |
+
messages: List[Dict],
|
| 1041 |
+
all_tool_results: List[Dict],
|
| 1042 |
+
iterations: int,
|
| 1043 |
+
max_iterations: int,
|
| 1044 |
+
iteration_contexts: List[Dict] = None,
|
| 1045 |
+
retrieved_tools_list: List[Dict] = None,
|
| 1046 |
+
search_debug_info: Dict[str, Any] = None,
|
| 1047 |
+
) -> Dict[str, Any]:
|
| 1048 |
+
"""
|
| 1049 |
+
Build final execution result.
|
| 1050 |
+
|
| 1051 |
+
Args:
|
| 1052 |
+
instruction: Original instruction
|
| 1053 |
+
messages: Complete conversation history (including all iteration summaries)
|
| 1054 |
+
all_tool_results: All tool execution results
|
| 1055 |
+
iterations: Number of iterations performed
|
| 1056 |
+
max_iterations: Maximum allowed iterations
|
| 1057 |
+
iteration_contexts: Context snapshots for each iteration
|
| 1058 |
+
retrieved_tools_list: List of tools retrieved for this task
|
| 1059 |
+
search_debug_info: Debug info from tool search (similarity scores, LLM selections)
|
| 1060 |
+
"""
|
| 1061 |
+
is_complete = self._check_task_completion(messages)
|
| 1062 |
+
|
| 1063 |
+
tool_executions = self._format_tool_executions(all_tool_results)
|
| 1064 |
+
|
| 1065 |
+
result = {
|
| 1066 |
+
"instruction": instruction,
|
| 1067 |
+
"step": self.step,
|
| 1068 |
+
"iterations": iterations,
|
| 1069 |
+
"tool_executions": tool_executions,
|
| 1070 |
+
"messages": messages,
|
| 1071 |
+
"iteration_contexts": iteration_contexts or [],
|
| 1072 |
+
"retrieved_tools_list": retrieved_tools_list or [],
|
| 1073 |
+
"search_debug_info": search_debug_info,
|
| 1074 |
+
"active_skills": list(self._active_skill_ids),
|
| 1075 |
+
"keep_session": True
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
+
if is_complete:
|
| 1079 |
+
logger.info("Task completed with <COMPLETE> marker")
|
| 1080 |
+
# Use LLM's own completion response directly (no extra LLM call needed)
|
| 1081 |
+
# LLM already generates a summary before outputting <COMPLETE>
|
| 1082 |
+
last_response = self._extract_last_assistant_message(messages)
|
| 1083 |
+
# Remove the <COMPLETE> token from response for cleaner output
|
| 1084 |
+
result["response"] = last_response.replace(GroundingAgentPrompts.TASK_COMPLETE, "").strip()
|
| 1085 |
+
result["status"] = "success"
|
| 1086 |
+
|
| 1087 |
+
# [DISABLED] Extra LLM call to generate final summary
|
| 1088 |
+
# final_summary, summary_success, final_summary_context = await self._generate_final_summary(
|
| 1089 |
+
# instruction=instruction,
|
| 1090 |
+
# messages=messages,
|
| 1091 |
+
# iterations=iterations
|
| 1092 |
+
# )
|
| 1093 |
+
# result["response"] = final_summary
|
| 1094 |
+
# result["final_summary_context"] = final_summary_context
|
| 1095 |
+
else:
|
| 1096 |
+
result["response"] = self._extract_last_assistant_message(messages)
|
| 1097 |
+
result["status"] = "incomplete"
|
| 1098 |
+
result["warning"] = (
|
| 1099 |
+
f"Task reached max iterations ({max_iterations}) without completion. "
|
| 1100 |
+
f"This may indicate the task needs more steps or clarification."
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
return result
|
| 1104 |
+
|
| 1105 |
+
def _format_tool_executions(self, all_tool_results: List[Dict]) -> List[Dict]:
|
| 1106 |
+
executions = []
|
| 1107 |
+
for tr in all_tool_results:
|
| 1108 |
+
tool_result_obj = tr.get("result")
|
| 1109 |
+
tool_call = tr.get("tool_call")
|
| 1110 |
+
|
| 1111 |
+
status = "unknown"
|
| 1112 |
+
if hasattr(tool_result_obj, 'status'):
|
| 1113 |
+
status_obj = tool_result_obj.status
|
| 1114 |
+
status = getattr(status_obj, 'value', status_obj)
|
| 1115 |
+
|
| 1116 |
+
# Extract tool_name and arguments from tool_call object (litellm format)
|
| 1117 |
+
tool_name = "unknown"
|
| 1118 |
+
arguments = {}
|
| 1119 |
+
if tool_call is not None:
|
| 1120 |
+
if hasattr(tool_call, 'function'):
|
| 1121 |
+
# tool_call is an object with .function attribute
|
| 1122 |
+
tool_name = getattr(tool_call.function, 'name', 'unknown')
|
| 1123 |
+
args_raw = getattr(tool_call.function, 'arguments', '{}')
|
| 1124 |
+
if isinstance(args_raw, str):
|
| 1125 |
+
try:
|
| 1126 |
+
arguments = json.loads(args_raw) if args_raw.strip() else {}
|
| 1127 |
+
except json.JSONDecodeError:
|
| 1128 |
+
arguments = {}
|
| 1129 |
+
else:
|
| 1130 |
+
arguments = args_raw if isinstance(args_raw, dict) else {}
|
| 1131 |
+
elif isinstance(tool_call, dict):
|
| 1132 |
+
# Fallback: tool_call is a dict
|
| 1133 |
+
func = tool_call.get("function", {})
|
| 1134 |
+
tool_name = func.get("name", "unknown")
|
| 1135 |
+
args_raw = func.get("arguments", "{}")
|
| 1136 |
+
if isinstance(args_raw, str):
|
| 1137 |
+
try:
|
| 1138 |
+
arguments = json.loads(args_raw) if args_raw.strip() else {}
|
| 1139 |
+
except json.JSONDecodeError:
|
| 1140 |
+
arguments = {}
|
| 1141 |
+
else:
|
| 1142 |
+
arguments = args_raw if isinstance(args_raw, dict) else {}
|
| 1143 |
+
|
| 1144 |
+
executions.append({
|
| 1145 |
+
"tool_name": tool_name,
|
| 1146 |
+
"arguments": arguments,
|
| 1147 |
+
"backend": tr.get("backend"),
|
| 1148 |
+
"server_name": tr.get("server_name"),
|
| 1149 |
+
"status": status,
|
| 1150 |
+
"content": tool_result_obj.content if hasattr(tool_result_obj, 'content') else None,
|
| 1151 |
+
"error": tool_result_obj.error if hasattr(tool_result_obj, 'error') else None,
|
| 1152 |
+
"execution_time": tool_result_obj.execution_time if hasattr(tool_result_obj, 'execution_time') else None,
|
| 1153 |
+
"metadata": tool_result_obj.metadata if hasattr(tool_result_obj, 'metadata') else {},
|
| 1154 |
+
})
|
| 1155 |
+
return executions
|
| 1156 |
+
|
| 1157 |
+
def _check_task_completion(self, messages: List[Dict]) -> bool:
|
| 1158 |
+
for msg in reversed(messages):
|
| 1159 |
+
if msg.get("role") == "assistant":
|
| 1160 |
+
content = msg.get("content", "")
|
| 1161 |
+
return GroundingAgentPrompts.TASK_COMPLETE in content
|
| 1162 |
+
return False
|
| 1163 |
+
|
| 1164 |
+
def _extract_last_assistant_message(self, messages: List[Dict]) -> str:
|
| 1165 |
+
for msg in reversed(messages):
|
| 1166 |
+
if msg.get("role") == "assistant":
|
| 1167 |
+
return msg.get("content", "")
|
| 1168 |
+
return ""
|
| 1169 |
+
|
| 1170 |
+
async def _record_agent_execution(
|
| 1171 |
+
self,
|
| 1172 |
+
result: Dict[str, Any],
|
| 1173 |
+
instruction: str
|
| 1174 |
+
) -> None:
|
| 1175 |
+
"""
|
| 1176 |
+
Record agent execution to recording manager.
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
result: Execution result
|
| 1180 |
+
instruction: Original instruction
|
| 1181 |
+
"""
|
| 1182 |
+
if not self._recording_manager:
|
| 1183 |
+
return
|
| 1184 |
+
|
| 1185 |
+
# Extract tool execution summary
|
| 1186 |
+
tool_summary = []
|
| 1187 |
+
if result.get("tool_executions"):
|
| 1188 |
+
for exec_info in result["tool_executions"]:
|
| 1189 |
+
tool_summary.append({
|
| 1190 |
+
"tool": exec_info.get("tool_name", "unknown"),
|
| 1191 |
+
"backend": exec_info.get("backend", "unknown"),
|
| 1192 |
+
"status": exec_info.get("status", "unknown"),
|
| 1193 |
+
})
|
| 1194 |
+
|
| 1195 |
+
await self._recording_manager.record_agent_action(
|
| 1196 |
+
agent_name=self.name,
|
| 1197 |
+
action_type="execute",
|
| 1198 |
+
input_data={"instruction": instruction},
|
| 1199 |
+
reasoning={
|
| 1200 |
+
"response": result.get("response", ""),
|
| 1201 |
+
"tools_selected": tool_summary,
|
| 1202 |
+
},
|
| 1203 |
+
output_data={
|
| 1204 |
+
"status": result.get("status", "unknown"),
|
| 1205 |
+
"iterations": result.get("iterations", 0),
|
| 1206 |
+
"num_tool_executions": len(result.get("tool_executions", [])),
|
| 1207 |
+
},
|
| 1208 |
+
metadata={
|
| 1209 |
+
"step": self.step,
|
| 1210 |
+
"instruction": instruction,
|
| 1211 |
+
}
|
| 1212 |
+
)
|
openspace/cloud/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cloud platform integration.
|
| 2 |
+
|
| 3 |
+
Provides:
|
| 4 |
+
- ``OpenSpaceClient`` — HTTP client for the cloud API
|
| 5 |
+
- ``get_openspace_auth`` — credential resolution
|
| 6 |
+
- ``SkillSearchEngine`` — hybrid BM25 + embedding search
|
| 7 |
+
- ``generate_embedding`` — OpenAI embedding generation
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from openspace.cloud.auth import get_openspace_auth
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def __getattr__(name: str):
|
| 14 |
+
if name == "OpenSpaceClient":
|
| 15 |
+
from openspace.cloud.client import OpenSpaceClient
|
| 16 |
+
return OpenSpaceClient
|
| 17 |
+
if name == "SkillSearchEngine":
|
| 18 |
+
from openspace.cloud.search import SkillSearchEngine
|
| 19 |
+
return SkillSearchEngine
|
| 20 |
+
if name == "generate_embedding":
|
| 21 |
+
from openspace.cloud.embedding import generate_embedding
|
| 22 |
+
return generate_embedding
|
| 23 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"OpenSpaceClient",
|
| 28 |
+
"get_openspace_auth",
|
| 29 |
+
"SkillSearchEngine",
|
| 30 |
+
"generate_embedding",
|
| 31 |
+
]
|
openspace/cloud/auth.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenSpace cloud platform authentication.
|
| 2 |
+
|
| 3 |
+
Resolution order for OPENSPACE_API_KEY:
|
| 4 |
+
1. ``OPENSPACE_API_KEY`` env var
|
| 5 |
+
2. Auto-detect from host agent config (MCP env block)
|
| 6 |
+
3. Empty (caller treats as "not configured").
|
| 7 |
+
|
| 8 |
+
Base URL resolution:
|
| 9 |
+
1. ``OPENSPACE_API_BASE`` env var
|
| 10 |
+
2. Default: ``https://open-space.cloud/api/v1``
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from typing import Dict, Optional
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("openspace.cloud")
|
| 20 |
+
|
| 21 |
+
OPENSPACE_DEFAULT_BASE = "https://open-space.cloud/api/v1"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_openspace_auth() -> tuple[Dict[str, str], str]:
|
| 25 |
+
"""Resolve OpenSpace credentials and base URL.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
``(auth_headers, api_base)`` — headers dict ready for HTTP requests
|
| 29 |
+
and the API base URL. If no credentials are found, ``auth_headers``
|
| 30 |
+
is empty.
|
| 31 |
+
"""
|
| 32 |
+
from openspace.host_detection import read_host_mcp_env
|
| 33 |
+
|
| 34 |
+
auth_headers: Dict[str, str] = {}
|
| 35 |
+
api_base = OPENSPACE_DEFAULT_BASE
|
| 36 |
+
|
| 37 |
+
# Tier 1: env vars
|
| 38 |
+
env_key = os.environ.get("OPENSPACE_API_KEY", "").strip()
|
| 39 |
+
env_base = os.environ.get("OPENSPACE_API_BASE", "").strip()
|
| 40 |
+
|
| 41 |
+
if env_key:
|
| 42 |
+
auth_headers["X-API-Key"] = env_key
|
| 43 |
+
if env_base:
|
| 44 |
+
api_base = env_base.rstrip("/")
|
| 45 |
+
logger.info("OpenSpace auth: using OPENSPACE_API_KEY env var")
|
| 46 |
+
return auth_headers, api_base
|
| 47 |
+
|
| 48 |
+
# Tier 2: host agent config MCP env block
|
| 49 |
+
mcp_env = read_host_mcp_env()
|
| 50 |
+
cfg_key = str(mcp_env.get("OPENSPACE_API_KEY", "")).strip()
|
| 51 |
+
cfg_base = str(mcp_env.get("OPENSPACE_API_BASE", "")).strip()
|
| 52 |
+
|
| 53 |
+
if cfg_key:
|
| 54 |
+
auth_headers["X-API-Key"] = cfg_key
|
| 55 |
+
if cfg_base:
|
| 56 |
+
api_base = cfg_base.rstrip("/")
|
| 57 |
+
logger.info("OpenSpace auth: using OPENSPACE_API_KEY from host agent MCP env config")
|
| 58 |
+
return auth_headers, api_base
|
| 59 |
+
|
| 60 |
+
return auth_headers, api_base
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_api_base(cli_override: Optional[str] = None) -> str:
|
| 64 |
+
"""Resolve OpenSpace API base URL (for CLI scripts).
|
| 65 |
+
|
| 66 |
+
Priority: ``cli_override`` → env var → host agent config → default.
|
| 67 |
+
"""
|
| 68 |
+
from openspace.host_detection import read_host_mcp_env
|
| 69 |
+
|
| 70 |
+
if cli_override:
|
| 71 |
+
return cli_override.rstrip("/")
|
| 72 |
+
env_base = os.environ.get("OPENSPACE_API_BASE", "").strip()
|
| 73 |
+
if env_base:
|
| 74 |
+
return env_base.rstrip("/")
|
| 75 |
+
mcp_env = read_host_mcp_env()
|
| 76 |
+
cfg_base = str(mcp_env.get("OPENSPACE_API_BASE", "")).strip()
|
| 77 |
+
if cfg_base:
|
| 78 |
+
return cfg_base.rstrip("/")
|
| 79 |
+
return OPENSPACE_DEFAULT_BASE
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_auth_headers_or_exit() -> Dict[str, str]:
|
| 83 |
+
"""Resolve auth headers for CLI scripts. Exits on failure."""
|
| 84 |
+
import sys
|
| 85 |
+
from openspace.host_detection import read_host_mcp_env
|
| 86 |
+
|
| 87 |
+
env_key = os.environ.get("OPENSPACE_API_KEY", "").strip()
|
| 88 |
+
if env_key:
|
| 89 |
+
return {"X-API-Key": env_key}
|
| 90 |
+
|
| 91 |
+
mcp_env = read_host_mcp_env()
|
| 92 |
+
cfg_key = str(mcp_env.get("OPENSPACE_API_KEY", "")).strip()
|
| 93 |
+
if cfg_key:
|
| 94 |
+
return {"X-API-Key": cfg_key}
|
| 95 |
+
|
| 96 |
+
print(
|
| 97 |
+
"ERROR: No OPENSPACE_API_KEY configured.\n"
|
| 98 |
+
" Register at https://open-space.cloud to obtain a key, then add it to\n"
|
| 99 |
+
" your host agent config in the OpenSpace MCP env block.",
|
| 100 |
+
file=sys.stderr,
|
| 101 |
+
)
|
| 102 |
+
sys.exit(1)
|
openspace/cloud/cli/__init__.py
ADDED
|
File without changes
|
openspace/cloud/cli/download_skill.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Download a skill from the OpenSpace cloud platform.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
openspace-download-skill --skill-id "weather__imp_abc12345" --output-dir ./skills/
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from openspace.cloud.auth import get_api_base, get_auth_headers_or_exit
|
| 16 |
+
from openspace.cloud.client import OpenSpaceClient, CloudError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main() -> None:
|
| 20 |
+
parser = argparse.ArgumentParser(
|
| 21 |
+
prog="openspace-download-skill",
|
| 22 |
+
description="Download a skill from OpenSpace's cloud community",
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument("--skill-id", required=True, help="Cloud skill record ID")
|
| 25 |
+
parser.add_argument("--output-dir", required=True, help="Target directory for extraction")
|
| 26 |
+
parser.add_argument("--api-base", default=None, help="Override API base URL")
|
| 27 |
+
parser.add_argument("--force", action="store_true", help="Overwrite existing skill directory")
|
| 28 |
+
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
api_base = get_api_base(args.api_base)
|
| 32 |
+
headers = get_auth_headers_or_exit()
|
| 33 |
+
output_base = Path(args.output_dir).resolve()
|
| 34 |
+
|
| 35 |
+
print(f"Fetching skill: {args.skill_id} ...", file=sys.stderr)
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
client = OpenSpaceClient(headers, api_base)
|
| 39 |
+
result = client.import_skill(args.skill_id, output_base)
|
| 40 |
+
except CloudError as e:
|
| 41 |
+
print(f"ERROR: {e}", file=sys.stderr)
|
| 42 |
+
sys.exit(1)
|
| 43 |
+
|
| 44 |
+
if result.get("status") == "already_exists" and not args.force:
|
| 45 |
+
print(
|
| 46 |
+
f"ERROR: Skill directory already exists: {result.get('local_path')}\n"
|
| 47 |
+
f" Use --force to overwrite.",
|
| 48 |
+
file=sys.stderr,
|
| 49 |
+
)
|
| 50 |
+
sys.exit(1)
|
| 51 |
+
|
| 52 |
+
files = result.get("files", [])
|
| 53 |
+
local_path = result.get("local_path", "")
|
| 54 |
+
print(f" Extracted {len(files)} file(s) to {local_path}", file=sys.stderr)
|
| 55 |
+
for f in files:
|
| 56 |
+
print(f" {f}", file=sys.stderr)
|
| 57 |
+
|
| 58 |
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 59 |
+
print(f"\nSkill downloaded to: {local_path}", file=sys.stderr)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
openspace/cloud/cli/upload_skill.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Upload a skill to the OpenSpace cloud platform.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
openspace-upload-skill --skill-dir ./my-skill --visibility public --origin imported
|
| 6 |
+
openspace-upload-skill --skill-dir ./my-skill --visibility private --origin fixed --parent-ids "parent_id"
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from openspace.cloud.auth import get_api_base, get_auth_headers_or_exit
|
| 17 |
+
from openspace.cloud.client import OpenSpaceClient, CloudError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def main() -> None:
|
| 21 |
+
parser = argparse.ArgumentParser(
|
| 22 |
+
prog="openspace-upload-skill",
|
| 23 |
+
description="Upload a skill to OpenSpace's cloud community",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument("--skill-dir", required=True, help="Path to skill directory (must contain SKILL.md)")
|
| 26 |
+
parser.add_argument("--visibility", required=True, choices=["public", "private"])
|
| 27 |
+
parser.add_argument("--origin", default="imported", choices=["imported", "captured", "derived", "fixed"])
|
| 28 |
+
parser.add_argument("--parent-ids", default="", help="Comma-separated parent skill IDs")
|
| 29 |
+
parser.add_argument("--tags", default="", help="Comma-separated tags")
|
| 30 |
+
parser.add_argument("--created-by", default="", help="Creator display name")
|
| 31 |
+
parser.add_argument("--change-summary", default="", help="Change summary text")
|
| 32 |
+
parser.add_argument("--api-base", default=None, help="Override API base URL")
|
| 33 |
+
parser.add_argument("--dry-run", action="store_true", help="List files without uploading")
|
| 34 |
+
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
skill_dir = Path(args.skill_dir).resolve()
|
| 38 |
+
if not skill_dir.is_dir():
|
| 39 |
+
print(f"ERROR: Not a directory: {skill_dir}", file=sys.stderr)
|
| 40 |
+
sys.exit(1)
|
| 41 |
+
|
| 42 |
+
api_base = get_api_base(args.api_base)
|
| 43 |
+
|
| 44 |
+
if args.dry_run:
|
| 45 |
+
files = OpenSpaceClient._collect_files(skill_dir)
|
| 46 |
+
print(f"Dry run — would upload {len(files)} file(s):", file=sys.stderr)
|
| 47 |
+
for f in files:
|
| 48 |
+
print(f" {f.relative_to(skill_dir)}", file=sys.stderr)
|
| 49 |
+
sys.exit(0)
|
| 50 |
+
|
| 51 |
+
headers = get_auth_headers_or_exit()
|
| 52 |
+
|
| 53 |
+
parent_ids = [p.strip() for p in args.parent_ids.split(",") if p.strip()]
|
| 54 |
+
tags = [t.strip() for t in args.tags.split(",") if t.strip()]
|
| 55 |
+
|
| 56 |
+
print(f"\n{'='*60}", file=sys.stderr)
|
| 57 |
+
print(f"Upload Skill: {skill_dir.name}", file=sys.stderr)
|
| 58 |
+
print(f" Visibility: {args.visibility}", file=sys.stderr)
|
| 59 |
+
print(f" Origin: {args.origin}", file=sys.stderr)
|
| 60 |
+
print(f" API Base: {api_base}", file=sys.stderr)
|
| 61 |
+
print(f"{'='*60}\n", file=sys.stderr)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
client = OpenSpaceClient(headers, api_base)
|
| 65 |
+
result = client.upload_skill(
|
| 66 |
+
skill_dir,
|
| 67 |
+
visibility=args.visibility,
|
| 68 |
+
origin=args.origin,
|
| 69 |
+
parent_skill_ids=parent_ids,
|
| 70 |
+
tags=tags,
|
| 71 |
+
created_by=args.created_by,
|
| 72 |
+
change_summary=args.change_summary,
|
| 73 |
+
)
|
| 74 |
+
except CloudError as e:
|
| 75 |
+
print(f"ERROR: {e}", file=sys.stderr)
|
| 76 |
+
sys.exit(1)
|
| 77 |
+
|
| 78 |
+
print(f"\nUpload complete!", file=sys.stderr)
|
| 79 |
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|
openspace/cloud/client.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenSpace cloud platform HTTP client.
|
| 2 |
+
|
| 3 |
+
All methods are **synchronous** (use ``urllib``). In async contexts
|
| 4 |
+
(MCP server), wrap calls with ``asyncio.to_thread()``.
|
| 5 |
+
|
| 6 |
+
Provides both low-level HTTP operations and higher-level workflows:
|
| 7 |
+
- ``fetch_record`` / ``download_artifact`` / ``fetch_metadata``
|
| 8 |
+
- ``stage_artifact`` / ``create_record``
|
| 9 |
+
- ``upload_skill`` (stage → diff → create — full workflow)
|
| 10 |
+
- ``import_skill`` (fetch → download → extract — full workflow)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import difflib
|
| 16 |
+
import io
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import uuid
|
| 21 |
+
import urllib.error
|
| 22 |
+
import urllib.parse
|
| 23 |
+
import urllib.request
|
| 24 |
+
import zipfile
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Dict, List, Optional
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("openspace.cloud")
|
| 29 |
+
|
| 30 |
+
SKILL_FILENAME = "SKILL.md"
|
| 31 |
+
SKILL_ID_FILENAME = ".skill_id"
|
| 32 |
+
|
| 33 |
+
_TEXT_EXTENSIONS = frozenset({
|
| 34 |
+
".md", ".txt", ".yaml", ".yml", ".json", ".py", ".sh", ".toml",
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CloudError(Exception):
|
| 39 |
+
"""Raised when a cloud API call fails."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, message: str, status_code: int = 0, body: str = ""):
|
| 42 |
+
super().__init__(message)
|
| 43 |
+
self.status_code = status_code
|
| 44 |
+
self.body = body
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class OpenSpaceClient:
|
| 48 |
+
"""HTTP client for the OpenSpace cloud API.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
auth_headers: Pre-resolved auth headers (from ``get_openspace_auth``).
|
| 52 |
+
api_base: API base URL (e.g. ``https://open-space.cloud/api/v1``).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_DEFAULT_UA = "OpenSpace-Client/1.0"
|
| 56 |
+
|
| 57 |
+
def __init__(self, auth_headers: Dict[str, str], api_base: str):
|
| 58 |
+
if not auth_headers:
|
| 59 |
+
raise CloudError(
|
| 60 |
+
"No OPENSPACE_API_KEY configured. "
|
| 61 |
+
"Register at https://open-space.cloud to obtain a key."
|
| 62 |
+
)
|
| 63 |
+
self._headers = {
|
| 64 |
+
"User-Agent": self._DEFAULT_UA,
|
| 65 |
+
**auth_headers,
|
| 66 |
+
}
|
| 67 |
+
self._base = api_base.rstrip("/")
|
| 68 |
+
|
| 69 |
+
def _request(
|
| 70 |
+
self,
|
| 71 |
+
method: str,
|
| 72 |
+
path: str,
|
| 73 |
+
*,
|
| 74 |
+
body: Optional[bytes] = None,
|
| 75 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
| 76 |
+
timeout: int = 30,
|
| 77 |
+
) -> tuple[int, bytes]:
|
| 78 |
+
"""Execute HTTP request. Returns ``(status_code, response_body)``."""
|
| 79 |
+
url = f"{self._base}{path}"
|
| 80 |
+
headers = {**self._headers}
|
| 81 |
+
if extra_headers:
|
| 82 |
+
headers.update(extra_headers)
|
| 83 |
+
|
| 84 |
+
req = urllib.request.Request(url, data=body, headers=headers, method=method)
|
| 85 |
+
try:
|
| 86 |
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
| 87 |
+
return resp.status, resp.read()
|
| 88 |
+
except urllib.error.HTTPError as e:
|
| 89 |
+
resp_body = e.read().decode("utf-8", errors="replace")
|
| 90 |
+
raise CloudError(
|
| 91 |
+
f"HTTP {e.code}: {resp_body[:500]}",
|
| 92 |
+
status_code=e.code,
|
| 93 |
+
body=resp_body,
|
| 94 |
+
)
|
| 95 |
+
except urllib.error.URLError as e:
|
| 96 |
+
raise CloudError(f"Connection failed: {e.reason}")
|
| 97 |
+
|
| 98 |
+
def _get_json(self, path: str, timeout: int = 30) -> Dict[str, Any]:
|
| 99 |
+
_, data = self._request("GET", path, timeout=timeout)
|
| 100 |
+
return json.loads(data.decode("utf-8"))
|
| 101 |
+
|
| 102 |
+
def fetch_record(self, record_id: str) -> Dict[str, Any]:
|
| 103 |
+
"""GET /records/{record_id} — fetch record metadata."""
|
| 104 |
+
return self._get_json(f"/records/{urllib.parse.quote(record_id)}")
|
| 105 |
+
|
| 106 |
+
def download_artifact(self, record_id: str) -> bytes:
|
| 107 |
+
"""GET /records/{record_id}/download — download artifact zip bytes."""
|
| 108 |
+
_, data = self._request(
|
| 109 |
+
"GET",
|
| 110 |
+
f"/records/{urllib.parse.quote(record_id)}/download",
|
| 111 |
+
timeout=120,
|
| 112 |
+
)
|
| 113 |
+
return data
|
| 114 |
+
|
| 115 |
+
def fetch_metadata(
|
| 116 |
+
self,
|
| 117 |
+
*,
|
| 118 |
+
include_embedding: bool = False,
|
| 119 |
+
limit: int = 200,
|
| 120 |
+
) -> List[Dict[str, Any]]:
|
| 121 |
+
"""GET /records/metadata — fetch all visible records with pagination."""
|
| 122 |
+
all_items: List[Dict[str, Any]] = []
|
| 123 |
+
cursor: Optional[str] = None
|
| 124 |
+
|
| 125 |
+
while True:
|
| 126 |
+
params: Dict[str, str] = {"limit": str(limit)}
|
| 127 |
+
if include_embedding:
|
| 128 |
+
params["include_embedding"] = "true"
|
| 129 |
+
if cursor:
|
| 130 |
+
params["cursor"] = cursor
|
| 131 |
+
|
| 132 |
+
path = f"/records/metadata?{urllib.parse.urlencode(params)}"
|
| 133 |
+
data = self._get_json(path, timeout=15)
|
| 134 |
+
|
| 135 |
+
all_items.extend(data.get("items", []))
|
| 136 |
+
|
| 137 |
+
if not data.get("has_more"):
|
| 138 |
+
break
|
| 139 |
+
cursor = data.get("next_cursor")
|
| 140 |
+
if not cursor:
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
return all_items
|
| 144 |
+
|
| 145 |
+
def stage_artifact(self, skill_dir: Path) -> tuple[str, int]:
|
| 146 |
+
"""POST /artifacts/stage — upload skill files.
|
| 147 |
+
|
| 148 |
+
Returns ``(artifact_id, file_count)``.
|
| 149 |
+
"""
|
| 150 |
+
file_paths = self._collect_files(skill_dir)
|
| 151 |
+
if not file_paths:
|
| 152 |
+
raise CloudError("No files found in skill directory")
|
| 153 |
+
|
| 154 |
+
boundary = f"----OpenSpaceUpload{os.urandom(8).hex()}"
|
| 155 |
+
body_parts: list[bytes] = []
|
| 156 |
+
for fp in file_paths:
|
| 157 |
+
rel_path = str(fp.relative_to(skill_dir))
|
| 158 |
+
body_parts.append(f"--{boundary}\r\n".encode())
|
| 159 |
+
body_parts.append(
|
| 160 |
+
f'Content-Disposition: form-data; name="files"; '
|
| 161 |
+
f'filename="{rel_path}"\r\n'.encode()
|
| 162 |
+
)
|
| 163 |
+
ctype = "text/plain" if fp.suffix in _TEXT_EXTENSIONS else "application/octet-stream"
|
| 164 |
+
body_parts.append(f"Content-Type: {ctype}\r\n\r\n".encode())
|
| 165 |
+
body_parts.append(fp.read_bytes())
|
| 166 |
+
body_parts.append(b"\r\n")
|
| 167 |
+
body_parts.append(f"--{boundary}--\r\n".encode())
|
| 168 |
+
|
| 169 |
+
_, resp_data = self._request(
|
| 170 |
+
"POST",
|
| 171 |
+
"/artifacts/stage",
|
| 172 |
+
body=b"".join(body_parts),
|
| 173 |
+
extra_headers={"Content-Type": f"multipart/form-data; boundary={boundary}"},
|
| 174 |
+
timeout=60,
|
| 175 |
+
)
|
| 176 |
+
stage = json.loads(resp_data.decode("utf-8"))
|
| 177 |
+
artifact_id = stage.get("artifact_id")
|
| 178 |
+
if not artifact_id:
|
| 179 |
+
raise CloudError("No artifact_id in stage response")
|
| 180 |
+
file_count = stage.get("stats", {}).get("file_count", 0)
|
| 181 |
+
return artifact_id, file_count
|
| 182 |
+
|
| 183 |
+
def create_record(self, payload: Dict[str, Any]) -> tuple[Dict[str, Any], int]:
|
| 184 |
+
"""POST /records — create skill record with 409 conflict handling.
|
| 185 |
+
|
| 186 |
+
Returns ``(response_data, status_code)``.
|
| 187 |
+
"""
|
| 188 |
+
body = json.dumps(payload).encode("utf-8")
|
| 189 |
+
try:
|
| 190 |
+
status, resp_data = self._request(
|
| 191 |
+
"POST",
|
| 192 |
+
"/records",
|
| 193 |
+
body=body,
|
| 194 |
+
extra_headers={"Content-Type": "application/json"},
|
| 195 |
+
)
|
| 196 |
+
return json.loads(resp_data.decode("utf-8")), status
|
| 197 |
+
except CloudError as e:
|
| 198 |
+
if e.status_code == 409:
|
| 199 |
+
return self._handle_409(e.body, payload)
|
| 200 |
+
raise
|
| 201 |
+
|
| 202 |
+
def _handle_409(
|
| 203 |
+
self, body_text: str, payload: Dict[str, Any],
|
| 204 |
+
) -> tuple[Dict[str, Any], int]:
|
| 205 |
+
"""Handle 409 conflict responses."""
|
| 206 |
+
try:
|
| 207 |
+
err_data = json.loads(body_text)
|
| 208 |
+
except json.JSONDecodeError:
|
| 209 |
+
raise CloudError(f"409 conflict: {body_text}", status_code=409, body=body_text)
|
| 210 |
+
|
| 211 |
+
err_type = err_data.get("error", "")
|
| 212 |
+
|
| 213 |
+
if err_type == "fingerprint_record_id_conflict":
|
| 214 |
+
existing_id = err_data.get("existing_record_id", "")
|
| 215 |
+
return {
|
| 216 |
+
"record_id": existing_id,
|
| 217 |
+
"status": "duplicate",
|
| 218 |
+
"existing_record_id": existing_id,
|
| 219 |
+
}, 409
|
| 220 |
+
|
| 221 |
+
if err_type == "record_id_fingerprint_conflict":
|
| 222 |
+
# Retry with a new UUID
|
| 223 |
+
name = payload.get("name", "skill")
|
| 224 |
+
payload["record_id"] = f"{name}__clo_{uuid.uuid4().hex[:8]}"
|
| 225 |
+
retry_body = json.dumps(payload).encode("utf-8")
|
| 226 |
+
status, resp_data = self._request(
|
| 227 |
+
"POST",
|
| 228 |
+
"/records",
|
| 229 |
+
body=retry_body,
|
| 230 |
+
extra_headers={"Content-Type": "application/json"},
|
| 231 |
+
)
|
| 232 |
+
return json.loads(resp_data.decode("utf-8")), status
|
| 233 |
+
|
| 234 |
+
raise CloudError(f"409 conflict: {body_text}", status_code=409, body=body_text)
|
| 235 |
+
|
| 236 |
+
def upload_skill(
|
| 237 |
+
self,
|
| 238 |
+
skill_dir: Path,
|
| 239 |
+
*,
|
| 240 |
+
visibility: str = "public",
|
| 241 |
+
origin: str = "imported",
|
| 242 |
+
parent_skill_ids: Optional[List[str]] = None,
|
| 243 |
+
tags: Optional[List[str]] = None,
|
| 244 |
+
created_by: str = "",
|
| 245 |
+
change_summary: str = "",
|
| 246 |
+
) -> Dict[str, Any]:
|
| 247 |
+
"""Upload a local skill to the cloud (stage → diff → create record).
|
| 248 |
+
|
| 249 |
+
Returns a result dict with status, record_id, etc.
|
| 250 |
+
"""
|
| 251 |
+
from openspace.skill_engine.skill_utils import parse_frontmatter
|
| 252 |
+
|
| 253 |
+
skill_path = Path(skill_dir)
|
| 254 |
+
skill_file = skill_path / SKILL_FILENAME
|
| 255 |
+
if not skill_file.exists():
|
| 256 |
+
raise CloudError(f"SKILL.md not found in {skill_dir}")
|
| 257 |
+
|
| 258 |
+
content = skill_file.read_text(encoding="utf-8")
|
| 259 |
+
fm = parse_frontmatter(content)
|
| 260 |
+
name = fm.get("name", skill_path.name)
|
| 261 |
+
description = fm.get("description", "")
|
| 262 |
+
|
| 263 |
+
if not name:
|
| 264 |
+
raise CloudError("SKILL.md frontmatter missing 'name' field")
|
| 265 |
+
|
| 266 |
+
parents = parent_skill_ids or []
|
| 267 |
+
self._validate_origin_parents(origin, parents)
|
| 268 |
+
|
| 269 |
+
api_visibility = "group_only" if visibility == "private" else "public"
|
| 270 |
+
|
| 271 |
+
# Step 1: Stage
|
| 272 |
+
logger.info(f"upload_skill: staging files for '{name}'")
|
| 273 |
+
artifact_id, file_count = self.stage_artifact(skill_path)
|
| 274 |
+
logger.info(f"upload_skill: staged {file_count} file(s), artifact_id={artifact_id}")
|
| 275 |
+
|
| 276 |
+
# Step 2: Content diff
|
| 277 |
+
content_diff = self._compute_content_diff(skill_path, api_visibility, parents)
|
| 278 |
+
|
| 279 |
+
# Step 3: Create record
|
| 280 |
+
record_id = f"{name}__clo_{uuid.uuid4().hex[:8]}"
|
| 281 |
+
payload: Dict[str, Any] = {
|
| 282 |
+
"record_id": record_id,
|
| 283 |
+
"artifact_id": artifact_id,
|
| 284 |
+
# name/description are NOT sent — the server extracts them
|
| 285 |
+
# from SKILL.md YAML frontmatter (Task 4+F4 change).
|
| 286 |
+
"origin": origin,
|
| 287 |
+
"visibility": api_visibility,
|
| 288 |
+
"parent_skill_ids": parents,
|
| 289 |
+
"tags": tags or [],
|
| 290 |
+
"level": "workflow",
|
| 291 |
+
}
|
| 292 |
+
if created_by:
|
| 293 |
+
payload["created_by"] = created_by
|
| 294 |
+
if change_summary:
|
| 295 |
+
payload["change_summary"] = change_summary
|
| 296 |
+
if content_diff is not None:
|
| 297 |
+
payload["content_diff"] = content_diff
|
| 298 |
+
|
| 299 |
+
record_data, status_code = self.create_record(payload)
|
| 300 |
+
action = "created" if status_code == 201 else "exists (idempotent)"
|
| 301 |
+
final_record_id = record_data.get("record_id", record_id)
|
| 302 |
+
|
| 303 |
+
logger.info(
|
| 304 |
+
f"upload_skill: {name} [{final_record_id}] — {action} "
|
| 305 |
+
f"(visibility={api_visibility}, origin={origin})"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Check for duplicate status from 409 handling
|
| 309 |
+
if record_data.get("status") == "duplicate":
|
| 310 |
+
return {
|
| 311 |
+
"status": "duplicate",
|
| 312 |
+
"message": f"Same content already exists as record '{record_data.get('existing_record_id', '')}'",
|
| 313 |
+
"existing_record_id": record_data.get("existing_record_id", ""),
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
return {
|
| 317 |
+
"status": "success",
|
| 318 |
+
"action": action,
|
| 319 |
+
"record_id": final_record_id,
|
| 320 |
+
"name": name,
|
| 321 |
+
"description": description,
|
| 322 |
+
"visibility": api_visibility,
|
| 323 |
+
"origin": origin,
|
| 324 |
+
"parent_skill_ids": parents,
|
| 325 |
+
"artifact_id": artifact_id,
|
| 326 |
+
"file_count": file_count,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
def import_skill(
|
| 330 |
+
self,
|
| 331 |
+
skill_id: str,
|
| 332 |
+
target_dir: Path,
|
| 333 |
+
) -> Dict[str, Any]:
|
| 334 |
+
"""Download a cloud skill and extract to a local directory.
|
| 335 |
+
|
| 336 |
+
Returns a result dict with status, local_path, files, etc.
|
| 337 |
+
"""
|
| 338 |
+
# 1. Fetch metadata
|
| 339 |
+
logger.info(f"import_skill: fetching metadata for {skill_id}")
|
| 340 |
+
record_data = self.fetch_record(skill_id)
|
| 341 |
+
skill_name = record_data.get("name", skill_id)
|
| 342 |
+
|
| 343 |
+
skill_dir = target_dir / skill_name
|
| 344 |
+
|
| 345 |
+
# Check if already exists locally
|
| 346 |
+
if skill_dir.exists() and (skill_dir / SKILL_FILENAME).exists():
|
| 347 |
+
return {
|
| 348 |
+
"status": "already_exists",
|
| 349 |
+
"skill_id": skill_id,
|
| 350 |
+
"name": skill_name,
|
| 351 |
+
"local_path": str(skill_dir),
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
# 2. Download artifact
|
| 355 |
+
logger.info(f"import_skill: downloading artifact for {skill_id}")
|
| 356 |
+
zip_data = self.download_artifact(skill_id)
|
| 357 |
+
|
| 358 |
+
# 3. Extract
|
| 359 |
+
skill_dir.mkdir(parents=True, exist_ok=True)
|
| 360 |
+
extracted = self._extract_zip(zip_data, skill_dir)
|
| 361 |
+
|
| 362 |
+
# 4. Write .skill_id sidecar
|
| 363 |
+
(skill_dir / SKILL_ID_FILENAME).write_text(skill_id + "\n", encoding="utf-8")
|
| 364 |
+
|
| 365 |
+
logger.info(
|
| 366 |
+
f"import_skill: {skill_name} [{skill_id}] → {skill_dir} "
|
| 367 |
+
f"({len(extracted)} files)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return {
|
| 371 |
+
"status": "success",
|
| 372 |
+
"skill_id": skill_id,
|
| 373 |
+
"name": skill_name,
|
| 374 |
+
"description": record_data.get("description", ""),
|
| 375 |
+
"local_path": str(skill_dir),
|
| 376 |
+
"files": extracted,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
@staticmethod
|
| 380 |
+
def _collect_files(skill_dir: Path) -> List[Path]:
|
| 381 |
+
"""Collect all files in skill directory (skip .skill_id sidecar)."""
|
| 382 |
+
return [
|
| 383 |
+
p for p in sorted(skill_dir.rglob("*"))
|
| 384 |
+
if p.is_file() and p.name != SKILL_ID_FILENAME
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
def _collect_text_files(skill_dir: Path) -> Dict[str, str]:
|
| 389 |
+
"""Collect text files as ``{relative_path: content}``."""
|
| 390 |
+
files: Dict[str, str] = {}
|
| 391 |
+
for p in sorted(skill_dir.rglob("*")):
|
| 392 |
+
if p.is_file() and p.name != SKILL_ID_FILENAME:
|
| 393 |
+
rel = str(p.relative_to(skill_dir))
|
| 394 |
+
try:
|
| 395 |
+
files[rel] = p.read_text(encoding="utf-8")
|
| 396 |
+
except (UnicodeDecodeError, OSError):
|
| 397 |
+
pass
|
| 398 |
+
return files
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def _extract_zip(zip_data: bytes, target_dir: Path) -> List[str]:
|
| 402 |
+
"""Extract zip bytes to target directory with path traversal protection."""
|
| 403 |
+
extracted: List[str] = []
|
| 404 |
+
try:
|
| 405 |
+
with zipfile.ZipFile(io.BytesIO(zip_data)) as zf:
|
| 406 |
+
for info in zf.infolist():
|
| 407 |
+
if info.is_dir():
|
| 408 |
+
continue
|
| 409 |
+
clean_name = Path(info.filename).as_posix()
|
| 410 |
+
if clean_name.startswith("..") or clean_name.startswith("/"):
|
| 411 |
+
continue
|
| 412 |
+
target_path = target_dir / clean_name
|
| 413 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 414 |
+
target_path.write_bytes(zf.read(info))
|
| 415 |
+
extracted.append(clean_name)
|
| 416 |
+
except zipfile.BadZipFile:
|
| 417 |
+
raise CloudError("Downloaded artifact is not a valid zip file")
|
| 418 |
+
return extracted
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
def _extract_zip_text_files(zip_data: bytes) -> Dict[str, str]:
|
| 422 |
+
"""Extract text files from zip as ``{filename: content}``."""
|
| 423 |
+
files: Dict[str, str] = {}
|
| 424 |
+
try:
|
| 425 |
+
with zipfile.ZipFile(io.BytesIO(zip_data)) as zf:
|
| 426 |
+
for info in zf.infolist():
|
| 427 |
+
if info.is_dir() or info.filename == SKILL_ID_FILENAME:
|
| 428 |
+
continue
|
| 429 |
+
try:
|
| 430 |
+
files[info.filename] = zf.read(info).decode("utf-8")
|
| 431 |
+
except (UnicodeDecodeError, KeyError):
|
| 432 |
+
pass
|
| 433 |
+
except zipfile.BadZipFile:
|
| 434 |
+
pass
|
| 435 |
+
return files
|
| 436 |
+
|
| 437 |
+
@staticmethod
|
| 438 |
+
def _validate_origin_parents(origin: str, parents: List[str]) -> None:
|
| 439 |
+
if origin in ("imported", "captured") and parents:
|
| 440 |
+
raise CloudError(f"origin='{origin}' must not have parent_skill_ids")
|
| 441 |
+
if origin == "derived" and not parents:
|
| 442 |
+
raise CloudError("origin='derived' requires at least 1 parent_skill_id")
|
| 443 |
+
if origin == "fixed" and len(parents) != 1:
|
| 444 |
+
raise CloudError("origin='fixed' requires exactly 1 parent_skill_id")
|
| 445 |
+
|
| 446 |
+
def _compute_content_diff(
|
| 447 |
+
self,
|
| 448 |
+
skill_dir: Path,
|
| 449 |
+
api_visibility: str,
|
| 450 |
+
parents: List[str],
|
| 451 |
+
) -> Optional[str]:
|
| 452 |
+
"""Compute content_diff for the upload.
|
| 453 |
+
|
| 454 |
+
- public + single parent → diff vs ancestor
|
| 455 |
+
- public + no parent → add-all diff
|
| 456 |
+
- else → None
|
| 457 |
+
"""
|
| 458 |
+
if api_visibility != "public":
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
cur_files = self._collect_text_files(skill_dir)
|
| 462 |
+
|
| 463 |
+
if len(parents) == 1:
|
| 464 |
+
try:
|
| 465 |
+
anc_zip = self.download_artifact(parents[0])
|
| 466 |
+
anc_files = self._extract_zip_text_files(anc_zip)
|
| 467 |
+
diff = self._unified_diff(anc_files, cur_files)
|
| 468 |
+
if diff:
|
| 469 |
+
logger.info(f"Computed diff vs ancestor {parents[0]}")
|
| 470 |
+
return diff
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.warning(f"Diff computation failed: {e}")
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
if not parents:
|
| 476 |
+
return self._unified_diff({}, cur_files)
|
| 477 |
+
|
| 478 |
+
return None # multiple parents
|
| 479 |
+
|
| 480 |
+
@staticmethod
|
| 481 |
+
def _unified_diff(old_files: Dict[str, str], new_files: Dict[str, str]) -> Optional[str]:
|
| 482 |
+
"""Compute combined unified diff between two file snapshots."""
|
| 483 |
+
all_names = sorted(set(old_files) | set(new_files))
|
| 484 |
+
parts: List[str] = []
|
| 485 |
+
for fname in all_names:
|
| 486 |
+
old = old_files.get(fname, "")
|
| 487 |
+
new = new_files.get(fname, "")
|
| 488 |
+
d = "".join(difflib.unified_diff(
|
| 489 |
+
old.splitlines(keepends=True),
|
| 490 |
+
new.splitlines(keepends=True),
|
| 491 |
+
fromfile=f"a/{fname}",
|
| 492 |
+
tofile=f"b/{fname}",
|
| 493 |
+
n=3,
|
| 494 |
+
))
|
| 495 |
+
if d:
|
| 496 |
+
parts.append(d)
|
| 497 |
+
return "\n".join(parts) if parts else None
|
openspace/cloud/embedding.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding generation via OpenAI-compatible API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import urllib.request
|
| 10 |
+
from typing import List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("openspace.cloud")
|
| 13 |
+
|
| 14 |
+
# Constants (duplicated here to avoid top-level import of skill_ranker)
|
| 15 |
+
SKILL_EMBEDDING_MODEL = "openai/text-embedding-3-small"
|
| 16 |
+
SKILL_EMBEDDING_MAX_CHARS = 12_000
|
| 17 |
+
SKILL_EMBEDDING_DIMENSIONS = 1536
|
| 18 |
+
|
| 19 |
+
_OPENROUTER_BASE = "https://openrouter.ai/api/v1"
|
| 20 |
+
_OPENAI_BASE = "https://api.openai.com/v1"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resolve_embedding_api() -> Tuple[Optional[str], str]:
|
| 24 |
+
"""Resolve API key and base URL for embedding requests.
|
| 25 |
+
|
| 26 |
+
Priority:
|
| 27 |
+
1. ``OPENROUTER_API_KEY`` → OpenRouter base URL
|
| 28 |
+
2. ``OPENAI_API_KEY`` + ``OPENAI_BASE_URL`` (default ``api.openai.com``)
|
| 29 |
+
3. host-agent config (nanobot / openclaw)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
``(api_key, base_url)`` — *api_key* may be ``None`` when no key is found.
|
| 33 |
+
"""
|
| 34 |
+
or_key = os.environ.get("OPENROUTER_API_KEY")
|
| 35 |
+
if or_key:
|
| 36 |
+
return or_key, _OPENROUTER_BASE
|
| 37 |
+
|
| 38 |
+
oa_key = os.environ.get("OPENAI_API_KEY")
|
| 39 |
+
if oa_key:
|
| 40 |
+
base = os.environ.get("OPENAI_BASE_URL", _OPENAI_BASE).rstrip("/")
|
| 41 |
+
return oa_key, base
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
from openspace.host_detection import get_openai_api_key
|
| 45 |
+
host_key = get_openai_api_key()
|
| 46 |
+
if host_key:
|
| 47 |
+
base = os.environ.get("OPENAI_BASE_URL", _OPENAI_BASE).rstrip("/")
|
| 48 |
+
return host_key, base
|
| 49 |
+
except Exception:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
return None, _OPENAI_BASE
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
| 56 |
+
"""Compute cosine similarity between two vectors."""
|
| 57 |
+
if len(a) != len(b) or not a:
|
| 58 |
+
return 0.0
|
| 59 |
+
dot = sum(x * y for x, y in zip(a, b))
|
| 60 |
+
norm_a = math.sqrt(sum(x * x for x in a))
|
| 61 |
+
norm_b = math.sqrt(sum(x * x for x in b))
|
| 62 |
+
if norm_a == 0 or norm_b == 0:
|
| 63 |
+
return 0.0
|
| 64 |
+
return dot / (norm_a * norm_b)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_skill_embedding_text(
|
| 68 |
+
name: str,
|
| 69 |
+
description: str,
|
| 70 |
+
readme_body: str,
|
| 71 |
+
max_chars: int = SKILL_EMBEDDING_MAX_CHARS,
|
| 72 |
+
) -> str:
|
| 73 |
+
"""Build text for skill embedding: ``name + description + SKILL.md body``.
|
| 74 |
+
|
| 75 |
+
Unified strategy matching MCP search_skills and clawhub platform.
|
| 76 |
+
"""
|
| 77 |
+
header = "\n".join(filter(None, [name, description]))
|
| 78 |
+
raw = "\n\n".join(filter(None, [header, readme_body]))
|
| 79 |
+
if len(raw) <= max_chars:
|
| 80 |
+
return raw
|
| 81 |
+
return raw[:max_chars]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def generate_embedding(text: str, api_key: Optional[str] = None) -> Optional[List[float]]:
|
| 85 |
+
"""Generate embedding using OpenAI-compatible API.
|
| 86 |
+
|
| 87 |
+
When *api_key* is ``None``, credentials are resolved automatically via
|
| 88 |
+
:func:`resolve_embedding_api` (``OPENROUTER_API_KEY`` → ``OPENAI_API_KEY``
|
| 89 |
+
→ host-agent config).
|
| 90 |
+
|
| 91 |
+
This is a **synchronous** call (uses urllib). In async contexts,
|
| 92 |
+
wrap with ``asyncio.to_thread()``.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
text: The text to embed.
|
| 96 |
+
api_key: Explicit API key. When provided, base URL is still resolved
|
| 97 |
+
from environment (``OPENROUTER_API_KEY`` presence determines
|
| 98 |
+
the endpoint).
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Embedding vector, or None on failure.
|
| 102 |
+
"""
|
| 103 |
+
resolved_key, base_url = resolve_embedding_api()
|
| 104 |
+
if api_key is None:
|
| 105 |
+
api_key = resolved_key
|
| 106 |
+
if not api_key:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
body = json.dumps({
|
| 110 |
+
"model": SKILL_EMBEDDING_MODEL,
|
| 111 |
+
"input": text,
|
| 112 |
+
}).encode("utf-8")
|
| 113 |
+
|
| 114 |
+
req = urllib.request.Request(
|
| 115 |
+
f"{base_url}/embeddings",
|
| 116 |
+
data=body,
|
| 117 |
+
headers={
|
| 118 |
+
"Content-Type": "application/json",
|
| 119 |
+
"Authorization": f"Bearer {api_key}",
|
| 120 |
+
},
|
| 121 |
+
method="POST",
|
| 122 |
+
)
|
| 123 |
+
try:
|
| 124 |
+
with urllib.request.urlopen(req, timeout=15) as resp:
|
| 125 |
+
data = json.loads(resp.read().decode("utf-8"))
|
| 126 |
+
return data.get("data", [{}])[0].get("embedding")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.warning("Embedding generation failed: %s", e)
|
| 129 |
+
return None
|
openspace/cloud/search.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hybrid skill search engine (BM25 + embedding + lexical boost).
|
| 2 |
+
|
| 3 |
+
Implements the search pipeline:
|
| 4 |
+
Phase 1: BM25 rough-rank over all candidates
|
| 5 |
+
Phase 2: Vector scoring (embedding cosine similarity)
|
| 6 |
+
Phase 3: Hybrid score = vector_score + lexical_boost
|
| 7 |
+
Phase 4: Deduplication + limit
|
| 8 |
+
|
| 9 |
+
Used by MCP ``search_skills`` tool, ``retrieve_skill`` agent tool,
|
| 10 |
+
and potentially other search interfaces.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import logging
|
| 17 |
+
import re
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("openspace.cloud")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _check_safety(text: str) -> list[str]:
|
| 24 |
+
"""Lazy wrapper — avoids importing skill_engine at module load time."""
|
| 25 |
+
from openspace.skill_engine.skill_utils import check_skill_safety
|
| 26 |
+
return check_skill_safety(text)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _is_safe(flags: list[str]) -> bool:
|
| 30 |
+
from openspace.skill_engine.skill_utils import is_skill_safe
|
| 31 |
+
return is_skill_safe(flags)
|
| 32 |
+
|
| 33 |
+
_WORD_RE = re.compile(r"[a-z0-9]+")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _tokenize(value: str) -> list[str]:
|
| 37 |
+
return _WORD_RE.findall(value.lower()) if value else []
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _lexical_boost(query_tokens: list[str], name: str, slug: str) -> float:
|
| 41 |
+
"""Compute lexical boost score based on exact/prefix token matching."""
|
| 42 |
+
slug_tokens = _tokenize(slug)
|
| 43 |
+
name_tokens = _tokenize(name)
|
| 44 |
+
boost = 0.0
|
| 45 |
+
|
| 46 |
+
# Slug exact / prefix
|
| 47 |
+
if slug_tokens and all(
|
| 48 |
+
any(ct == qt for ct in slug_tokens) for qt in query_tokens
|
| 49 |
+
):
|
| 50 |
+
boost += 1.4
|
| 51 |
+
elif slug_tokens and all(
|
| 52 |
+
any(ct.startswith(qt) for ct in slug_tokens) for qt in query_tokens
|
| 53 |
+
):
|
| 54 |
+
boost += 0.8
|
| 55 |
+
|
| 56 |
+
# Name exact / prefix
|
| 57 |
+
if name_tokens and all(
|
| 58 |
+
any(ct == qt for ct in name_tokens) for qt in query_tokens
|
| 59 |
+
):
|
| 60 |
+
boost += 1.1
|
| 61 |
+
elif name_tokens and all(
|
| 62 |
+
any(ct.startswith(qt) for ct in name_tokens) for qt in query_tokens
|
| 63 |
+
):
|
| 64 |
+
boost += 0.6
|
| 65 |
+
|
| 66 |
+
return boost
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SkillSearchEngine:
|
| 70 |
+
"""Hybrid BM25 + embedding search engine for skills.
|
| 71 |
+
|
| 72 |
+
Usage::
|
| 73 |
+
|
| 74 |
+
engine = SkillSearchEngine()
|
| 75 |
+
results = engine.search(
|
| 76 |
+
query="weather forecast",
|
| 77 |
+
candidates=candidates,
|
| 78 |
+
query_embedding=[...], # optional
|
| 79 |
+
limit=20,
|
| 80 |
+
)
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def search(
|
| 84 |
+
self,
|
| 85 |
+
query: str,
|
| 86 |
+
candidates: List[Dict[str, Any]],
|
| 87 |
+
*,
|
| 88 |
+
query_embedding: Optional[List[float]] = None,
|
| 89 |
+
limit: int = 20,
|
| 90 |
+
) -> List[Dict[str, Any]]:
|
| 91 |
+
"""Run the full search pipeline on candidates.
|
| 92 |
+
|
| 93 |
+
Each candidate dict should have at minimum:
|
| 94 |
+
- ``skill_id``, ``name``, ``description``
|
| 95 |
+
- ``_embedding`` (optional): pre-computed embedding vector
|
| 96 |
+
- ``source``: "openspace-local" | "cloud"
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
query: Search query text.
|
| 100 |
+
candidates: Candidate dicts to rank.
|
| 101 |
+
query_embedding: Pre-computed query embedding (if available).
|
| 102 |
+
limit: Max results to return.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Sorted list of result dicts (highest score first).
|
| 106 |
+
"""
|
| 107 |
+
q = query.strip()
|
| 108 |
+
if not q or not candidates:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
query_tokens = _tokenize(q)
|
| 112 |
+
if not query_tokens:
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
# Phase 1: BM25 rough-rank
|
| 116 |
+
filtered = self._bm25_phase(q, candidates, limit)
|
| 117 |
+
|
| 118 |
+
# Phase 2+3: Vector + lexical scoring
|
| 119 |
+
scored = self._score_phase(filtered, query_tokens, query_embedding)
|
| 120 |
+
|
| 121 |
+
# Phase 4: Deduplicate and limit
|
| 122 |
+
return self._dedup_and_limit(scored, limit)
|
| 123 |
+
|
| 124 |
+
def _bm25_phase(
|
| 125 |
+
self,
|
| 126 |
+
query: str,
|
| 127 |
+
candidates: List[Dict[str, Any]],
|
| 128 |
+
limit: int,
|
| 129 |
+
) -> List[Dict[str, Any]]:
|
| 130 |
+
"""BM25 rough-rank to keep top candidates for embedding stage."""
|
| 131 |
+
from openspace.skill_engine.skill_ranker import SkillRanker, SkillCandidate
|
| 132 |
+
|
| 133 |
+
ranker = SkillRanker(enable_cache=True)
|
| 134 |
+
bm25_candidates = [
|
| 135 |
+
SkillCandidate(
|
| 136 |
+
skill_id=c.get("skill_id", ""),
|
| 137 |
+
name=c.get("name", ""),
|
| 138 |
+
description=c.get("description", ""),
|
| 139 |
+
body="",
|
| 140 |
+
metadata=c,
|
| 141 |
+
)
|
| 142 |
+
for c in candidates
|
| 143 |
+
]
|
| 144 |
+
ranked = ranker.bm25_only(query, bm25_candidates, top_k=min(limit * 3, len(candidates)))
|
| 145 |
+
|
| 146 |
+
ranked_ids = {sc.skill_id for sc in ranked}
|
| 147 |
+
filtered = [c for c in candidates if c.get("skill_id") in ranked_ids]
|
| 148 |
+
|
| 149 |
+
# If BM25 found nothing, fall back to all candidates
|
| 150 |
+
return filtered if filtered else candidates
|
| 151 |
+
|
| 152 |
+
def _score_phase(
|
| 153 |
+
self,
|
| 154 |
+
candidates: List[Dict[str, Any]],
|
| 155 |
+
query_tokens: list[str],
|
| 156 |
+
query_embedding: Optional[List[float]],
|
| 157 |
+
) -> List[Dict[str, Any]]:
|
| 158 |
+
"""Compute hybrid score = vector_score + lexical_boost."""
|
| 159 |
+
from openspace.cloud.embedding import cosine_similarity
|
| 160 |
+
|
| 161 |
+
scored = []
|
| 162 |
+
for c in candidates:
|
| 163 |
+
name = c.get("name", "")
|
| 164 |
+
slug = c.get("skill_id", name).split("__")[0].replace(":", "-")
|
| 165 |
+
|
| 166 |
+
# Vector score
|
| 167 |
+
vector_score = 0.0
|
| 168 |
+
if query_embedding:
|
| 169 |
+
skill_emb = c.get("_embedding")
|
| 170 |
+
if skill_emb and isinstance(skill_emb, list):
|
| 171 |
+
vector_score = cosine_similarity(query_embedding, skill_emb)
|
| 172 |
+
|
| 173 |
+
# Lexical boost
|
| 174 |
+
lexical = _lexical_boost(query_tokens, name, slug)
|
| 175 |
+
|
| 176 |
+
final_score = vector_score + lexical
|
| 177 |
+
|
| 178 |
+
entry: Dict[str, Any] = {
|
| 179 |
+
"skill_id": c.get("skill_id", ""),
|
| 180 |
+
"name": name,
|
| 181 |
+
"description": c.get("description", ""),
|
| 182 |
+
"source": c.get("source", ""),
|
| 183 |
+
"score": round(final_score, 4),
|
| 184 |
+
}
|
| 185 |
+
if vector_score > 0:
|
| 186 |
+
entry["vector_score"] = round(vector_score, 4)
|
| 187 |
+
# Include optional fields
|
| 188 |
+
for key in ("path", "visibility", "created_by", "origin", "tags", "quality", "safety_flags"):
|
| 189 |
+
if c.get(key):
|
| 190 |
+
entry[key] = c[key]
|
| 191 |
+
scored.append(entry)
|
| 192 |
+
|
| 193 |
+
scored.sort(key=lambda x: -x["score"])
|
| 194 |
+
return scored
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _dedup_and_limit(
|
| 198 |
+
scored: List[Dict[str, Any]],
|
| 199 |
+
limit: int,
|
| 200 |
+
) -> List[Dict[str, Any]]:
|
| 201 |
+
"""Deduplicate by name and apply limit."""
|
| 202 |
+
seen: set[str] = set()
|
| 203 |
+
deduped = []
|
| 204 |
+
for item in scored:
|
| 205 |
+
name = item["name"]
|
| 206 |
+
if name in seen:
|
| 207 |
+
continue
|
| 208 |
+
seen.add(name)
|
| 209 |
+
deduped.append(item)
|
| 210 |
+
return deduped[:limit]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def build_local_candidates(
|
| 214 |
+
skills: list,
|
| 215 |
+
store: Any = None,
|
| 216 |
+
) -> List[Dict[str, Any]]:
|
| 217 |
+
"""Build search candidate dicts from SkillRegistry skills.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
skills: List of ``SkillMeta`` from ``registry.list_skills()``.
|
| 221 |
+
store: Optional ``SkillStore`` instance for quality data enrichment.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
List of candidate dicts ready for ``SkillSearchEngine.search()``.
|
| 225 |
+
"""
|
| 226 |
+
from openspace.cloud.embedding import build_skill_embedding_text
|
| 227 |
+
|
| 228 |
+
candidates: List[Dict[str, Any]] = []
|
| 229 |
+
for s in skills:
|
| 230 |
+
# Read SKILL.md body
|
| 231 |
+
readme_body = ""
|
| 232 |
+
try:
|
| 233 |
+
raw = s.path.read_text(encoding="utf-8")
|
| 234 |
+
m = re.match(r"^---\n.*?\n---\n?", raw, re.DOTALL)
|
| 235 |
+
readme_body = raw[m.end():].strip() if m else raw
|
| 236 |
+
except Exception:
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
embedding_text = build_skill_embedding_text(s.name, s.description, readme_body)
|
| 240 |
+
|
| 241 |
+
# Safety check
|
| 242 |
+
flags = _check_safety(embedding_text)
|
| 243 |
+
if not _is_safe(flags):
|
| 244 |
+
logger.info(f"BLOCKED local skill {s.skill_id} — {flags}")
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
candidates.append({
|
| 248 |
+
"skill_id": s.skill_id,
|
| 249 |
+
"name": s.name,
|
| 250 |
+
"description": s.description,
|
| 251 |
+
"source": "openspace-local",
|
| 252 |
+
"path": str(s.path),
|
| 253 |
+
"is_local": True,
|
| 254 |
+
"safety_flags": flags if flags else None,
|
| 255 |
+
"_embedding_text": embedding_text,
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
# Enrich with quality data
|
| 259 |
+
if store and candidates:
|
| 260 |
+
try:
|
| 261 |
+
all_records = store.load_all(active_only=True)
|
| 262 |
+
for c in candidates:
|
| 263 |
+
rec = all_records.get(c["skill_id"])
|
| 264 |
+
if rec:
|
| 265 |
+
c["quality"] = {
|
| 266 |
+
"total_selections": rec.total_selections,
|
| 267 |
+
"completion_rate": round(rec.completion_rate, 3),
|
| 268 |
+
"effective_rate": round(rec.effective_rate, 3),
|
| 269 |
+
}
|
| 270 |
+
c["tags"] = rec.tags
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.warning(f"Quality lookup failed: {e}")
|
| 273 |
+
|
| 274 |
+
return candidates
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def build_cloud_candidates(
|
| 278 |
+
items: List[Dict[str, Any]],
|
| 279 |
+
) -> List[Dict[str, Any]]:
|
| 280 |
+
"""Build search candidate dicts from cloud metadata items.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
items: Items from ``OpenSpaceClient.fetch_metadata()``.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
List of candidate dicts (with safety filtering applied).
|
| 287 |
+
"""
|
| 288 |
+
candidates: List[Dict[str, Any]] = []
|
| 289 |
+
for item in items:
|
| 290 |
+
name = item.get("name", "")
|
| 291 |
+
desc = item.get("description", "")
|
| 292 |
+
tags = item.get("tags", [])
|
| 293 |
+
safety_text = f"{name}\n{desc}\n{' '.join(tags)}"
|
| 294 |
+
flags = _check_safety(safety_text)
|
| 295 |
+
if not _is_safe(flags):
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
c_entry: Dict[str, Any] = {
|
| 299 |
+
"skill_id": item.get("record_id", ""),
|
| 300 |
+
"name": name,
|
| 301 |
+
"description": desc,
|
| 302 |
+
"source": "cloud",
|
| 303 |
+
"visibility": item.get("visibility", "public"),
|
| 304 |
+
"is_local": False,
|
| 305 |
+
"created_by": item.get("created_by", ""),
|
| 306 |
+
"origin": item.get("origin", ""),
|
| 307 |
+
"tags": tags,
|
| 308 |
+
"safety_flags": flags if flags else None,
|
| 309 |
+
}
|
| 310 |
+
# Carry pre-computed embedding
|
| 311 |
+
platform_emb = item.get("embedding")
|
| 312 |
+
if platform_emb and isinstance(platform_emb, list):
|
| 313 |
+
c_entry["_embedding"] = platform_emb
|
| 314 |
+
candidates.append(c_entry)
|
| 315 |
+
|
| 316 |
+
return candidates
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
async def hybrid_search_skills(
|
| 320 |
+
query: str,
|
| 321 |
+
local_skills: list = None,
|
| 322 |
+
store: Any = None,
|
| 323 |
+
source: str = "all",
|
| 324 |
+
limit: int = 20,
|
| 325 |
+
) -> List[Dict[str, Any]]:
|
| 326 |
+
"""Shared cloud+local skill search with graceful fallback.
|
| 327 |
+
|
| 328 |
+
Builds candidates, generates embeddings, runs ``SkillSearchEngine``.
|
| 329 |
+
Cloud is attempted when *source* includes it; failures are silently
|
| 330 |
+
skipped so the caller always gets local results at minimum.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
query: Free-text search query.
|
| 334 |
+
local_skills: ``SkillMeta`` list (from ``registry.list_skills()``).
|
| 335 |
+
store: Optional ``SkillStore`` for quality enrichment.
|
| 336 |
+
source: ``"all"`` | ``"local"`` | ``"cloud"``.
|
| 337 |
+
limit: Maximum results.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Ranked result dicts (same format as ``SkillSearchEngine.search()``).
|
| 341 |
+
"""
|
| 342 |
+
from openspace.cloud.embedding import generate_embedding
|
| 343 |
+
|
| 344 |
+
q = query.strip()
|
| 345 |
+
if not q:
|
| 346 |
+
return []
|
| 347 |
+
|
| 348 |
+
candidates: List[Dict[str, Any]] = []
|
| 349 |
+
|
| 350 |
+
if source in ("all", "local") and local_skills:
|
| 351 |
+
candidates.extend(build_local_candidates(local_skills, store))
|
| 352 |
+
|
| 353 |
+
if source in ("all", "cloud"):
|
| 354 |
+
try:
|
| 355 |
+
from openspace.cloud.auth import get_openspace_auth
|
| 356 |
+
from openspace.cloud.client import OpenSpaceClient
|
| 357 |
+
|
| 358 |
+
auth_headers, api_base = get_openspace_auth()
|
| 359 |
+
if auth_headers:
|
| 360 |
+
client = OpenSpaceClient(auth_headers, api_base)
|
| 361 |
+
try:
|
| 362 |
+
from openspace.cloud.embedding import resolve_embedding_api
|
| 363 |
+
has_emb = bool(resolve_embedding_api()[0])
|
| 364 |
+
except Exception:
|
| 365 |
+
has_emb = False
|
| 366 |
+
items = await asyncio.to_thread(
|
| 367 |
+
client.fetch_metadata, include_embedding=has_emb, limit=200,
|
| 368 |
+
)
|
| 369 |
+
candidates.extend(build_cloud_candidates(items))
|
| 370 |
+
except Exception as e:
|
| 371 |
+
logger.warning(f"hybrid_search_skills: cloud unavailable: {e}")
|
| 372 |
+
|
| 373 |
+
if not candidates:
|
| 374 |
+
return []
|
| 375 |
+
|
| 376 |
+
# query embedding (optional — key/URL resolved inside generate_embedding)
|
| 377 |
+
query_embedding: Optional[List[float]] = None
|
| 378 |
+
try:
|
| 379 |
+
query_embedding = await asyncio.to_thread(generate_embedding, q)
|
| 380 |
+
if query_embedding:
|
| 381 |
+
for c in candidates:
|
| 382 |
+
if not c.get("_embedding") and c.get("_embedding_text"):
|
| 383 |
+
emb = await asyncio.to_thread(
|
| 384 |
+
generate_embedding, c["_embedding_text"],
|
| 385 |
+
)
|
| 386 |
+
if emb:
|
| 387 |
+
c["_embedding"] = emb
|
| 388 |
+
except Exception:
|
| 389 |
+
pass
|
| 390 |
+
|
| 391 |
+
engine = SkillSearchEngine()
|
| 392 |
+
return engine.search(q, candidates, query_embedding=query_embedding, limit=limit)
|
| 393 |
+
|
openspace/config/README.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔧 Configuration Guide
|
| 2 |
+
|
| 3 |
+
All configuration applies to both Path A (host agent) and Path B (standalone). Configure once before the first run.
|
| 4 |
+
|
| 5 |
+
## 1. API Keys (`.env`)
|
| 6 |
+
|
| 7 |
+
> [!NOTE]
|
| 8 |
+
> Create a `.env` file and add your API keys (refer to [`.env.example`](../../.env.example)). When used via host agent (Path A), LLM keys are auto-detected from your agent's config — `.env` is mainly needed for standalone mode.
|
| 9 |
+
|
| 10 |
+
## 2. Environment Variables
|
| 11 |
+
|
| 12 |
+
Set via `.env`, MCP config `env` block, or system environment. OpenSpace reads these at startup.
|
| 13 |
+
|
| 14 |
+
| Variable | Required | Description |
|
| 15 |
+
|----------|----------|-------------|
|
| 16 |
+
| `OPENSPACE_HOST_SKILL_DIRS` | Path A only | Your agent's skill directories (comma-separated). Auto-registered on startup. |
|
| 17 |
+
| `OPENSPACE_WORKSPACE` | Recommended | OpenSpace project root. Used for recording logs and workspace resolution. |
|
| 18 |
+
| `OPENSPACE_API_KEY` | No | Cloud API key (`sk-xxx`). Register at https://open-space.cloud. |
|
| 19 |
+
| `OPENSPACE_MODEL` | No | LLM model override (default: auto-detected or `openrouter/anthropic/claude-sonnet-4.5`). |
|
| 20 |
+
| `OPENSPACE_MAX_ITERATIONS` | No | Max agent iterations per task (default: `20`). |
|
| 21 |
+
| `OPENSPACE_BACKEND_SCOPE` | No | Enabled backends, comma-separated (default: all — `shell,gui,mcp,web,system`). |
|
| 22 |
+
|
| 23 |
+
### Advanced env overrides (rarely needed)
|
| 24 |
+
|
| 25 |
+
| Variable | Description |
|
| 26 |
+
|----------|-------------|
|
| 27 |
+
| `OPENSPACE_LLM_API_KEY` | LLM API key (auto-detected from host agent in Path A) |
|
| 28 |
+
| `OPENSPACE_LLM_API_BASE` | LLM API base URL |
|
| 29 |
+
| `OPENSPACE_LLM_EXTRA_HEADERS` | Extra HTTP headers for LLM requests (JSON string) |
|
| 30 |
+
| `OPENSPACE_LLM_CONFIG` | Arbitrary litellm kwargs (JSON string) |
|
| 31 |
+
| `OPENSPACE_API_BASE` | Cloud API base URL (default `https://open-space.cloud/api/v1`) |
|
| 32 |
+
| `OPENSPACE_CONFIG_PATH` | Custom grounding config JSON (deep-merged with defaults) |
|
| 33 |
+
| `OPENSPACE_SHELL_CONDA_ENV` | Conda environment for shell backend |
|
| 34 |
+
| `OPENSPACE_SHELL_WORKING_DIR` | Working directory for shell backend |
|
| 35 |
+
| `OPENSPACE_MCP_SERVERS_JSON` | MCP server definitions (JSON string, merged into `mcpServers`) |
|
| 36 |
+
| `OPENSPACE_ENABLE_RECORDING` | Record execution traces (default: `true`) |
|
| 37 |
+
| `OPENSPACE_LOG_LEVEL` | `DEBUG` / `INFO` / `WARNING` / `ERROR` |
|
| 38 |
+
|
| 39 |
+
## 3. MCP Servers (`config_mcp.json`)
|
| 40 |
+
|
| 41 |
+
Register external MCP servers that OpenSpace connects to as a **client** (e.g. GitHub, Slack, databases):
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
cp openspace/config/config_mcp.json.example openspace/config/config_mcp.json
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
```json
|
| 48 |
+
{
|
| 49 |
+
"mcpServers": {
|
| 50 |
+
"github": {
|
| 51 |
+
"command": "npx",
|
| 52 |
+
"args": ["-y", "@modelcontextprotocol/server-github"],
|
| 53 |
+
"env": { "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_TOKEN}" }
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## 4. Execution Mode: Local vs Server
|
| 60 |
+
|
| 61 |
+
Shell and GUI backends support two execution modes, set via `"mode"` in `config_grounding.json`:
|
| 62 |
+
|
| 63 |
+
| | Local Mode (`"local"`, default) | Server Mode (`"server"`) |
|
| 64 |
+
|---|---|---|
|
| 65 |
+
| **Setup** | Zero config | Start `local_server` first |
|
| 66 |
+
| **Use case** | Same-machine development | Remote VMs, sandboxing, multi-machine |
|
| 67 |
+
| **How** | `asyncio.subprocess` in-process | HTTP → Flask → subprocess |
|
| 68 |
+
|
| 69 |
+
> [!TIP]
|
| 70 |
+
> **Use local mode** for most use cases. For server mode setup (how to enable, platform-specific deps, remote VM control), see [`../local_server/README.md`](../local_server/README.md).
|
| 71 |
+
|
| 72 |
+
## 5. Config Files (`openspace/config/`)
|
| 73 |
+
|
| 74 |
+
Layered system — later files override earlier ones:
|
| 75 |
+
|
| 76 |
+
| File | Purpose |
|
| 77 |
+
|------|---------|
|
| 78 |
+
| `config_grounding.json` | Backend settings, smart tool retrieval, tool quality, skill discovery |
|
| 79 |
+
| `config_agents.json` | Agent definitions, backend scope, max iterations |
|
| 80 |
+
| `config_mcp.json` | MCP servers OpenSpace connects to as a client |
|
| 81 |
+
| `config_security.json` | Security policies, blocked commands, sandboxing |
|
| 82 |
+
| `config_dev.json` | Dev overrides — copy from `config_dev.json.example` (highest priority) |
|
| 83 |
+
|
| 84 |
+
### Agent config (`config_agents.json`)
|
| 85 |
+
|
| 86 |
+
```json
|
| 87 |
+
{ "agents": [{ "name": "GroundingAgent", "backend_scope": ["shell", "mcp", "web"], "max_iterations": 30 }] }
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
| Field | Description | Default |
|
| 91 |
+
|-------|-------------|---------|
|
| 92 |
+
| `backend_scope` | Enabled backends | `["gui", "shell", "mcp", "system", "web"]` |
|
| 93 |
+
| `max_iterations` | Max execution cycles | `20` |
|
| 94 |
+
| `visual_analysis_timeout` | Timeout for visual analysis (seconds) | `30.0` |
|
| 95 |
+
|
| 96 |
+
### Backend & tool config (`config_grounding.json`)
|
| 97 |
+
|
| 98 |
+
| Section | Key Fields | Description |
|
| 99 |
+
|---------|-----------|-------------|
|
| 100 |
+
| `shell` | `mode`, `timeout`, `conda_env`, `working_dir` | `"local"` (default) or `"server"`, command timeout (default: `60`s) |
|
| 101 |
+
| `gui` | `mode`, `timeout`, `driver_type`, `screenshot_on_error` | Local/server mode, automation driver (default: `pyautogui`) |
|
| 102 |
+
| `mcp` | `timeout`, `sandbox`, `eager_sessions` | Request timeout (`30`s), E2B sandbox, lazy/eager server init |
|
| 103 |
+
| `tool_search` | `search_mode`, `max_tools`, `enable_llm_filter` | `"hybrid"` (semantic + LLM), max tools to return (`40`), embedding cache |
|
| 104 |
+
| `tool_quality` | `enabled`, `enable_persistence`, `evolve_interval` | Quality tracking, self-evolution every N calls (default: `5`) |
|
| 105 |
+
| `skills` | `enabled`, `skill_dirs`, `max_select` | Directories to scan, max skills injected per task (default: `2`) |
|
| 106 |
+
|
| 107 |
+
### Security config (`config_security.json`)
|
| 108 |
+
|
| 109 |
+
| Field | Description | Default |
|
| 110 |
+
|-------|-------------|---------|
|
| 111 |
+
| `allow_shell_commands` | Enable shell execution | `true` |
|
| 112 |
+
| `blocked_commands` | Platform-specific blacklists (common/linux/darwin/windows) | `rm -rf`, `shutdown`, `dd`, etc. |
|
| 113 |
+
| `sandbox_enabled` | Enable sandboxing for all operations | `false` |
|
| 114 |
+
| Per-backend overrides | Shell, MCP, GUI, Web each have independent security policies | Inherit global |
|
| 115 |
+
|
openspace/config/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .grounding import *
|
| 2 |
+
from .loader import *
|
| 3 |
+
from .constants import *
|
| 4 |
+
from .utils import *
|
| 5 |
+
from . import constants
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
# Grounding Config
|
| 9 |
+
"BackendConfig",
|
| 10 |
+
"ShellConfig",
|
| 11 |
+
"WebConfig",
|
| 12 |
+
"MCPConfig",
|
| 13 |
+
"GUIConfig",
|
| 14 |
+
"ToolSearchConfig",
|
| 15 |
+
"SessionConfig",
|
| 16 |
+
"SecurityPolicy",
|
| 17 |
+
"GroundingConfig",
|
| 18 |
+
|
| 19 |
+
# Loader
|
| 20 |
+
"CONFIG_DIR",
|
| 21 |
+
"load_config",
|
| 22 |
+
"get_config",
|
| 23 |
+
"reset_config",
|
| 24 |
+
"save_config",
|
| 25 |
+
"load_agents_config",
|
| 26 |
+
"get_agent_config",
|
| 27 |
+
|
| 28 |
+
# Utils
|
| 29 |
+
"get_config_value",
|
| 30 |
+
"load_json_file",
|
| 31 |
+
"save_json_file",
|
| 32 |
+
] + constants.__all__
|
openspace/config/config_agents.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"agents": [
|
| 3 |
+
{
|
| 4 |
+
"name": "GroundingAgent",
|
| 5 |
+
"class_name": "GroundingAgent",
|
| 6 |
+
"backend_scope": ["shell", "mcp", "system"],
|
| 7 |
+
"max_iterations": 30,
|
| 8 |
+
"visual_analysis_timeout": 60.0
|
| 9 |
+
}
|
| 10 |
+
]
|
| 11 |
+
}
|
openspace/config/config_dev.json.example
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"comment": "[Optional] Loading grounding.json → security.json → dev.json (dev.json overrides the former ones)",
|
| 3 |
+
|
| 4 |
+
"debug": true,
|
| 5 |
+
"log_level": "DEBUG",
|
| 6 |
+
|
| 7 |
+
"security_policies": {
|
| 8 |
+
"global": {
|
| 9 |
+
"blocked_commands": []
|
| 10 |
+
}
|
| 11 |
+
}
|
| 12 |
+
}
|
openspace/config/config_grounding.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"shell": {
|
| 3 |
+
"mode": "local",
|
| 4 |
+
"timeout": 60,
|
| 5 |
+
"max_retries": 3,
|
| 6 |
+
"retry_interval": 3.0,
|
| 7 |
+
"default_shell": "/bin/bash",
|
| 8 |
+
"working_dir": null,
|
| 9 |
+
"env": {},
|
| 10 |
+
"conda_env": null,
|
| 11 |
+
"default_port": 5000
|
| 12 |
+
},
|
| 13 |
+
"mcp": {
|
| 14 |
+
"timeout": 30,
|
| 15 |
+
"max_retries": 3,
|
| 16 |
+
"retry_interval": 2.0,
|
| 17 |
+
"sandbox": false,
|
| 18 |
+
"auto_initialize": true,
|
| 19 |
+
"eager_sessions": false,
|
| 20 |
+
"sse_read_timeout": 300.0,
|
| 21 |
+
"check_dependencies": true,
|
| 22 |
+
"auto_install": true
|
| 23 |
+
},
|
| 24 |
+
"gui": {
|
| 25 |
+
"mode": "local",
|
| 26 |
+
"timeout": 90,
|
| 27 |
+
"max_retries": 3,
|
| 28 |
+
"retry_interval": 5.0,
|
| 29 |
+
"driver_type": "pyautogui",
|
| 30 |
+
"failsafe": false,
|
| 31 |
+
"screenshot_on_error": true,
|
| 32 |
+
"pkgs_prefix": "import pyautogui; import time; pyautogui.FAILSAFE = {failsafe}; {command}"
|
| 33 |
+
},
|
| 34 |
+
"tool_search": {
|
| 35 |
+
"embedding_model": "BAAI/bge-small-en-v1.5",
|
| 36 |
+
"max_tools": 40,
|
| 37 |
+
"search_mode": "hybrid",
|
| 38 |
+
"enable_llm_filter": true,
|
| 39 |
+
"llm_filter_threshold": 50,
|
| 40 |
+
"enable_cache_persistence": true,
|
| 41 |
+
"cache_dir": null
|
| 42 |
+
},
|
| 43 |
+
"tool_quality": {
|
| 44 |
+
"enabled": true,
|
| 45 |
+
"enable_persistence": true,
|
| 46 |
+
"cache_dir": null,
|
| 47 |
+
"auto_evaluate_descriptions": true,
|
| 48 |
+
"enable_quality_ranking": true,
|
| 49 |
+
"evolve_interval": 5
|
| 50 |
+
},
|
| 51 |
+
"skills": {
|
| 52 |
+
"enabled": true,
|
| 53 |
+
"skill_dirs": [],
|
| 54 |
+
"max_select": 2
|
| 55 |
+
},
|
| 56 |
+
|
| 57 |
+
"tool_cache_ttl": 600,
|
| 58 |
+
"tool_cache_maxsize": 500,
|
| 59 |
+
|
| 60 |
+
"debug": false,
|
| 61 |
+
"log_level": "INFO",
|
| 62 |
+
"enabled_backends": [
|
| 63 |
+
{
|
| 64 |
+
"name": "shell",
|
| 65 |
+
"provider_cls": "openspace.grounding.backends.shell.ShellProvider"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "web",
|
| 69 |
+
"provider_cls": "openspace.grounding.backends.web.WebProvider"
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"name": "mcp",
|
| 73 |
+
"provider_cls": "openspace.grounding.backends.mcp.MCPProvider"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"name": "gui",
|
| 77 |
+
"provider_cls": "openspace.grounding.backends.gui.GUIProvider"
|
| 78 |
+
}
|
| 79 |
+
],
|
| 80 |
+
|
| 81 |
+
"_comment_system_backend": "Note: 'system' backend is automatically registered and always available. It provides meta-level tools for querying system state. Do not add it to enabled_backends as it requires special initialization."
|
| 82 |
+
}
|
openspace/config/config_mcp.json.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mcpServers": {
|
| 3 |
+
"github": {
|
| 4 |
+
"command": "npx",
|
| 5 |
+
"args": ["-y", "@modelcontextprotocol/server-github"],
|
| 6 |
+
"env": {
|
| 7 |
+
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_TOKEN}"
|
| 8 |
+
}
|
| 9 |
+
}
|
| 10 |
+
}
|
| 11 |
+
}
|
openspace/config/config_security.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"security_policies": {
|
| 3 |
+
"global": {
|
| 4 |
+
"allow_shell_commands": true,
|
| 5 |
+
"allow_network_access": true,
|
| 6 |
+
"allow_file_access": true,
|
| 7 |
+
"blocked_commands": {
|
| 8 |
+
"common": ["rm", "-rf", "shutdown", "reboot", "poweroff", "halt"],
|
| 9 |
+
"linux": ["mkfs", "dd", "iptables", "systemctl", "init", "kill", "-9", "pkill"],
|
| 10 |
+
"darwin": ["diskutil", "dd", "pfctl", "launchctl", "killall"],
|
| 11 |
+
"windows": ["del", "format", "rd", "rmdir", "/s", "/q", "taskkill", "/f"]
|
| 12 |
+
},
|
| 13 |
+
"sandbox_enabled": false
|
| 14 |
+
},
|
| 15 |
+
"backend": {
|
| 16 |
+
"shell": {
|
| 17 |
+
"allow_shell_commands": true,
|
| 18 |
+
"allow_file_access": true,
|
| 19 |
+
"blocked_commands": {
|
| 20 |
+
"common": ["rm", "-rf", "shutdown", "reboot", "poweroff", "halt"],
|
| 21 |
+
"linux": [
|
| 22 |
+
"mkfs", "mkfs.ext4", "mkfs.xfs",
|
| 23 |
+
"dd",
|
| 24 |
+
"iptables", "ip6tables", "nftables",
|
| 25 |
+
"systemctl", "service",
|
| 26 |
+
"fdisk", "parted", "gdisk",
|
| 27 |
+
"mount", "umount",
|
| 28 |
+
"chmod", "777",
|
| 29 |
+
"chown", "root",
|
| 30 |
+
"passwd",
|
| 31 |
+
"useradd", "userdel", "usermod",
|
| 32 |
+
"kill", "-9", "pkill", "killall"
|
| 33 |
+
],
|
| 34 |
+
"darwin": [
|
| 35 |
+
"diskutil",
|
| 36 |
+
"dd",
|
| 37 |
+
"pfctl",
|
| 38 |
+
"launchctl",
|
| 39 |
+
"dscl",
|
| 40 |
+
"chmod", "777",
|
| 41 |
+
"chown", "root",
|
| 42 |
+
"passwd",
|
| 43 |
+
"killall",
|
| 44 |
+
"pmset"
|
| 45 |
+
],
|
| 46 |
+
"windows": [
|
| 47 |
+
"del", "erase",
|
| 48 |
+
"format",
|
| 49 |
+
"rd", "rmdir", "/s", "/q",
|
| 50 |
+
"diskpart",
|
| 51 |
+
"reg", "delete",
|
| 52 |
+
"net", "user",
|
| 53 |
+
"taskkill", "/f",
|
| 54 |
+
"wmic"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
"sandbox_enabled": false
|
| 58 |
+
},
|
| 59 |
+
"mcp": {
|
| 60 |
+
"sandbox_enabled": false
|
| 61 |
+
},
|
| 62 |
+
"web": {
|
| 63 |
+
"allow_network_access": true,
|
| 64 |
+
"allowed_domains": []
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
openspace/config/constants.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
CONFIG_GROUNDING = "config_grounding.json"
|
| 4 |
+
CONFIG_SECURITY = "config_security.json"
|
| 5 |
+
CONFIG_MCP = "config_mcp.json"
|
| 6 |
+
CONFIG_DEV = "config_dev.json"
|
| 7 |
+
CONFIG_AGENTS = "config_agents.json"
|
| 8 |
+
|
| 9 |
+
LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
| 10 |
+
|
| 11 |
+
# Project root directory (OpenSpace/)
|
| 12 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"CONFIG_GROUNDING",
|
| 17 |
+
"CONFIG_SECURITY",
|
| 18 |
+
"CONFIG_MCP",
|
| 19 |
+
"CONFIG_DEV",
|
| 20 |
+
"CONFIG_AGENTS",
|
| 21 |
+
"LOG_LEVELS",
|
| 22 |
+
"PROJECT_ROOT",
|
| 23 |
+
]
|
openspace/config/grounding.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Any, List, Literal
|
| 2 |
+
try:
|
| 3 |
+
from pydantic import BaseModel, Field, field_validator
|
| 4 |
+
PYDANTIC_V2 = True
|
| 5 |
+
except ImportError:
|
| 6 |
+
from pydantic import BaseModel, Field, validator as field_validator
|
| 7 |
+
PYDANTIC_V2 = False
|
| 8 |
+
|
| 9 |
+
from openspace.grounding.core.types import (
|
| 10 |
+
SessionConfig,
|
| 11 |
+
SecurityPolicy,
|
| 12 |
+
BackendType
|
| 13 |
+
)
|
| 14 |
+
from .constants import LOG_LEVELS
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConfigMixin:
|
| 18 |
+
"""Mixin to add utility methods for config access"""
|
| 19 |
+
|
| 20 |
+
def get_value(self, key: str, default=None):
|
| 21 |
+
"""
|
| 22 |
+
Safely get config value, works with both dict and Pydantic models.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
key: Configuration key
|
| 26 |
+
default: Default value if key not found
|
| 27 |
+
"""
|
| 28 |
+
if isinstance(self, dict):
|
| 29 |
+
return self.get(key, default)
|
| 30 |
+
else:
|
| 31 |
+
return getattr(self, key, default)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BackendConfig(BaseModel, ConfigMixin):
|
| 35 |
+
"""Base backend configuration"""
|
| 36 |
+
enabled: bool = Field(True, description="Whether the backend is enabled")
|
| 37 |
+
timeout: int = Field(30, ge=1, le=300, description="Timeout in seconds")
|
| 38 |
+
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ShellConfig(BackendConfig):
|
| 42 |
+
"""
|
| 43 |
+
Shell backend configuration
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
enabled: Whether shell backend is enabled
|
| 47 |
+
mode: Execution mode - "local" runs scripts in-process via subprocess,
|
| 48 |
+
"server" connects to a running local_server via HTTP
|
| 49 |
+
timeout: Default timeout for shell operations (seconds)
|
| 50 |
+
max_retries: Maximum number of retry attempts for failed operations
|
| 51 |
+
retry_interval: Wait time between retries (seconds)
|
| 52 |
+
default_shell: Path to default shell executable
|
| 53 |
+
working_dir: Default working directory for bash scripts
|
| 54 |
+
env: Default environment variables for shell operations
|
| 55 |
+
conda_env: Conda environment name to activate before execution (optional)
|
| 56 |
+
default_port: Default port for shell server connection (only used in server mode)
|
| 57 |
+
"""
|
| 58 |
+
mode: Literal["local", "server"] = Field("local", description="Execution mode: 'local' (in-process subprocess) or 'server' (HTTP local_server)")
|
| 59 |
+
retry_interval: float = Field(3.0, ge=0.1, le=60.0, description="Wait time between retries in seconds")
|
| 60 |
+
default_shell: str = Field("/bin/bash", description="Default shell path")
|
| 61 |
+
working_dir: Optional[str] = Field(None, description="Default working directory for bash scripts")
|
| 62 |
+
env: Dict[str, str] = Field(default_factory=dict, description="Default environment variables")
|
| 63 |
+
conda_env: Optional[str] = Field(None, description="Conda environment name to activate (e.g., 'myenv')")
|
| 64 |
+
default_port: int = Field(5000, ge=1, le=65535, description="Default port for shell server")
|
| 65 |
+
use_clawwork_productivity: bool = Field(
|
| 66 |
+
False,
|
| 67 |
+
description="If True and livebench is installed, add ClawWork productivity tools (search_web, read_webpage, create_file, read_file, execute_code_sandbox, create_video) for fair comparison with ClawWork."
|
| 68 |
+
)
|
| 69 |
+
productivity_date: str = Field(
|
| 70 |
+
"default",
|
| 71 |
+
description="Date segment for productivity sandbox paths (e.g. 'default' or 'YYYY-MM-DD'). Used when use_clawwork_productivity is True."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@field_validator('default_shell')
|
| 75 |
+
@classmethod
|
| 76 |
+
def validate_shell(cls, v):
|
| 77 |
+
if not v or not isinstance(v, str):
|
| 78 |
+
raise ValueError("Shell path must be a non-empty string")
|
| 79 |
+
return v
|
| 80 |
+
|
| 81 |
+
@field_validator('working_dir')
|
| 82 |
+
@classmethod
|
| 83 |
+
def validate_working_dir(cls, v):
|
| 84 |
+
if v is not None and not isinstance(v, str):
|
| 85 |
+
raise ValueError("Working directory must be a string")
|
| 86 |
+
return v
|
| 87 |
+
|
| 88 |
+
class WebConfig(BackendConfig):
|
| 89 |
+
"""
|
| 90 |
+
Web backend configuration - AI Deep Research
|
| 91 |
+
|
| 92 |
+
Attributes:
|
| 93 |
+
enabled: Whether web backend is enabled
|
| 94 |
+
timeout: Default timeout for web operations (seconds)
|
| 95 |
+
max_retries: Maximum number of retry attempts
|
| 96 |
+
|
| 97 |
+
Note:
|
| 98 |
+
All web-specific parameters (API key, base URL) are loaded from
|
| 99 |
+
environment variables or use default values in WebSession:
|
| 100 |
+
- OPENROUTER_API_KEY: API key for deep research (required)
|
| 101 |
+
- Deep research base URL defaults to "https://openrouter.ai/api/v1"
|
| 102 |
+
"""
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class MCPConfig(BackendConfig):
|
| 107 |
+
"""MCP backend configuration"""
|
| 108 |
+
sandbox: bool = Field(False, description="Whether to enable sandbox")
|
| 109 |
+
auto_initialize: bool = Field(True, description="Whether to auto initialize")
|
| 110 |
+
eager_sessions: bool = Field(False, description="Whether to eagerly create sessions for all servers on initialization")
|
| 111 |
+
retry_interval: float = Field(2.0, ge=0.1, le=60.0, description="Wait time between retries in seconds")
|
| 112 |
+
servers: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="MCP servers configuration, loaded from config_mcp.json")
|
| 113 |
+
sse_read_timeout: float = Field(300.0, ge=1.0, le=3600.0, description="SSE read timeout in seconds for HTTP/Sandbox connectors")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GUIConfig(BackendConfig):
|
| 117 |
+
"""
|
| 118 |
+
GUI backend configuration
|
| 119 |
+
|
| 120 |
+
Attributes:
|
| 121 |
+
mode: Execution mode - "local" runs GUI operations in-process,
|
| 122 |
+
"server" connects to a running local_server via HTTP
|
| 123 |
+
"""
|
| 124 |
+
mode: Literal["local", "server"] = Field("local", description="Execution mode: 'local' (in-process) or 'server' (HTTP local_server)")
|
| 125 |
+
retry_interval: float = Field(5.0, ge=0.1, le=60.0, description="Wait time between retries in seconds")
|
| 126 |
+
driver_type: str = Field("pyautogui", description="GUI driver type")
|
| 127 |
+
failsafe: bool = Field(False, description="Whether to enable pyautogui failsafe mode")
|
| 128 |
+
screenshot_on_error: bool = Field(True, description="Whether to capture screenshot on error")
|
| 129 |
+
pkgs_prefix: str = Field(
|
| 130 |
+
"import pyautogui; import time; pyautogui.FAILSAFE = {failsafe}; {command}",
|
| 131 |
+
description="Python command prefix for pyautogui setup"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ToolSearchConfig(BaseModel):
|
| 136 |
+
"""Tool search and ranking configuration"""
|
| 137 |
+
embedding_model: str = Field(
|
| 138 |
+
"BAAI/bge-small-en-v1.5",
|
| 139 |
+
description="Embedding model name for semantic search"
|
| 140 |
+
)
|
| 141 |
+
max_tools: int = Field(
|
| 142 |
+
20,
|
| 143 |
+
ge=1,
|
| 144 |
+
le=1000,
|
| 145 |
+
description="Maximum number of tools to return from search"
|
| 146 |
+
)
|
| 147 |
+
search_mode: str = Field(
|
| 148 |
+
"hybrid",
|
| 149 |
+
description="Default search mode: semantic, keyword, or hybrid"
|
| 150 |
+
)
|
| 151 |
+
enable_llm_filter: bool = Field(
|
| 152 |
+
True,
|
| 153 |
+
description="Whether to use LLM for backend/server filtering"
|
| 154 |
+
)
|
| 155 |
+
llm_filter_threshold: int = Field(
|
| 156 |
+
50,
|
| 157 |
+
ge=1,
|
| 158 |
+
le=1000,
|
| 159 |
+
description="Only apply LLM filter when tool count exceeds this threshold"
|
| 160 |
+
)
|
| 161 |
+
enable_cache_persistence: bool = Field(
|
| 162 |
+
False,
|
| 163 |
+
description="Whether to persist embeddings to disk"
|
| 164 |
+
)
|
| 165 |
+
cache_dir: Optional[str] = Field(
|
| 166 |
+
None,
|
| 167 |
+
description="Directory for embedding cache. None means use default <project_root>/.openspace/embedding_cache"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
@field_validator('search_mode')
|
| 171 |
+
@classmethod
|
| 172 |
+
def validate_search_mode(cls, v):
|
| 173 |
+
valid_modes = ['semantic', 'keyword', 'hybrid']
|
| 174 |
+
if v.lower() not in valid_modes:
|
| 175 |
+
raise ValueError(f"Search mode must be one of {valid_modes}, got: {v}")
|
| 176 |
+
return v.lower()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class ToolQualityConfig(BaseModel):
|
| 180 |
+
"""Tool quality tracking configuration"""
|
| 181 |
+
enabled: bool = Field(
|
| 182 |
+
True,
|
| 183 |
+
description="Whether to enable tool quality tracking"
|
| 184 |
+
)
|
| 185 |
+
enable_persistence: bool = Field(
|
| 186 |
+
True,
|
| 187 |
+
description="Whether to persist quality data to disk"
|
| 188 |
+
)
|
| 189 |
+
cache_dir: Optional[str] = Field(
|
| 190 |
+
None,
|
| 191 |
+
description="Directory for quality cache. None means use default <project_root>/.openspace/tool_quality"
|
| 192 |
+
)
|
| 193 |
+
auto_evaluate_descriptions: bool = Field(
|
| 194 |
+
True,
|
| 195 |
+
description="Whether to automatically evaluate tool descriptions using LLM"
|
| 196 |
+
)
|
| 197 |
+
enable_quality_ranking: bool = Field(
|
| 198 |
+
True,
|
| 199 |
+
description="Whether to incorporate quality scores in tool ranking"
|
| 200 |
+
)
|
| 201 |
+
evolve_interval: int = Field(
|
| 202 |
+
5,
|
| 203 |
+
ge=1,
|
| 204 |
+
le=100,
|
| 205 |
+
description="Trigger quality evolution every N tool executions"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class SkillConfig(BaseModel):
|
| 210 |
+
"""Skill engine configuration
|
| 211 |
+
|
| 212 |
+
Controls how skills are discovered, selected and injected.
|
| 213 |
+
Built-in skills (``openspace/skills/``) are always auto-discovered.
|
| 214 |
+
"""
|
| 215 |
+
enabled: bool = Field(True, description="Enable skill matching and injection")
|
| 216 |
+
skill_dirs: List[str] = Field(
|
| 217 |
+
default_factory=list,
|
| 218 |
+
description="Extra skill directories. Built-in openspace/skills/ is always included."
|
| 219 |
+
)
|
| 220 |
+
max_select: int = Field(
|
| 221 |
+
2, ge=1, le=20,
|
| 222 |
+
description="Maximum number of skills to inject per task"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class GroundingConfig(BaseModel):
|
| 227 |
+
"""
|
| 228 |
+
Main configuration for Grounding module.
|
| 229 |
+
|
| 230 |
+
Contains configuration for all grounding backends and grounding-level settings.
|
| 231 |
+
Note: Local server connection uses defaults or environment variables (LOCAL_SERVER_URL).
|
| 232 |
+
"""
|
| 233 |
+
# Backend configurations
|
| 234 |
+
shell: ShellConfig = Field(default_factory=ShellConfig)
|
| 235 |
+
web: WebConfig = Field(default_factory=WebConfig)
|
| 236 |
+
mcp: MCPConfig = Field(default_factory=MCPConfig)
|
| 237 |
+
gui: GUIConfig = Field(default_factory=GUIConfig)
|
| 238 |
+
system: BackendConfig = Field(default_factory=BackendConfig)
|
| 239 |
+
|
| 240 |
+
# Grounding-level settings
|
| 241 |
+
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig)
|
| 242 |
+
tool_quality: ToolQualityConfig = Field(default_factory=ToolQualityConfig)
|
| 243 |
+
skills: SkillConfig = Field(default_factory=SkillConfig)
|
| 244 |
+
|
| 245 |
+
enabled_backends: List[Dict[str, str]] = Field(
|
| 246 |
+
default_factory=list,
|
| 247 |
+
description="List of enabled backends, each item: {'name': str, 'provider_cls': str}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
session_defaults: SessionConfig = Field(
|
| 251 |
+
default_factory=lambda: SessionConfig(
|
| 252 |
+
session_name="",
|
| 253 |
+
backend_type=BackendType.SHELL,
|
| 254 |
+
timeout=30,
|
| 255 |
+
auto_reconnect=True,
|
| 256 |
+
health_check_interval=30
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
tool_cache_ttl: int = Field(
|
| 261 |
+
300,
|
| 262 |
+
ge=1,
|
| 263 |
+
le=3600,
|
| 264 |
+
description="Tool cache time-to-live in seconds"
|
| 265 |
+
)
|
| 266 |
+
tool_cache_maxsize: int = Field(
|
| 267 |
+
300,
|
| 268 |
+
ge=1,
|
| 269 |
+
le=10000,
|
| 270 |
+
description="Maximum number of tool cache entries"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
debug: bool = Field(False, description="Debug mode")
|
| 274 |
+
log_level: str = Field("INFO", description="Log level")
|
| 275 |
+
security_policies: Dict[str, Any] = Field(default_factory=dict)
|
| 276 |
+
|
| 277 |
+
@field_validator('log_level')
|
| 278 |
+
@classmethod
|
| 279 |
+
def validate_log_level(cls, v):
|
| 280 |
+
if v.upper() not in LOG_LEVELS:
|
| 281 |
+
raise ValueError(f"Log level must be one of {LOG_LEVELS}, got: {v}")
|
| 282 |
+
return v.upper()
|
| 283 |
+
|
| 284 |
+
def get_backend_config(self, backend_type: str) -> BackendConfig:
|
| 285 |
+
"""Get configuration for specified backend"""
|
| 286 |
+
name = backend_type.lower()
|
| 287 |
+
if not hasattr(self, name):
|
| 288 |
+
from openspace.utils.logging import Logger
|
| 289 |
+
logger = Logger.get_logger(__name__)
|
| 290 |
+
logger.warning(f"Unknown backend type: {backend_type}")
|
| 291 |
+
return BackendConfig()
|
| 292 |
+
return getattr(self, name)
|
| 293 |
+
|
| 294 |
+
def get_security_policy(self, backend_type: str) -> SecurityPolicy:
|
| 295 |
+
global_policy = self.security_policies.get("global", {})
|
| 296 |
+
backend_policy = self.security_policies.get("backend", {}).get(backend_type.lower(), {})
|
| 297 |
+
merged_policy = {**global_policy, **backend_policy}
|
| 298 |
+
return SecurityPolicy.from_dict(merged_policy)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
__all__ = [
|
| 302 |
+
"BackendConfig",
|
| 303 |
+
"ShellConfig",
|
| 304 |
+
"WebConfig",
|
| 305 |
+
"MCPConfig",
|
| 306 |
+
"GUIConfig",
|
| 307 |
+
"ToolSearchConfig",
|
| 308 |
+
"ToolQualityConfig",
|
| 309 |
+
"SkillConfig",
|
| 310 |
+
"GroundingConfig",
|
| 311 |
+
]
|
openspace/config/loader.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Union, Iterable, Dict, Any, Optional
|
| 4 |
+
|
| 5 |
+
from .grounding import GroundingConfig
|
| 6 |
+
from .constants import (
|
| 7 |
+
CONFIG_GROUNDING,
|
| 8 |
+
CONFIG_SECURITY,
|
| 9 |
+
CONFIG_DEV,
|
| 10 |
+
CONFIG_MCP,
|
| 11 |
+
CONFIG_AGENTS
|
| 12 |
+
)
|
| 13 |
+
from openspace.utils.logging import Logger
|
| 14 |
+
from .utils import load_json_file, save_json_file as save_json
|
| 15 |
+
|
| 16 |
+
logger = Logger.get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
CONFIG_DIR = Path(__file__).parent
|
| 20 |
+
|
| 21 |
+
# Global configuration singleton
|
| 22 |
+
_config: GroundingConfig | None = None
|
| 23 |
+
_config_lock = threading.RLock() # Use RLock to support recursive locking
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _deep_merge_dict(base: dict, update: dict) -> dict:
|
| 27 |
+
"""Deep merge two dictionaries, update's values will override base's values"""
|
| 28 |
+
result = base.copy()
|
| 29 |
+
for key, value in update.items():
|
| 30 |
+
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
| 31 |
+
result[key] = _deep_merge_dict(result[key], value)
|
| 32 |
+
else:
|
| 33 |
+
result[key] = value
|
| 34 |
+
return result
|
| 35 |
+
|
| 36 |
+
def _load_json_file(path: Path) -> Dict[str, Any]:
|
| 37 |
+
"""Load single JSON configuration file.
|
| 38 |
+
|
| 39 |
+
This function wraps the generic load_json_file and adds global configuration specific error handling and logging.
|
| 40 |
+
"""
|
| 41 |
+
if not path.exists():
|
| 42 |
+
logger.debug(f"Configuration file does not exist, skipping: {path}")
|
| 43 |
+
return {}
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
data = load_json_file(path)
|
| 47 |
+
logger.info(f"Loaded configuration file: {path}")
|
| 48 |
+
return data
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.warning(f"Failed to load configuration file {path}: {e}")
|
| 51 |
+
return {}
|
| 52 |
+
|
| 53 |
+
def _load_multiple_files(paths: Iterable[Path]) -> Dict[str, Any]:
|
| 54 |
+
"""Load configuration from multiple files"""
|
| 55 |
+
merged = {}
|
| 56 |
+
for path in paths:
|
| 57 |
+
data = _load_json_file(path)
|
| 58 |
+
if data:
|
| 59 |
+
merged = _deep_merge_dict(merged, data)
|
| 60 |
+
return merged
|
| 61 |
+
|
| 62 |
+
def load_config(*config_paths: Union[str, Path]) -> GroundingConfig:
|
| 63 |
+
"""
|
| 64 |
+
Load configuration files
|
| 65 |
+
"""
|
| 66 |
+
global _config
|
| 67 |
+
|
| 68 |
+
with _config_lock:
|
| 69 |
+
if config_paths:
|
| 70 |
+
paths = [Path(p) for p in config_paths]
|
| 71 |
+
else:
|
| 72 |
+
paths = [
|
| 73 |
+
CONFIG_DIR / CONFIG_GROUNDING,
|
| 74 |
+
CONFIG_DIR / CONFIG_SECURITY,
|
| 75 |
+
CONFIG_DIR / CONFIG_DEV, # Optional: development environment configuration
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Load and merge configuration
|
| 79 |
+
raw_data = _load_multiple_files(paths)
|
| 80 |
+
|
| 81 |
+
# Load MCP configuration (separate processing)
|
| 82 |
+
# Check if mcpServers already provided in merged custom configs
|
| 83 |
+
has_custom_mcp_servers = "mcpServers" in raw_data
|
| 84 |
+
|
| 85 |
+
if has_custom_mcp_servers:
|
| 86 |
+
# Use mcpServers from custom config
|
| 87 |
+
if "mcp" not in raw_data:
|
| 88 |
+
raw_data["mcp"] = {}
|
| 89 |
+
raw_data["mcp"]["servers"] = raw_data.pop("mcpServers")
|
| 90 |
+
logger.debug(f"Using custom MCP servers from provided config ({len(raw_data['mcp']['servers'])} servers)")
|
| 91 |
+
else:
|
| 92 |
+
# Load default MCP servers from config_mcp.json
|
| 93 |
+
mcp_data = _load_json_file(CONFIG_DIR / CONFIG_MCP)
|
| 94 |
+
if mcp_data and "mcpServers" in mcp_data:
|
| 95 |
+
if "mcp" not in raw_data:
|
| 96 |
+
raw_data["mcp"] = {}
|
| 97 |
+
raw_data["mcp"]["servers"] = mcp_data["mcpServers"]
|
| 98 |
+
logger.debug(f"Loaded MCP servers from default config_mcp.json ({len(raw_data['mcp']['servers'])} servers)")
|
| 99 |
+
|
| 100 |
+
# Validate and create configuration object
|
| 101 |
+
try:
|
| 102 |
+
_config = GroundingConfig.model_validate(raw_data)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"Validation failed, using default configuration: {e}")
|
| 105 |
+
_config = GroundingConfig()
|
| 106 |
+
|
| 107 |
+
# Adjust log level according to configuration
|
| 108 |
+
if _config.debug:
|
| 109 |
+
Logger.set_debug(2)
|
| 110 |
+
elif _config.log_level:
|
| 111 |
+
try:
|
| 112 |
+
Logger.configure(level=_config.log_level)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.warning(f"Failed to set log level {_config.log_level}: {e}")
|
| 115 |
+
|
| 116 |
+
return _config
|
| 117 |
+
|
| 118 |
+
def get_config() -> GroundingConfig:
|
| 119 |
+
"""
|
| 120 |
+
Get global configuration instance.
|
| 121 |
+
|
| 122 |
+
Usage:
|
| 123 |
+
- Get configuration in Provider: get_config().get_backend_config('shell')
|
| 124 |
+
- Get security policy in Tool: get_config().get_security_policy('shell')
|
| 125 |
+
"""
|
| 126 |
+
global _config
|
| 127 |
+
|
| 128 |
+
if _config is None:
|
| 129 |
+
with _config_lock:
|
| 130 |
+
if _config is None:
|
| 131 |
+
load_config()
|
| 132 |
+
|
| 133 |
+
return _config
|
| 134 |
+
|
| 135 |
+
def reset_config() -> None:
|
| 136 |
+
"""Reset configuration (for testing)"""
|
| 137 |
+
global _config
|
| 138 |
+
with _config_lock:
|
| 139 |
+
_config = None
|
| 140 |
+
|
| 141 |
+
def save_config(config: GroundingConfig, path: Union[str, Path]) -> None:
|
| 142 |
+
save_json(config.model_dump(), path)
|
| 143 |
+
logger.info(f"Configuration saved to: {path}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_agents_config() -> Dict[str, Any]:
|
| 147 |
+
agents_config_path = CONFIG_DIR / CONFIG_AGENTS
|
| 148 |
+
return _load_json_file(agents_config_path)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_agent_config(agent_name: str) -> Optional[Dict[str, Any]]:
|
| 152 |
+
"""
|
| 153 |
+
Get the configuration of the specified agent
|
| 154 |
+
"""
|
| 155 |
+
agents_config = load_agents_config()
|
| 156 |
+
|
| 157 |
+
if "agents" not in agents_config:
|
| 158 |
+
logger.warning(f"No 'agents' key found in {CONFIG_AGENTS}")
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
for agent_cfg in agents_config.get("agents", []):
|
| 162 |
+
if agent_cfg.get("name") == agent_name:
|
| 163 |
+
return agent_cfg
|
| 164 |
+
|
| 165 |
+
logger.warning(f"Agent '{agent_name}' not found in {CONFIG_AGENTS}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
__all__ = [
|
| 170 |
+
"CONFIG_DIR",
|
| 171 |
+
"load_config",
|
| 172 |
+
"get_config",
|
| 173 |
+
"reset_config",
|
| 174 |
+
"save_config",
|
| 175 |
+
"load_agents_config",
|
| 176 |
+
"get_agent_config"
|
| 177 |
+
]
|
openspace/config/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_config_value(config: Any, key: str, default=None):
|
| 7 |
+
if isinstance(config, dict):
|
| 8 |
+
return config.get(key, default)
|
| 9 |
+
else:
|
| 10 |
+
return getattr(config, key, default)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_json_file(filepath: str | Path) -> dict[str, Any]:
|
| 14 |
+
filepath = Path(filepath) if isinstance(filepath, str) else filepath
|
| 15 |
+
|
| 16 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 17 |
+
return json.load(f)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def save_json_file(data: dict[str, Any], filepath: str | Path, indent: int = 2) -> None:
|
| 21 |
+
filepath = Path(filepath) if isinstance(filepath, str) else filepath
|
| 22 |
+
|
| 23 |
+
# Ensure directory exists
|
| 24 |
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 27 |
+
json.dump(data, f, indent=indent, ensure_ascii=False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = ["get_config_value", "load_json_file", "save_json_file"]
|
openspace/dashboard_server.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 8 |
+
|
| 9 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 10 |
+
|
| 11 |
+
from flask import Flask, abort, jsonify, send_from_directory, url_for, request
|
| 12 |
+
|
| 13 |
+
from openspace.recording.action_recorder import analyze_agent_actions, load_agent_actions
|
| 14 |
+
from openspace.recording.utils import load_recording_session
|
| 15 |
+
from openspace.skill_engine import SkillStore
|
| 16 |
+
from openspace.skill_engine.types import SkillRecord
|
| 17 |
+
|
| 18 |
+
API_PREFIX = "/api/v1"
|
| 19 |
+
FRONTEND_DIST_DIR = PROJECT_ROOT / "frontend" / "dist"
|
| 20 |
+
WORKFLOW_ROOTS = [
|
| 21 |
+
PROJECT_ROOT / "logs" / "recordings",
|
| 22 |
+
PROJECT_ROOT / "logs" / "trajectories",
|
| 23 |
+
PROJECT_ROOT / "gdpval_bench" / "results",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
PIPELINE_STAGES = [
|
| 27 |
+
{
|
| 28 |
+
"id": "initialize",
|
| 29 |
+
"title": "Initialize",
|
| 30 |
+
"description": "Load LLM, grounding backends, recording, registry, analyzer, and evolver.",
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"id": "select-skills",
|
| 34 |
+
"title": "Skill Selection",
|
| 35 |
+
"description": "Select candidate skills and write selection metadata before execution.",
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"id": "phase-1-skill",
|
| 39 |
+
"title": "Skill Phase",
|
| 40 |
+
"description": "Run the task with injected skill context whenever matching skills exist.",
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": "phase-2-fallback",
|
| 44 |
+
"title": "Tool Fallback",
|
| 45 |
+
"description": "Fallback to tool-only execution when the skill-guided phase fails or no skills match.",
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"id": "analysis",
|
| 49 |
+
"title": "Execution Analysis",
|
| 50 |
+
"description": "Persist metadata, trajectory, and post-run execution judgments.",
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"id": "evolution",
|
| 54 |
+
"title": "Skill Evolution",
|
| 55 |
+
"description": "Trigger fix / derived / captured evolution and periodic quality checks.",
|
| 56 |
+
},
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
_STORE: SkillStore | None = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_app() -> Flask:
|
| 63 |
+
app = Flask(__name__, static_folder=None)
|
| 64 |
+
|
| 65 |
+
@app.before_request
|
| 66 |
+
def check_api_key():
|
| 67 |
+
# Allow preflight requests (CORS)
|
| 68 |
+
if request.method == "OPTIONS":
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
expected_key = os.environ.get("OPENSPACE_API_KEY")
|
| 72 |
+
if expected_key:
|
| 73 |
+
auth_header = request.headers.get("Authorization")
|
| 74 |
+
if not auth_header or auth_header != f"Bearer {expected_key}":
|
| 75 |
+
abort(401, description="Unauthorized: Invalid or missing API Key")
|
| 76 |
+
|
| 77 |
+
@app.route(f"{API_PREFIX}/health", methods=["GET"])
|
| 78 |
+
def health() -> Any:
|
| 79 |
+
workflows = _discover_workflow_dirs()
|
| 80 |
+
store = _get_store()
|
| 81 |
+
return jsonify(
|
| 82 |
+
{
|
| 83 |
+
"status": "ok",
|
| 84 |
+
"project_root": str(PROJECT_ROOT),
|
| 85 |
+
"db_path": str(store.db_path),
|
| 86 |
+
"db_exists": store.db_path.exists(),
|
| 87 |
+
"frontend_dist_exists": FRONTEND_DIST_DIR.exists(),
|
| 88 |
+
"workflow_roots": [str(path) for path in WORKFLOW_ROOTS],
|
| 89 |
+
"workflow_count": len(workflows),
|
| 90 |
+
}
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@app.route(f"{API_PREFIX}/overview", methods=["GET"])
|
| 94 |
+
def overview() -> Any:
|
| 95 |
+
store = _get_store()
|
| 96 |
+
skills = list(store.load_all(active_only=False).values())
|
| 97 |
+
workflows = [_build_workflow_summary(path) for path in _discover_workflow_dirs()]
|
| 98 |
+
top_skills = _sort_skills(skills, sort_key="score")[:5]
|
| 99 |
+
recent_skills = _sort_skills(skills, sort_key="updated")[:5]
|
| 100 |
+
average_score = round(
|
| 101 |
+
sum(_skill_score(record) for record in skills) / len(skills), 1
|
| 102 |
+
) if skills else 0.0
|
| 103 |
+
average_workflow_success = round(
|
| 104 |
+
(sum((item.get("success_rate") or 0.0) for item in workflows) / len(workflows)) * 100,
|
| 105 |
+
1,
|
| 106 |
+
) if workflows else 0.0
|
| 107 |
+
|
| 108 |
+
return jsonify(
|
| 109 |
+
{
|
| 110 |
+
"health": {
|
| 111 |
+
"status": "ok",
|
| 112 |
+
"db_path": str(store.db_path),
|
| 113 |
+
"workflow_count": len(workflows),
|
| 114 |
+
"frontend_dist_exists": FRONTEND_DIST_DIR.exists(),
|
| 115 |
+
},
|
| 116 |
+
"pipeline": PIPELINE_STAGES,
|
| 117 |
+
"skills": {
|
| 118 |
+
"summary": _build_skill_stats(store, skills),
|
| 119 |
+
"average_score": average_score,
|
| 120 |
+
"top": [_serialize_skill(item) for item in top_skills],
|
| 121 |
+
"recent": [_serialize_skill(item) for item in recent_skills],
|
| 122 |
+
},
|
| 123 |
+
"workflows": {
|
| 124 |
+
"total": len(workflows),
|
| 125 |
+
"average_success_rate": average_workflow_success,
|
| 126 |
+
"recent": workflows[:5],
|
| 127 |
+
},
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@app.route(f"{API_PREFIX}/skills", methods=["GET"])
|
| 132 |
+
def list_skills() -> Any:
|
| 133 |
+
store = _get_store()
|
| 134 |
+
active_only = _bool_arg("active_only", True)
|
| 135 |
+
limit = _int_arg("limit", 100)
|
| 136 |
+
sort_key = (_str_arg("sort", "score") or "score").lower()
|
| 137 |
+
skills = list(store.load_all(active_only=active_only).values())
|
| 138 |
+
query = (_str_arg("query", "") or "").strip().lower()
|
| 139 |
+
if query:
|
| 140 |
+
skills = [
|
| 141 |
+
record
|
| 142 |
+
for record in skills
|
| 143 |
+
if query in record.name.lower()
|
| 144 |
+
or query in record.skill_id.lower()
|
| 145 |
+
or query in record.description.lower()
|
| 146 |
+
or any(query in tag.lower() for tag in record.tags)
|
| 147 |
+
]
|
| 148 |
+
items = [_serialize_skill(item) for item in _sort_skills(skills, sort_key=sort_key)[:limit]]
|
| 149 |
+
return jsonify({"items": items, "count": len(items), "active_only": active_only})
|
| 150 |
+
|
| 151 |
+
@app.route(f"{API_PREFIX}/skills/stats", methods=["GET"])
|
| 152 |
+
def skill_stats() -> Any:
|
| 153 |
+
store = _get_store()
|
| 154 |
+
skills = list(store.load_all(active_only=False).values())
|
| 155 |
+
return jsonify(_build_skill_stats(store, skills))
|
| 156 |
+
|
| 157 |
+
@app.route(f"{API_PREFIX}/skills/<skill_id>", methods=["GET"])
|
| 158 |
+
def skill_detail(skill_id: str) -> Any:
|
| 159 |
+
store = _get_store()
|
| 160 |
+
record = store.load_record(skill_id)
|
| 161 |
+
if not record:
|
| 162 |
+
abort(404, description=f"Unknown skill_id: {skill_id}")
|
| 163 |
+
|
| 164 |
+
detail = _serialize_skill(record, include_recent_analyses=True)
|
| 165 |
+
detail["lineage_graph"] = _build_lineage_payload(skill_id, store)
|
| 166 |
+
detail["recent_analyses"] = [analysis.to_dict() for analysis in store.load_analyses(skill_id=skill_id, limit=10)]
|
| 167 |
+
detail["source"] = _load_skill_source(record)
|
| 168 |
+
return jsonify(detail)
|
| 169 |
+
|
| 170 |
+
@app.route(f"{API_PREFIX}/skills/<skill_id>/lineage", methods=["GET"])
|
| 171 |
+
def skill_lineage(skill_id: str) -> Any:
|
| 172 |
+
store = _get_store()
|
| 173 |
+
if not store.load_record(skill_id):
|
| 174 |
+
abort(404, description=f"Unknown skill_id: {skill_id}")
|
| 175 |
+
return jsonify(_build_lineage_payload(skill_id, store))
|
| 176 |
+
|
| 177 |
+
@app.route(f"{API_PREFIX}/skills/<skill_id>/source", methods=["GET"])
|
| 178 |
+
def skill_source(skill_id: str) -> Any:
|
| 179 |
+
store = _get_store()
|
| 180 |
+
record = store.load_record(skill_id)
|
| 181 |
+
if not record:
|
| 182 |
+
abort(404, description=f"Unknown skill_id: {skill_id}")
|
| 183 |
+
return jsonify(_load_skill_source(record))
|
| 184 |
+
|
| 185 |
+
@app.route(f"{API_PREFIX}/workflows", methods=["GET"])
|
| 186 |
+
def list_workflows() -> Any:
|
| 187 |
+
items = [_build_workflow_summary(path) for path in _discover_workflow_dirs()]
|
| 188 |
+
return jsonify({"items": items, "count": len(items)})
|
| 189 |
+
|
| 190 |
+
@app.route(f"{API_PREFIX}/workflows/<workflow_id>", methods=["GET"])
|
| 191 |
+
def workflow_detail(workflow_id: str) -> Any:
|
| 192 |
+
workflow_dir = _get_workflow_dir(workflow_id)
|
| 193 |
+
if not workflow_dir:
|
| 194 |
+
abort(404, description=f"Unknown workflow: {workflow_id}")
|
| 195 |
+
|
| 196 |
+
session = load_recording_session(str(workflow_dir))
|
| 197 |
+
actions = load_agent_actions(str(workflow_dir))
|
| 198 |
+
metadata = session.get("metadata") or {}
|
| 199 |
+
trajectory = session.get("trajectory") or []
|
| 200 |
+
plans = session.get("plans") or []
|
| 201 |
+
decisions = session.get("decisions") or []
|
| 202 |
+
action_stats = analyze_agent_actions(actions)
|
| 203 |
+
|
| 204 |
+
enriched_trajectory = []
|
| 205 |
+
for step in trajectory:
|
| 206 |
+
step_copy = dict(step)
|
| 207 |
+
screenshot_rel = step_copy.get("screenshot")
|
| 208 |
+
if screenshot_rel:
|
| 209 |
+
step_copy["screenshot_url"] = url_for(
|
| 210 |
+
"workflow_artifact",
|
| 211 |
+
workflow_id=workflow_id,
|
| 212 |
+
artifact_path=screenshot_rel,
|
| 213 |
+
)
|
| 214 |
+
enriched_trajectory.append(step_copy)
|
| 215 |
+
|
| 216 |
+
timeline = _build_timeline(actions, enriched_trajectory)
|
| 217 |
+
artifacts = _build_workflow_artifacts(workflow_dir, workflow_id, metadata)
|
| 218 |
+
|
| 219 |
+
return jsonify(
|
| 220 |
+
{
|
| 221 |
+
**_build_workflow_summary(workflow_dir),
|
| 222 |
+
"metadata": metadata,
|
| 223 |
+
"statistics": session.get("statistics") or {},
|
| 224 |
+
"trajectory": enriched_trajectory,
|
| 225 |
+
"plans": plans,
|
| 226 |
+
"decisions": decisions,
|
| 227 |
+
"agent_actions": actions,
|
| 228 |
+
"agent_statistics": action_stats,
|
| 229 |
+
"timeline": timeline,
|
| 230 |
+
"artifacts": artifacts,
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@app.route(f"{API_PREFIX}/workflows/<workflow_id>/artifacts/<path:artifact_path>", methods=["GET"])
|
| 235 |
+
def workflow_artifact(workflow_id: str, artifact_path: str) -> Any:
|
| 236 |
+
workflow_dir = _get_workflow_dir(workflow_id)
|
| 237 |
+
if not workflow_dir:
|
| 238 |
+
abort(404, description=f"Unknown workflow: {workflow_id}")
|
| 239 |
+
|
| 240 |
+
target = (workflow_dir / artifact_path).resolve()
|
| 241 |
+
root = workflow_dir.resolve()
|
| 242 |
+
if root not in target.parents and target != root:
|
| 243 |
+
abort(404)
|
| 244 |
+
if not target.exists() or not target.is_file():
|
| 245 |
+
abort(404)
|
| 246 |
+
return send_from_directory(str(target.parent), target.name)
|
| 247 |
+
|
| 248 |
+
@app.route("/", defaults={"path": ""})
|
| 249 |
+
@app.route("/<path:path>")
|
| 250 |
+
def serve_frontend(path: str) -> Any:
|
| 251 |
+
if path.startswith("api/"):
|
| 252 |
+
abort(404)
|
| 253 |
+
|
| 254 |
+
if FRONTEND_DIST_DIR.exists():
|
| 255 |
+
requested = FRONTEND_DIST_DIR / path if path else FRONTEND_DIST_DIR / "index.html"
|
| 256 |
+
if path and requested.exists() and requested.is_file():
|
| 257 |
+
return send_from_directory(str(FRONTEND_DIST_DIR), path)
|
| 258 |
+
return send_from_directory(str(FRONTEND_DIST_DIR), "index.html")
|
| 259 |
+
|
| 260 |
+
return jsonify(
|
| 261 |
+
{
|
| 262 |
+
"message": "OpenSpace dashboard API is running.",
|
| 263 |
+
"frontend": "Build frontend/ first or run the Vite dev server.",
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return app
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _get_store() -> SkillStore:
|
| 271 |
+
global _STORE
|
| 272 |
+
if _STORE is None:
|
| 273 |
+
_STORE = SkillStore()
|
| 274 |
+
return _STORE
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _bool_arg(name: str, default: bool) -> bool:
|
| 278 |
+
from flask import request
|
| 279 |
+
|
| 280 |
+
raw = request.args.get(name)
|
| 281 |
+
if raw is None:
|
| 282 |
+
return default
|
| 283 |
+
return raw.lower() not in {"0", "false", "no", "off"}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _int_arg(name: str, default: int) -> int:
|
| 287 |
+
from flask import request
|
| 288 |
+
|
| 289 |
+
raw = request.args.get(name)
|
| 290 |
+
if raw is None:
|
| 291 |
+
return default
|
| 292 |
+
try:
|
| 293 |
+
return int(raw)
|
| 294 |
+
except ValueError:
|
| 295 |
+
return default
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _str_arg(name: str, default: str) -> str:
|
| 299 |
+
from flask import request
|
| 300 |
+
|
| 301 |
+
return request.args.get(name, default)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _skill_score(record: SkillRecord) -> float:
|
| 305 |
+
return round(record.effective_rate * 100, 1)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _serialize_skill(record: SkillRecord, *, include_recent_analyses: bool = False) -> Dict[str, Any]:
|
| 309 |
+
payload = record.to_dict()
|
| 310 |
+
if not include_recent_analyses:
|
| 311 |
+
payload.pop("recent_analyses", None)
|
| 312 |
+
|
| 313 |
+
path = payload.get("path", "")
|
| 314 |
+
lineage = payload.get("lineage") or {}
|
| 315 |
+
payload.update(
|
| 316 |
+
{
|
| 317 |
+
"skill_dir": str(Path(path).parent) if path else "",
|
| 318 |
+
"origin": lineage.get("origin", ""),
|
| 319 |
+
"generation": lineage.get("generation", 0),
|
| 320 |
+
"parent_skill_ids": lineage.get("parent_skill_ids", []),
|
| 321 |
+
"applied_rate": round(record.applied_rate, 4),
|
| 322 |
+
"completion_rate": round(record.completion_rate, 4),
|
| 323 |
+
"effective_rate": round(record.effective_rate, 4),
|
| 324 |
+
"fallback_rate": round(record.fallback_rate, 4),
|
| 325 |
+
"score": _skill_score(record),
|
| 326 |
+
}
|
| 327 |
+
)
|
| 328 |
+
return payload
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _naive_dt(dt: datetime) -> datetime:
|
| 332 |
+
"""Strip tzinfo so naive/aware datetimes can be compared safely."""
|
| 333 |
+
return dt.replace(tzinfo=None) if dt.tzinfo else dt
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _sort_skills(records: Iterable[SkillRecord], *, sort_key: str) -> List[SkillRecord]:
|
| 337 |
+
if sort_key == "updated":
|
| 338 |
+
return sorted(records, key=lambda item: _naive_dt(item.last_updated), reverse=True)
|
| 339 |
+
if sort_key == "name":
|
| 340 |
+
return sorted(records, key=lambda item: item.name.lower())
|
| 341 |
+
return sorted(
|
| 342 |
+
records,
|
| 343 |
+
key=lambda item: (_skill_score(item), item.total_selections, _naive_dt(item.last_updated).timestamp()),
|
| 344 |
+
reverse=True,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _build_skill_stats(store: SkillStore, skills: List[SkillRecord]) -> Dict[str, Any]:
|
| 349 |
+
stats = store.get_stats(active_only=False)
|
| 350 |
+
avg_score = round(sum(_skill_score(item) for item in skills) / len(skills), 1) if skills else 0.0
|
| 351 |
+
skills_with_recent_analysis = sum(1 for item in skills if item.recent_analyses)
|
| 352 |
+
return {
|
| 353 |
+
**stats,
|
| 354 |
+
"average_score": avg_score,
|
| 355 |
+
"skills_with_activity": sum(1 for item in skills if item.total_selections > 0),
|
| 356 |
+
"skills_with_recent_analysis": skills_with_recent_analysis,
|
| 357 |
+
"top_by_effective_rate": [_serialize_skill(item) for item in _sort_skills(skills, sort_key="score")[:5]],
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _load_skill_source(record: SkillRecord) -> Dict[str, Any]:
|
| 362 |
+
skill_path = Path(record.path)
|
| 363 |
+
if not skill_path.exists() or not skill_path.is_file():
|
| 364 |
+
return {"exists": False, "path": record.path, "content": None}
|
| 365 |
+
try:
|
| 366 |
+
return {
|
| 367 |
+
"exists": True,
|
| 368 |
+
"path": str(skill_path),
|
| 369 |
+
"content": skill_path.read_text(encoding="utf-8"),
|
| 370 |
+
}
|
| 371 |
+
except OSError:
|
| 372 |
+
return {"exists": False, "path": str(skill_path), "content": None}
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _build_lineage_payload(skill_id: str, store: SkillStore) -> Dict[str, Any]:
|
| 376 |
+
records = store.load_all(active_only=False)
|
| 377 |
+
if skill_id not in records:
|
| 378 |
+
return {"skill_id": skill_id, "nodes": [], "edges": [], "total_nodes": 0}
|
| 379 |
+
|
| 380 |
+
children_by_parent: Dict[str, set[str]] = {}
|
| 381 |
+
for item in records.values():
|
| 382 |
+
for parent_id in item.lineage.parent_skill_ids:
|
| 383 |
+
children_by_parent.setdefault(parent_id, set()).add(item.skill_id)
|
| 384 |
+
|
| 385 |
+
related_ids = {skill_id}
|
| 386 |
+
frontier = [skill_id]
|
| 387 |
+
while frontier:
|
| 388 |
+
current = frontier.pop()
|
| 389 |
+
record = records.get(current)
|
| 390 |
+
if not record:
|
| 391 |
+
continue
|
| 392 |
+
for parent_id in record.lineage.parent_skill_ids:
|
| 393 |
+
if parent_id not in related_ids:
|
| 394 |
+
related_ids.add(parent_id)
|
| 395 |
+
frontier.append(parent_id)
|
| 396 |
+
for child_id in children_by_parent.get(current, set()):
|
| 397 |
+
if child_id not in related_ids:
|
| 398 |
+
related_ids.add(child_id)
|
| 399 |
+
frontier.append(child_id)
|
| 400 |
+
|
| 401 |
+
nodes = []
|
| 402 |
+
edges = []
|
| 403 |
+
for related_id in sorted(related_ids):
|
| 404 |
+
record = records.get(related_id)
|
| 405 |
+
if not record:
|
| 406 |
+
continue
|
| 407 |
+
nodes.append(
|
| 408 |
+
{
|
| 409 |
+
"skill_id": record.skill_id,
|
| 410 |
+
"name": record.name,
|
| 411 |
+
"description": record.description,
|
| 412 |
+
"origin": record.lineage.origin.value,
|
| 413 |
+
"generation": record.lineage.generation,
|
| 414 |
+
"created_at": record.lineage.created_at.isoformat(),
|
| 415 |
+
"visibility": record.visibility.value,
|
| 416 |
+
"is_active": record.is_active,
|
| 417 |
+
"tags": list(record.tags),
|
| 418 |
+
"score": _skill_score(record),
|
| 419 |
+
"effective_rate": round(record.effective_rate, 4),
|
| 420 |
+
"total_selections": record.total_selections,
|
| 421 |
+
}
|
| 422 |
+
)
|
| 423 |
+
for parent_id in record.lineage.parent_skill_ids:
|
| 424 |
+
if parent_id in related_ids:
|
| 425 |
+
edges.append({"source": parent_id, "target": record.skill_id})
|
| 426 |
+
|
| 427 |
+
return {
|
| 428 |
+
"skill_id": skill_id,
|
| 429 |
+
"nodes": nodes,
|
| 430 |
+
"edges": edges,
|
| 431 |
+
"total_nodes": len(nodes),
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _discover_workflow_dirs() -> List[Path]:
|
| 436 |
+
discovered: Dict[str, Path] = {}
|
| 437 |
+
for root in WORKFLOW_ROOTS:
|
| 438 |
+
if not root.exists():
|
| 439 |
+
continue
|
| 440 |
+
_scan_workflow_tree(root, discovered)
|
| 441 |
+
return sorted(discovered.values(), key=lambda item: item.stat().st_mtime, reverse=True)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _scan_workflow_tree(directory: Path, discovered: Dict[str, Path], *, _depth: int = 0, _max_depth: int = 6) -> None:
|
| 445 |
+
if _depth > _max_depth:
|
| 446 |
+
return
|
| 447 |
+
try:
|
| 448 |
+
children = list(directory.iterdir())
|
| 449 |
+
except OSError:
|
| 450 |
+
return
|
| 451 |
+
for child in children:
|
| 452 |
+
if not child.is_dir():
|
| 453 |
+
continue
|
| 454 |
+
if (child / "metadata.json").exists() or (child / "traj.jsonl").exists():
|
| 455 |
+
discovered.setdefault(child.name, child)
|
| 456 |
+
else:
|
| 457 |
+
_scan_workflow_tree(child, discovered, _depth=_depth + 1, _max_depth=_max_depth)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _get_workflow_dir(workflow_id: str) -> Optional[Path]:
|
| 461 |
+
for path in _discover_workflow_dirs():
|
| 462 |
+
if path.name == workflow_id:
|
| 463 |
+
return path
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _build_workflow_summary(workflow_dir: Path) -> Dict[str, Any]:
|
| 468 |
+
session = load_recording_session(str(workflow_dir))
|
| 469 |
+
metadata = session.get("metadata") or {}
|
| 470 |
+
statistics = session.get("statistics") or {}
|
| 471 |
+
actions = load_agent_actions(str(workflow_dir))
|
| 472 |
+
screenshots_dir = workflow_dir / "screenshots"
|
| 473 |
+
screenshot_count = len(list(screenshots_dir.glob("*.png"))) if screenshots_dir.exists() else 0
|
| 474 |
+
|
| 475 |
+
video_candidates = [workflow_dir / "screen_recording.mp4", workflow_dir / "recording.mp4"]
|
| 476 |
+
video_url = None
|
| 477 |
+
for candidate in video_candidates:
|
| 478 |
+
if candidate.exists():
|
| 479 |
+
rel = candidate.relative_to(workflow_dir).as_posix()
|
| 480 |
+
video_url = url_for("workflow_artifact", workflow_id=workflow_dir.name, artifact_path=rel)
|
| 481 |
+
break
|
| 482 |
+
|
| 483 |
+
outcome = metadata.get("execution_outcome") or {}
|
| 484 |
+
# Instruction fallback chain: top-level → retrieved_tools.instruction → skill_selection.task
|
| 485 |
+
instruction = (
|
| 486 |
+
metadata.get("instruction")
|
| 487 |
+
or (metadata.get("retrieved_tools") or {}).get("instruction")
|
| 488 |
+
or (metadata.get("skill_selection") or {}).get("task")
|
| 489 |
+
or ""
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# Resolve start/end times with trajectory fallback
|
| 493 |
+
start_time = metadata.get("start_time")
|
| 494 |
+
end_time = metadata.get("end_time")
|
| 495 |
+
trajectory = session.get("trajectory") or []
|
| 496 |
+
|
| 497 |
+
# If end_time is missing, infer from last trajectory step
|
| 498 |
+
if not end_time and trajectory:
|
| 499 |
+
last_ts = trajectory[-1].get("timestamp")
|
| 500 |
+
if last_ts:
|
| 501 |
+
end_time = last_ts
|
| 502 |
+
|
| 503 |
+
# Compute execution_time: prefer outcome, fallback to timestamp diff
|
| 504 |
+
execution_time = outcome.get("execution_time", 0)
|
| 505 |
+
if not execution_time and start_time and end_time:
|
| 506 |
+
try:
|
| 507 |
+
t0 = datetime.fromisoformat(start_time)
|
| 508 |
+
t1 = datetime.fromisoformat(end_time)
|
| 509 |
+
execution_time = round((t1 - t0).total_seconds(), 2)
|
| 510 |
+
except (ValueError, TypeError):
|
| 511 |
+
pass
|
| 512 |
+
|
| 513 |
+
# Resolve status: prefer outcome, fallback heuristic
|
| 514 |
+
status = outcome.get("status", "")
|
| 515 |
+
if not status:
|
| 516 |
+
total_steps = statistics.get("total_steps", 0)
|
| 517 |
+
if total_steps > 0:
|
| 518 |
+
status = "success"
|
| 519 |
+
elif trajectory:
|
| 520 |
+
status = "completed"
|
| 521 |
+
else:
|
| 522 |
+
status = "unknown"
|
| 523 |
+
|
| 524 |
+
# Resolve iterations: prefer outcome, fallback to conversation count
|
| 525 |
+
iterations = outcome.get("iterations", 0)
|
| 526 |
+
if not iterations and trajectory:
|
| 527 |
+
iterations = len(trajectory)
|
| 528 |
+
|
| 529 |
+
return {
|
| 530 |
+
"id": workflow_dir.name,
|
| 531 |
+
"path": str(workflow_dir),
|
| 532 |
+
"task_id": metadata.get("task_id") or metadata.get("task_name") or workflow_dir.name,
|
| 533 |
+
"task_name": metadata.get("task_name") or metadata.get("task_id") or workflow_dir.name,
|
| 534 |
+
"instruction": instruction,
|
| 535 |
+
"status": status,
|
| 536 |
+
"iterations": iterations,
|
| 537 |
+
"execution_time": execution_time,
|
| 538 |
+
"start_time": start_time,
|
| 539 |
+
"end_time": end_time,
|
| 540 |
+
"total_steps": statistics.get("total_steps", 0),
|
| 541 |
+
"success_count": statistics.get("success_count", 0),
|
| 542 |
+
"success_rate": statistics.get("success_rate", 0.0),
|
| 543 |
+
"backend_counts": statistics.get("backends", {}),
|
| 544 |
+
"tool_counts": statistics.get("tools", {}),
|
| 545 |
+
"agent_action_count": len(actions),
|
| 546 |
+
"has_video": bool(video_url),
|
| 547 |
+
"video_url": video_url,
|
| 548 |
+
"screenshot_count": screenshot_count,
|
| 549 |
+
"selected_skills": (metadata.get("skill_selection") or {}).get("selected", []),
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _build_timeline(actions: List[Dict[str, Any]], trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 554 |
+
events: List[Dict[str, Any]] = []
|
| 555 |
+
for action in actions:
|
| 556 |
+
events.append(
|
| 557 |
+
{
|
| 558 |
+
"timestamp": action.get("timestamp", ""),
|
| 559 |
+
"type": "agent_action",
|
| 560 |
+
"step": action.get("step"),
|
| 561 |
+
"label": action.get("action_type", "agent_action"),
|
| 562 |
+
"agent_name": action.get("agent_name", ""),
|
| 563 |
+
"agent_type": action.get("agent_type", ""),
|
| 564 |
+
"details": action,
|
| 565 |
+
}
|
| 566 |
+
)
|
| 567 |
+
for step in trajectory:
|
| 568 |
+
events.append(
|
| 569 |
+
{
|
| 570 |
+
"timestamp": step.get("timestamp", ""),
|
| 571 |
+
"type": "tool_execution",
|
| 572 |
+
"step": step.get("step"),
|
| 573 |
+
"label": step.get("tool", "tool_execution"),
|
| 574 |
+
"backend": step.get("backend", ""),
|
| 575 |
+
"status": (step.get("result") or {}).get("status", "unknown"),
|
| 576 |
+
"details": step,
|
| 577 |
+
}
|
| 578 |
+
)
|
| 579 |
+
events.sort(key=lambda item: (item.get("timestamp", ""), item.get("step") or 0))
|
| 580 |
+
return events
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def _build_workflow_artifacts(workflow_dir: Path, workflow_id: str, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
| 584 |
+
screenshots: List[Dict[str, Any]] = []
|
| 585 |
+
screenshots_dir = workflow_dir / "screenshots"
|
| 586 |
+
if screenshots_dir.exists():
|
| 587 |
+
for image in sorted(screenshots_dir.glob("*.png")):
|
| 588 |
+
rel = image.relative_to(workflow_dir).as_posix()
|
| 589 |
+
screenshots.append(
|
| 590 |
+
{
|
| 591 |
+
"name": image.name,
|
| 592 |
+
"path": rel,
|
| 593 |
+
"url": url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=rel),
|
| 594 |
+
}
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
init_screenshot = metadata.get("init_screenshot")
|
| 598 |
+
init_screenshot_url = (
|
| 599 |
+
url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=init_screenshot)
|
| 600 |
+
if isinstance(init_screenshot, str)
|
| 601 |
+
else None
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
video_url = None
|
| 605 |
+
for rel in ("screen_recording.mp4", "recording.mp4"):
|
| 606 |
+
candidate = workflow_dir / rel
|
| 607 |
+
if candidate.exists():
|
| 608 |
+
video_url = url_for("workflow_artifact", workflow_id=workflow_id, artifact_path=rel)
|
| 609 |
+
break
|
| 610 |
+
|
| 611 |
+
return {
|
| 612 |
+
"init_screenshot_url": init_screenshot_url,
|
| 613 |
+
"screenshots": screenshots,
|
| 614 |
+
"video_url": video_url,
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def main() -> None:
|
| 619 |
+
parser = argparse.ArgumentParser(description="OpenSpace dashboard API server")
|
| 620 |
+
parser.add_argument("--host", default="127.0.0.1", help="Dashboard API host")
|
| 621 |
+
parser.add_argument("--port", type=int, default=7788, help="Dashboard API port")
|
| 622 |
+
parser.add_argument("--debug", action="store_true", help="Enable Flask debug mode")
|
| 623 |
+
args = parser.parse_args()
|
| 624 |
+
|
| 625 |
+
app = create_app()
|
| 626 |
+
|
| 627 |
+
from werkzeug.serving import run_simple
|
| 628 |
+
run_simple(
|
| 629 |
+
args.host,
|
| 630 |
+
args.port,
|
| 631 |
+
app,
|
| 632 |
+
threaded=True,
|
| 633 |
+
use_debugger=args.debug,
|
| 634 |
+
use_reloader=args.debug,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
if __name__ == "__main__":
|
| 639 |
+
main()
|
openspace/grounding/backends/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use lazy imports to avoid loading all backends unconditionally
|
| 2 |
+
|
| 3 |
+
def _lazy_import_provider(provider_name: str):
|
| 4 |
+
"""Lazy import provider class"""
|
| 5 |
+
if provider_name == 'mcp':
|
| 6 |
+
from .mcp.provider import MCPProvider
|
| 7 |
+
return MCPProvider
|
| 8 |
+
elif provider_name == 'shell':
|
| 9 |
+
from .shell.provider import ShellProvider
|
| 10 |
+
return ShellProvider
|
| 11 |
+
elif provider_name == 'web':
|
| 12 |
+
from .web.provider import WebProvider
|
| 13 |
+
return WebProvider
|
| 14 |
+
elif provider_name == 'gui':
|
| 15 |
+
from .gui.provider import GUIProvider
|
| 16 |
+
return GUIProvider
|
| 17 |
+
else:
|
| 18 |
+
raise ImportError(f"Unknown provider: {provider_name}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _ProviderRegistry:
|
| 22 |
+
"""Lazy provider registry"""
|
| 23 |
+
def __getitem__(self, key):
|
| 24 |
+
return _lazy_import_provider(key)
|
| 25 |
+
|
| 26 |
+
def __contains__(self, key):
|
| 27 |
+
return key in ['mcp', 'shell', 'web', 'gui']
|
| 28 |
+
|
| 29 |
+
BACKEND_PROVIDERS = _ProviderRegistry()
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
'BACKEND_PROVIDERS',
|
| 33 |
+
'_lazy_import_provider'
|
| 34 |
+
]
|
openspace/grounding/backends/gui/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .provider import GUIProvider
|
| 2 |
+
from .session import GUISession
|
| 3 |
+
from .transport.connector import GUIConnector
|
| 4 |
+
from .transport.local_connector import LocalGUIConnector
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .anthropic_client import AnthropicGUIClient
|
| 8 |
+
from . import anthropic_utils
|
| 9 |
+
_anthropic_available = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
_anthropic_available = False
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
# Core Provider and Session
|
| 15 |
+
"GUIProvider",
|
| 16 |
+
"GUISession",
|
| 17 |
+
|
| 18 |
+
# Transport layer
|
| 19 |
+
"GUIConnector",
|
| 20 |
+
"LocalGUIConnector",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
# Add Anthropic modules to exports if available
|
| 24 |
+
if _anthropic_available:
|
| 25 |
+
__all__.extend(["AnthropicGUIClient", "anthropic_utils"])
|
openspace/grounding/backends/gui/anthropic_client.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 5 |
+
from openspace.utils.logging import Logger
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
|
| 9 |
+
logger = Logger.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from anthropic import (
|
| 13 |
+
Anthropic,
|
| 14 |
+
AnthropicBedrock,
|
| 15 |
+
AnthropicVertex,
|
| 16 |
+
APIError,
|
| 17 |
+
APIResponseValidationError,
|
| 18 |
+
APIStatusError,
|
| 19 |
+
)
|
| 20 |
+
from anthropic.types.beta import (
|
| 21 |
+
BetaMessageParam,
|
| 22 |
+
BetaTextBlockParam,
|
| 23 |
+
)
|
| 24 |
+
ANTHROPIC_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
logger.warning("Anthropic SDK not available. Install with: pip install anthropic")
|
| 27 |
+
ANTHROPIC_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
# Import utility functions
|
| 30 |
+
from .anthropic_utils import (
|
| 31 |
+
APIProvider,
|
| 32 |
+
PROVIDER_TO_DEFAULT_MODEL_NAME,
|
| 33 |
+
COMPUTER_USE_BETA_FLAG,
|
| 34 |
+
PROMPT_CACHING_BETA_FLAG,
|
| 35 |
+
get_system_prompt,
|
| 36 |
+
inject_prompt_caching,
|
| 37 |
+
maybe_filter_to_n_most_recent_images,
|
| 38 |
+
response_to_params,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# API retry configuration
|
| 42 |
+
API_RETRY_TIMES = 10
|
| 43 |
+
API_RETRY_INTERVAL = 5 # seconds
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AnthropicGUIClient:
|
| 47 |
+
"""
|
| 48 |
+
Anthropic LLM Client for GUI operations.
|
| 49 |
+
Uses Claude Sonnet 4.5 with computer-use-2025-01-24 API.
|
| 50 |
+
|
| 51 |
+
Features:
|
| 52 |
+
- Vision-based screen understanding
|
| 53 |
+
- Automatic screenshot resizing (configurable display size)
|
| 54 |
+
- Coordinate scaling between display and actual screen
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
model: str = "claude-sonnet-4-5",
|
| 60 |
+
platform: str = "Ubuntu",
|
| 61 |
+
api_key: Optional[str] = None,
|
| 62 |
+
provider: str = "anthropic",
|
| 63 |
+
max_tokens: int = 4096,
|
| 64 |
+
screen_size: Tuple[int, int] = (1920, 1080),
|
| 65 |
+
display_size: Tuple[int, int] = (1024, 768), # Computer use display size
|
| 66 |
+
pyautogui_size: Optional[Tuple[int, int]] = None, # PyAutoGUI working size
|
| 67 |
+
only_n_most_recent_images: int = 3,
|
| 68 |
+
enable_prompt_caching: bool = True,
|
| 69 |
+
backup_api_key: Optional[str] = None,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Initialize Anthropic GUI Client for Claude Sonnet 4.5.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model: Model name (only "claude-sonnet-4-5" supported)
|
| 76 |
+
platform: Platform type (Ubuntu, Windows, or macOS)
|
| 77 |
+
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
| 78 |
+
provider: API provider (only "anthropic" supported)
|
| 79 |
+
max_tokens: Maximum tokens for response
|
| 80 |
+
screen_size: Actual screenshot resolution (width, height) - physical pixels
|
| 81 |
+
display_size: Display size for computer use tool (width, height)
|
| 82 |
+
Screenshots will be resized to this size before sending to API
|
| 83 |
+
pyautogui_size: PyAutoGUI working size (logical pixels). If None, assumed same as screen_size.
|
| 84 |
+
On Retina/HiDPI displays, this may be screen_size / 2
|
| 85 |
+
only_n_most_recent_images: Number of recent screenshots to keep in history
|
| 86 |
+
enable_prompt_caching: Whether to enable prompt caching for cost optimization
|
| 87 |
+
backup_api_key: Backup API key (defaults to ANTHROPIC_API_KEY_BACKUP env var)
|
| 88 |
+
"""
|
| 89 |
+
if not ANTHROPIC_AVAILABLE:
|
| 90 |
+
raise RuntimeError("Anthropic SDK not installed. Install with: pip install anthropic")
|
| 91 |
+
|
| 92 |
+
# Only support claude-sonnet-4-5
|
| 93 |
+
if model != "claude-sonnet-4-5":
|
| 94 |
+
logger.warning(f"Model '{model}' not supported. Using 'claude-sonnet-4-5'")
|
| 95 |
+
model = "claude-sonnet-4-5"
|
| 96 |
+
|
| 97 |
+
self.model = model
|
| 98 |
+
self.platform = platform
|
| 99 |
+
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
| 100 |
+
if not self.api_key:
|
| 101 |
+
raise ValueError("Anthropic API key not provided. Set ANTHROPIC_API_KEY env var or pass api_key parameter")
|
| 102 |
+
|
| 103 |
+
# Backup API key for failover
|
| 104 |
+
self.backup_api_key = backup_api_key or os.environ.get("ANTHROPIC_API_KEY_BACKUP")
|
| 105 |
+
|
| 106 |
+
# Only support anthropic provider
|
| 107 |
+
if provider != "anthropic":
|
| 108 |
+
logger.warning(f"Provider '{provider}' not supported. Using 'anthropic'")
|
| 109 |
+
provider = "anthropic"
|
| 110 |
+
|
| 111 |
+
self.provider = APIProvider(provider)
|
| 112 |
+
self.max_tokens = max_tokens
|
| 113 |
+
self.screen_size = screen_size
|
| 114 |
+
self.display_size = display_size
|
| 115 |
+
self.pyautogui_size = pyautogui_size or screen_size # Default to screen_size if not specified
|
| 116 |
+
self.only_n_most_recent_images = only_n_most_recent_images
|
| 117 |
+
self.enable_prompt_caching = enable_prompt_caching
|
| 118 |
+
|
| 119 |
+
# Message history
|
| 120 |
+
self.messages: List[BetaMessageParam] = []
|
| 121 |
+
|
| 122 |
+
# Calculate resize factor for coordinate scaling
|
| 123 |
+
# Step 1: LLM coordinates (display_size) -> Physical pixels (screen_size)
|
| 124 |
+
# Step 2: Physical pixels -> PyAutoGUI logical pixels (pyautogui_size)
|
| 125 |
+
self.resize_factor = (
|
| 126 |
+
self.pyautogui_size[0] / display_size[0], # x scale factor
|
| 127 |
+
self.pyautogui_size[1] / display_size[1] # y scale factor
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
logger.info(
|
| 131 |
+
f"Initialized AnthropicGUIClient:\n"
|
| 132 |
+
f" Model: {model}\n"
|
| 133 |
+
f" Platform: {platform}\n"
|
| 134 |
+
f" Screen Size (physical): {screen_size}\n"
|
| 135 |
+
f" PyAutoGUI Size (logical): {self.pyautogui_size}\n"
|
| 136 |
+
f" Display Size (LLM): {display_size}\n"
|
| 137 |
+
f" Resize Factor (LLM->PyAutoGUI): {self.resize_factor}\n"
|
| 138 |
+
f" Prompt Caching: {enable_prompt_caching}"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def _create_client(self, api_key: Optional[str] = None):
|
| 142 |
+
"""Create Anthropic client (only supports anthropic provider)."""
|
| 143 |
+
key = api_key or self.api_key
|
| 144 |
+
return Anthropic(api_key=key, max_retries=4)
|
| 145 |
+
|
| 146 |
+
def _resize_screenshot(self, screenshot_bytes: bytes) -> bytes:
|
| 147 |
+
"""
|
| 148 |
+
Resize screenshot to display size for Computer Use API.
|
| 149 |
+
|
| 150 |
+
For computer-use-2025-01-24, the screenshot must be resized to the
|
| 151 |
+
display_width_px x display_height_px specified in the tool definition.
|
| 152 |
+
"""
|
| 153 |
+
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
|
| 154 |
+
resized_image = screenshot_image.resize(self.display_size, Image.Resampling.LANCZOS)
|
| 155 |
+
|
| 156 |
+
output_buffer = io.BytesIO()
|
| 157 |
+
resized_image.save(output_buffer, format='PNG')
|
| 158 |
+
return output_buffer.getvalue()
|
| 159 |
+
|
| 160 |
+
def _scale_coordinates(self, x: int, y: int) -> Tuple[int, int]:
|
| 161 |
+
"""
|
| 162 |
+
Scale coordinates from display size to actual screen size.
|
| 163 |
+
|
| 164 |
+
The API returns coordinates in display_size (e.g., 1024x768).
|
| 165 |
+
We need to scale them to actual screen_size (e.g., 1920x1080) for execution.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
x, y: Coordinates in display size space
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Scaled coordinates in actual screen size space
|
| 172 |
+
"""
|
| 173 |
+
scaled_x = int(x * self.resize_factor[0])
|
| 174 |
+
scaled_y = int(y * self.resize_factor[1])
|
| 175 |
+
return scaled_x, scaled_y
|
| 176 |
+
|
| 177 |
+
async def plan_action(
|
| 178 |
+
self,
|
| 179 |
+
task_description: str,
|
| 180 |
+
screenshot: bytes,
|
| 181 |
+
action_history: List[Dict[str, Any]] = None,
|
| 182 |
+
) -> Tuple[Optional[str], List[str]]:
|
| 183 |
+
"""
|
| 184 |
+
Plan next action based on task and current screenshot.
|
| 185 |
+
Includes prompt caching, error handling, and backup API key support.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
task_description: Task to accomplish
|
| 189 |
+
screenshot: Current screenshot (PNG bytes)
|
| 190 |
+
action_history: Previous actions (for context)
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tuple of (reasoning, list of pyautogui commands)
|
| 194 |
+
"""
|
| 195 |
+
# Resize screenshot
|
| 196 |
+
resized_screenshot = self._resize_screenshot(screenshot)
|
| 197 |
+
screenshot_b64 = base64.b64encode(resized_screenshot).decode('utf-8')
|
| 198 |
+
|
| 199 |
+
# Initialize messages with first task + screenshot
|
| 200 |
+
if not self.messages:
|
| 201 |
+
# IMPORTANT: Image should come BEFORE text for better model understanding
|
| 202 |
+
# This matches OSWorld's implementation which has proven effectiveness
|
| 203 |
+
self.messages.append({
|
| 204 |
+
"role": "user",
|
| 205 |
+
"content": [
|
| 206 |
+
{
|
| 207 |
+
"type": "image",
|
| 208 |
+
"source": {
|
| 209 |
+
"type": "base64",
|
| 210 |
+
"media_type": "image/png",
|
| 211 |
+
"data": screenshot_b64,
|
| 212 |
+
},
|
| 213 |
+
},
|
| 214 |
+
{"type": "text", "text": task_description},
|
| 215 |
+
]
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
# Filter images BEFORE adding new screenshot to control message size
|
| 219 |
+
# This is critical to avoid exceeding the 25MB API limit
|
| 220 |
+
image_truncation_threshold = 10
|
| 221 |
+
if self.only_n_most_recent_images and len(self.messages) > 1:
|
| 222 |
+
# Reserve 1 slot for the screenshot we're about to add
|
| 223 |
+
maybe_filter_to_n_most_recent_images(
|
| 224 |
+
self.messages,
|
| 225 |
+
max(1, self.only_n_most_recent_images - 1),
|
| 226 |
+
min_removal_threshold=1, # More aggressive filtering
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Add tool result from previous action if exists
|
| 230 |
+
if self.messages and self.messages[-1]["role"] == "assistant":
|
| 231 |
+
last_content = self.messages[-1]["content"]
|
| 232 |
+
if isinstance(last_content, list) and any(
|
| 233 |
+
block.get("type") == "tool_use" for block in last_content
|
| 234 |
+
):
|
| 235 |
+
tool_use_id = next(
|
| 236 |
+
block["id"] for block in last_content
|
| 237 |
+
if block.get("type") == "tool_use"
|
| 238 |
+
)
|
| 239 |
+
self._add_tool_result(tool_use_id, "Success", resized_screenshot)
|
| 240 |
+
|
| 241 |
+
# Define tools and betas for claude-sonnet-4-5 with computer-use-2025-01-24
|
| 242 |
+
tools = [{
|
| 243 |
+
'name': 'computer',
|
| 244 |
+
'type': 'computer_20250124',
|
| 245 |
+
'display_width_px': self.display_size[0],
|
| 246 |
+
'display_height_px': self.display_size[1],
|
| 247 |
+
'display_number': 1
|
| 248 |
+
}]
|
| 249 |
+
betas = [COMPUTER_USE_BETA_FLAG]
|
| 250 |
+
|
| 251 |
+
# Prepare system prompt with optional caching
|
| 252 |
+
system = BetaTextBlockParam(
|
| 253 |
+
type="text",
|
| 254 |
+
text=get_system_prompt(self.platform)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Enable prompt caching if supported and enabled
|
| 258 |
+
if self.enable_prompt_caching:
|
| 259 |
+
betas.append(PROMPT_CACHING_BETA_FLAG)
|
| 260 |
+
inject_prompt_caching(self.messages)
|
| 261 |
+
system["cache_control"] = {"type": "ephemeral"} # type: ignore
|
| 262 |
+
|
| 263 |
+
# Model name - use claude-sonnet-4-5 directly
|
| 264 |
+
model_name = "claude-sonnet-4-5"
|
| 265 |
+
|
| 266 |
+
# Enable thinking for complex computer use tasks
|
| 267 |
+
extra_body = {"thinking": {"type": "enabled", "budget_tokens": 2048}}
|
| 268 |
+
|
| 269 |
+
# Log request details for debugging
|
| 270 |
+
# Count current images in messages
|
| 271 |
+
total_images = sum(
|
| 272 |
+
1
|
| 273 |
+
for message in self.messages
|
| 274 |
+
for item in (message.get("content", []) if isinstance(message.get("content"), list) else [])
|
| 275 |
+
if isinstance(item, dict) and item.get("type") == "image"
|
| 276 |
+
)
|
| 277 |
+
tool_result_images = sum(
|
| 278 |
+
1
|
| 279 |
+
for message in self.messages
|
| 280 |
+
for item in (message.get("content", []) if isinstance(message.get("content"), list) else [])
|
| 281 |
+
if isinstance(item, dict) and item.get("type") == "tool_result"
|
| 282 |
+
for content in item.get("content", [])
|
| 283 |
+
if isinstance(content, dict) and content.get("type") == "image"
|
| 284 |
+
)
|
| 285 |
+
logger.info(
|
| 286 |
+
f"Anthropic API request:\n"
|
| 287 |
+
f" Model: {model_name}\n"
|
| 288 |
+
f" Display Size: {self.display_size}\n"
|
| 289 |
+
f" Betas: {betas}\n"
|
| 290 |
+
f" Images: {total_images} ({tool_result_images} in tool_results)\n"
|
| 291 |
+
f" Messages: {len(self.messages)}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Try API call with retry and backup
|
| 295 |
+
client = self._create_client()
|
| 296 |
+
response = None
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
# Retry loop with automatic image count reduction on 25MB error
|
| 300 |
+
for attempt in range(API_RETRY_TIMES):
|
| 301 |
+
try:
|
| 302 |
+
response = client.beta.messages.create(
|
| 303 |
+
max_tokens=self.max_tokens,
|
| 304 |
+
messages=self.messages,
|
| 305 |
+
model=model_name,
|
| 306 |
+
system=[system],
|
| 307 |
+
tools=tools,
|
| 308 |
+
betas=betas,
|
| 309 |
+
extra_body=extra_body
|
| 310 |
+
)
|
| 311 |
+
logger.info(f"API call succeeded on attempt {attempt + 1}")
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
| 315 |
+
error_msg = str(e)
|
| 316 |
+
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
| 317 |
+
|
| 318 |
+
# Handle 25MB payload limit error (including HTTP 413)
|
| 319 |
+
if ("25000000" in error_msg or
|
| 320 |
+
"Member must have length less than or equal to" in error_msg or
|
| 321 |
+
"request_too_large" in error_msg or
|
| 322 |
+
"413" in str(e)):
|
| 323 |
+
logger.warning("Detected 25MB limit error, reducing image count")
|
| 324 |
+
current_count = self.only_n_most_recent_images
|
| 325 |
+
new_count = max(1, current_count // 2)
|
| 326 |
+
self.only_n_most_recent_images = new_count
|
| 327 |
+
|
| 328 |
+
maybe_filter_to_n_most_recent_images(
|
| 329 |
+
self.messages,
|
| 330 |
+
new_count,
|
| 331 |
+
min_removal_threshold=1, # Aggressive filtering when hitting limit
|
| 332 |
+
)
|
| 333 |
+
logger.info(f"Image count reduced from {current_count} to {new_count}")
|
| 334 |
+
|
| 335 |
+
if attempt < API_RETRY_TIMES - 1:
|
| 336 |
+
time.sleep(API_RETRY_INTERVAL)
|
| 337 |
+
else:
|
| 338 |
+
raise
|
| 339 |
+
|
| 340 |
+
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
| 341 |
+
logger.error(f"Primary API key failed: {e}")
|
| 342 |
+
|
| 343 |
+
# Try backup API key if available
|
| 344 |
+
if self.backup_api_key:
|
| 345 |
+
logger.warning("Retrying with backup API key...")
|
| 346 |
+
try:
|
| 347 |
+
backup_client = self._create_client(self.backup_api_key)
|
| 348 |
+
response = backup_client.beta.messages.create(
|
| 349 |
+
max_tokens=self.max_tokens,
|
| 350 |
+
messages=self.messages,
|
| 351 |
+
model=model_name,
|
| 352 |
+
system=[system],
|
| 353 |
+
tools=tools,
|
| 354 |
+
betas=betas,
|
| 355 |
+
extra_body=extra_body
|
| 356 |
+
)
|
| 357 |
+
logger.info("Successfully used backup API key")
|
| 358 |
+
except Exception as backup_e:
|
| 359 |
+
logger.error(f"Backup API key also failed: {backup_e}")
|
| 360 |
+
return None, ["FAIL"]
|
| 361 |
+
else:
|
| 362 |
+
return None, ["FAIL"]
|
| 363 |
+
|
| 364 |
+
except Exception as e:
|
| 365 |
+
logger.error(f"Unexpected error: {e}")
|
| 366 |
+
return None, ["FAIL"]
|
| 367 |
+
|
| 368 |
+
if not response:
|
| 369 |
+
return None, ["FAIL"]
|
| 370 |
+
|
| 371 |
+
# Parse response using utility function
|
| 372 |
+
response_params = response_to_params(response)
|
| 373 |
+
|
| 374 |
+
# Extract reasoning and commands
|
| 375 |
+
reasoning = ""
|
| 376 |
+
commands = []
|
| 377 |
+
|
| 378 |
+
for block in response_params:
|
| 379 |
+
block_type = block.get("type")
|
| 380 |
+
|
| 381 |
+
if block_type == "text":
|
| 382 |
+
reasoning = block.get("text", "")
|
| 383 |
+
elif block_type == "thinking":
|
| 384 |
+
reasoning = block.get("thinking", "")
|
| 385 |
+
elif block_type == "tool_use":
|
| 386 |
+
tool_input = block.get("input", {})
|
| 387 |
+
command = self._parse_computer_tool_use(tool_input)
|
| 388 |
+
if command:
|
| 389 |
+
commands.append(command)
|
| 390 |
+
else:
|
| 391 |
+
logger.warning(f"Failed to parse tool_use: {tool_input}")
|
| 392 |
+
|
| 393 |
+
# Store assistant response
|
| 394 |
+
self.messages.append({
|
| 395 |
+
"role": "assistant",
|
| 396 |
+
"content": response_params
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
logger.info(f"Parsed {len(commands)} commands from response")
|
| 400 |
+
|
| 401 |
+
return reasoning, commands
|
| 402 |
+
|
| 403 |
+
def _add_tool_result(
|
| 404 |
+
self,
|
| 405 |
+
tool_use_id: str,
|
| 406 |
+
result: str,
|
| 407 |
+
screenshot_bytes: Optional[bytes] = None
|
| 408 |
+
):
|
| 409 |
+
"""
|
| 410 |
+
Add tool result to message history.
|
| 411 |
+
IMPORTANT: Put screenshot BEFORE text for consistency with initial message.
|
| 412 |
+
"""
|
| 413 |
+
# Build content list with image first (if provided), then text
|
| 414 |
+
content_list = []
|
| 415 |
+
|
| 416 |
+
# Add screenshot first if provided (consistent with initial message ordering)
|
| 417 |
+
if screenshot_bytes is not None:
|
| 418 |
+
screenshot_b64 = base64.b64encode(screenshot_bytes).decode('utf-8')
|
| 419 |
+
content_list.append({
|
| 420 |
+
"type": "image",
|
| 421 |
+
"source": {
|
| 422 |
+
"type": "base64",
|
| 423 |
+
"media_type": "image/png",
|
| 424 |
+
"data": screenshot_b64
|
| 425 |
+
}
|
| 426 |
+
})
|
| 427 |
+
|
| 428 |
+
# Then add text result
|
| 429 |
+
content_list.append({"type": "text", "text": result})
|
| 430 |
+
|
| 431 |
+
tool_result_content = [{
|
| 432 |
+
"type": "tool_result",
|
| 433 |
+
"tool_use_id": tool_use_id,
|
| 434 |
+
"content": content_list
|
| 435 |
+
}]
|
| 436 |
+
|
| 437 |
+
self.messages.append({
|
| 438 |
+
"role": "user",
|
| 439 |
+
"content": tool_result_content
|
| 440 |
+
})
|
| 441 |
+
|
| 442 |
+
def _parse_computer_tool_use(self, tool_input: Dict[str, Any]) -> Optional[str]:
|
| 443 |
+
"""
|
| 444 |
+
Parse Anthropic computer tool use to pyautogui command.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
tool_input: Tool input from Anthropic (action, coordinate, text, etc.)
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
PyAutoGUI command string or control command (DONE, FAIL)
|
| 451 |
+
"""
|
| 452 |
+
action = tool_input.get("action")
|
| 453 |
+
if not action:
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
# Action conversion
|
| 457 |
+
action_conversion = {
|
| 458 |
+
"left click": "click",
|
| 459 |
+
"right click": "right_click"
|
| 460 |
+
}
|
| 461 |
+
action = action_conversion.get(action, action)
|
| 462 |
+
|
| 463 |
+
text = tool_input.get("text")
|
| 464 |
+
coordinate = tool_input.get("coordinate")
|
| 465 |
+
scroll_direction = tool_input.get("scroll_direction")
|
| 466 |
+
scroll_amount = tool_input.get("scroll_amount", 5)
|
| 467 |
+
|
| 468 |
+
# Scale coordinates to actual screen size
|
| 469 |
+
if coordinate:
|
| 470 |
+
coordinate = self._scale_coordinates(coordinate[0], coordinate[1])
|
| 471 |
+
|
| 472 |
+
# Build commands
|
| 473 |
+
command = ""
|
| 474 |
+
|
| 475 |
+
if action == "mouse_move":
|
| 476 |
+
if coordinate:
|
| 477 |
+
x, y = coordinate
|
| 478 |
+
command = f"pyautogui.moveTo({x}, {y}, duration=0.5)"
|
| 479 |
+
|
| 480 |
+
elif action in ("left_click", "click"):
|
| 481 |
+
if coordinate:
|
| 482 |
+
x, y = coordinate
|
| 483 |
+
command = f"pyautogui.click({x}, {y})"
|
| 484 |
+
else:
|
| 485 |
+
command = "pyautogui.click()"
|
| 486 |
+
|
| 487 |
+
elif action == "right_click":
|
| 488 |
+
if coordinate:
|
| 489 |
+
x, y = coordinate
|
| 490 |
+
command = f"pyautogui.rightClick({x}, {y})"
|
| 491 |
+
else:
|
| 492 |
+
command = "pyautogui.rightClick()"
|
| 493 |
+
|
| 494 |
+
elif action == "double_click":
|
| 495 |
+
if coordinate:
|
| 496 |
+
x, y = coordinate
|
| 497 |
+
command = f"pyautogui.doubleClick({x}, {y})"
|
| 498 |
+
else:
|
| 499 |
+
command = "pyautogui.doubleClick()"
|
| 500 |
+
|
| 501 |
+
elif action == "middle_click":
|
| 502 |
+
if coordinate:
|
| 503 |
+
x, y = coordinate
|
| 504 |
+
command = f"pyautogui.middleClick({x}, {y})"
|
| 505 |
+
else:
|
| 506 |
+
command = "pyautogui.middleClick()"
|
| 507 |
+
|
| 508 |
+
elif action == "left_click_drag":
|
| 509 |
+
if coordinate:
|
| 510 |
+
x, y = coordinate
|
| 511 |
+
command = f"pyautogui.dragTo({x}, {y}, duration=0.5)"
|
| 512 |
+
|
| 513 |
+
elif action == "key":
|
| 514 |
+
if text:
|
| 515 |
+
keys = text.split('+')
|
| 516 |
+
# Key conversion
|
| 517 |
+
key_conversion = {
|
| 518 |
+
"page_down": "pagedown",
|
| 519 |
+
"page_up": "pageup",
|
| 520 |
+
"super_l": "win",
|
| 521 |
+
"super": "command",
|
| 522 |
+
"escape": "esc"
|
| 523 |
+
}
|
| 524 |
+
converted_keys = [key_conversion.get(k.strip().lower(), k.strip().lower()) for k in keys]
|
| 525 |
+
|
| 526 |
+
# Press and release keys
|
| 527 |
+
for key in converted_keys:
|
| 528 |
+
command += f"pyautogui.keyDown('{key}'); "
|
| 529 |
+
for key in reversed(converted_keys):
|
| 530 |
+
command += f"pyautogui.keyUp('{key}'); "
|
| 531 |
+
# Remove trailing semicolon and space
|
| 532 |
+
command = command.rstrip('; ')
|
| 533 |
+
|
| 534 |
+
elif action == "type":
|
| 535 |
+
if text:
|
| 536 |
+
command = f"pyautogui.typewrite({repr(text)}, interval=0.01)"
|
| 537 |
+
|
| 538 |
+
elif action == "scroll":
|
| 539 |
+
if scroll_direction in ("up", "down"):
|
| 540 |
+
scroll_value = scroll_amount if scroll_direction == "up" else -scroll_amount
|
| 541 |
+
if coordinate:
|
| 542 |
+
x, y = coordinate
|
| 543 |
+
command = f"pyautogui.scroll({scroll_value}, {x}, {y})"
|
| 544 |
+
else:
|
| 545 |
+
command = f"pyautogui.scroll({scroll_value})"
|
| 546 |
+
elif scroll_direction in ("left", "right"):
|
| 547 |
+
scroll_value = scroll_amount if scroll_direction == "right" else -scroll_amount
|
| 548 |
+
if coordinate:
|
| 549 |
+
x, y = coordinate
|
| 550 |
+
command = f"pyautogui.hscroll({scroll_value}, {x}, {y})"
|
| 551 |
+
else:
|
| 552 |
+
command = f"pyautogui.hscroll({scroll_value})"
|
| 553 |
+
|
| 554 |
+
elif action == "screenshot":
|
| 555 |
+
# Screenshot is automatically handled by the system
|
| 556 |
+
# Return special marker to indicate no action needed
|
| 557 |
+
return "SCREENSHOT"
|
| 558 |
+
|
| 559 |
+
elif action == "wait":
|
| 560 |
+
# Wait for specified duration
|
| 561 |
+
duration = tool_input.get("duration", 1)
|
| 562 |
+
command = f"pyautogui.sleep({duration})"
|
| 563 |
+
|
| 564 |
+
elif action == "done":
|
| 565 |
+
return "DONE"
|
| 566 |
+
|
| 567 |
+
elif action == "fail":
|
| 568 |
+
return "FAIL"
|
| 569 |
+
|
| 570 |
+
return command if command else None
|
| 571 |
+
|
| 572 |
+
def reset(self):
|
| 573 |
+
"""Reset message history."""
|
| 574 |
+
self.messages = []
|
| 575 |
+
logger.info("Reset AnthropicGUIClient message history")
|
openspace/grounding/backends/gui/anthropic_utils.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, cast
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from openspace.utils.logging import Logger
|
| 5 |
+
|
| 6 |
+
logger = Logger.get_logger(__name__)
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from anthropic.types.beta import (
|
| 10 |
+
BetaCacheControlEphemeralParam,
|
| 11 |
+
BetaContentBlockParam,
|
| 12 |
+
BetaImageBlockParam,
|
| 13 |
+
BetaMessage,
|
| 14 |
+
BetaMessageParam,
|
| 15 |
+
BetaTextBlock,
|
| 16 |
+
BetaTextBlockParam,
|
| 17 |
+
BetaToolResultBlockParam,
|
| 18 |
+
BetaToolUseBlockParam,
|
| 19 |
+
)
|
| 20 |
+
ANTHROPIC_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
ANTHROPIC_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Beta flags
|
| 26 |
+
# For claude-sonnet-4-5 with computer-use-2025-01-24
|
| 27 |
+
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
| 28 |
+
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class APIProvider(Enum):
|
| 32 |
+
"""API Provider enumeration"""
|
| 33 |
+
ANTHROPIC = "anthropic"
|
| 34 |
+
# BEDROCK = "bedrock"
|
| 35 |
+
# VERTEX = "vertex"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Provider to model name mapping (simplified for claude-sonnet-4-5 only)
|
| 39 |
+
PROVIDER_TO_DEFAULT_MODEL_NAME: dict = {
|
| 40 |
+
(APIProvider.ANTHROPIC, "claude-sonnet-4-5"): "claude-sonnet-4-5",
|
| 41 |
+
# (APIProvider.BEDROCK, "claude-sonnet-4-5"): "us.anthropic.claude-sonnet-4-5-v1:0",
|
| 42 |
+
# (APIProvider.VERTEX, "claude-sonnet-4-5"): "claude-sonnet-4-5-v1",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_system_prompt(platform: str = "Ubuntu") -> str:
|
| 47 |
+
"""
|
| 48 |
+
Get system prompt based on platform.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
platform: Platform type (Ubuntu, Windows, macOS, or Darwin)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
System prompt string
|
| 55 |
+
"""
|
| 56 |
+
# Normalize platform name
|
| 57 |
+
platform_lower = platform.lower()
|
| 58 |
+
|
| 59 |
+
if platform_lower in ["windows", "win32"]:
|
| 60 |
+
return f"""<SYSTEM_CAPABILITY>
|
| 61 |
+
* You are utilising a Windows virtual machine using x86_64 architecture with internet access.
|
| 62 |
+
* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications.
|
| 63 |
+
* To accomplish tasks, you MUST use the computer tool to see the screen and take actions.
|
| 64 |
+
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
|
| 65 |
+
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
| 66 |
+
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
| 67 |
+
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
| 68 |
+
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
| 69 |
+
* Home directory of this Windows system is 'C:\\Users\\user'.
|
| 70 |
+
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
|
| 71 |
+
* After each action, the system will provide you with a new screenshot showing the result.
|
| 72 |
+
* Continue taking actions until the task is complete.
|
| 73 |
+
</SYSTEM_CAPABILITY>"""
|
| 74 |
+
elif platform_lower in ["macos", "darwin", "mac"]:
|
| 75 |
+
return f"""<SYSTEM_CAPABILITY>
|
| 76 |
+
* You are utilising a macOS system with internet access.
|
| 77 |
+
* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications.
|
| 78 |
+
* To accomplish tasks, you MUST use the computer tool to see the screen and take actions.
|
| 79 |
+
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
|
| 80 |
+
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
| 81 |
+
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
| 82 |
+
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
| 83 |
+
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
| 84 |
+
* Home directory of this macOS system is typically '/Users/[username]' or can be accessed via '~'.
|
| 85 |
+
* On macOS, use Command (⌘) key combinations instead of Ctrl (e.g., Command+C for copy).
|
| 86 |
+
* After each action, the system will provide you with a new screenshot showing the result.
|
| 87 |
+
* Continue taking actions until the task is complete.
|
| 88 |
+
* When the task is completed, simply describe what you've done in your response WITHOUT using the tool again.
|
| 89 |
+
</SYSTEM_CAPABILITY>"""
|
| 90 |
+
else: # Ubuntu/Linux
|
| 91 |
+
return f"""<SYSTEM_CAPABILITY>
|
| 92 |
+
* You are utilising an Ubuntu virtual machine using x86_64 architecture with internet access.
|
| 93 |
+
* You can use the computer tool to interact with the desktop: take screenshots, click, type, and control applications.
|
| 94 |
+
* To accomplish tasks, you MUST use the computer tool to see the screen and take actions.
|
| 95 |
+
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
|
| 96 |
+
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
| 97 |
+
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
| 98 |
+
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
| 99 |
+
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
| 100 |
+
* Home directory of this Ubuntu system is '/home/user'.
|
| 101 |
+
* After each action, the system will provide you with a new screenshot showing the result.
|
| 102 |
+
* Continue taking actions until the task is complete.
|
| 103 |
+
</SYSTEM_CAPABILITY>"""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def inject_prompt_caching(messages: List[BetaMessageParam]) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Set cache breakpoints for the 3 most recent turns.
|
| 109 |
+
One cache breakpoint is left for tools/system prompt, to be shared across sessions.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
messages: Message history (modified in place)
|
| 113 |
+
"""
|
| 114 |
+
if not ANTHROPIC_AVAILABLE:
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
breakpoints_remaining = 3
|
| 118 |
+
for message in reversed(messages):
|
| 119 |
+
if message["role"] == "user" and isinstance(
|
| 120 |
+
content := message["content"], list
|
| 121 |
+
):
|
| 122 |
+
if breakpoints_remaining:
|
| 123 |
+
breakpoints_remaining -= 1
|
| 124 |
+
# Use type ignore to bypass TypedDict check until SDK types are updated
|
| 125 |
+
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
|
| 126 |
+
{"type": "ephemeral"}
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
content[-1].pop("cache_control", None)
|
| 130 |
+
# we'll only ever have one extra turn per loop
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def maybe_filter_to_n_most_recent_images(
|
| 135 |
+
messages: List[BetaMessageParam],
|
| 136 |
+
images_to_keep: int,
|
| 137 |
+
min_removal_threshold: int,
|
| 138 |
+
) -> None:
|
| 139 |
+
"""
|
| 140 |
+
With the assumption that images are screenshots that are of diminishing value as
|
| 141 |
+
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
| 142 |
+
images in place, with a chunk of min_removal_threshold to reduce the amount we
|
| 143 |
+
break the implicit prompt cache.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
messages: Message history (modified in place)
|
| 147 |
+
images_to_keep: Number of recent images to keep
|
| 148 |
+
min_removal_threshold: Minimum number of images to remove at once (for cache efficiency)
|
| 149 |
+
"""
|
| 150 |
+
if not ANTHROPIC_AVAILABLE or images_to_keep is None:
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
tool_result_blocks = cast(
|
| 154 |
+
list[BetaToolResultBlockParam],
|
| 155 |
+
[
|
| 156 |
+
item
|
| 157 |
+
for message in messages
|
| 158 |
+
for item in (
|
| 159 |
+
message["content"] if isinstance(message["content"], list) else []
|
| 160 |
+
)
|
| 161 |
+
if isinstance(item, dict) and item.get("type") == "tool_result"
|
| 162 |
+
],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
total_images = sum(
|
| 166 |
+
1
|
| 167 |
+
for tool_result in tool_result_blocks
|
| 168 |
+
for content in tool_result.get("content", [])
|
| 169 |
+
if isinstance(content, dict) and content.get("type") == "image"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
images_to_remove = total_images - images_to_keep
|
| 173 |
+
# for better cache behavior, we want to remove in chunks
|
| 174 |
+
images_to_remove -= images_to_remove % min_removal_threshold
|
| 175 |
+
|
| 176 |
+
for tool_result in tool_result_blocks:
|
| 177 |
+
if isinstance(tool_result.get("content"), list):
|
| 178 |
+
new_content = []
|
| 179 |
+
for content in tool_result.get("content", []):
|
| 180 |
+
if isinstance(content, dict) and content.get("type") == "image":
|
| 181 |
+
if images_to_remove > 0:
|
| 182 |
+
images_to_remove -= 1
|
| 183 |
+
continue
|
| 184 |
+
new_content.append(content)
|
| 185 |
+
tool_result["content"] = new_content
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def response_to_params(response: BetaMessage) -> List[BetaContentBlockParam]:
|
| 189 |
+
"""
|
| 190 |
+
Convert Anthropic response to parameter list.
|
| 191 |
+
Handles both text blocks, tool use blocks, and thinking blocks.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
response: Anthropic API response
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
List of content blocks
|
| 198 |
+
"""
|
| 199 |
+
if not ANTHROPIC_AVAILABLE:
|
| 200 |
+
return []
|
| 201 |
+
|
| 202 |
+
res: List[BetaContentBlockParam] = []
|
| 203 |
+
if response.content:
|
| 204 |
+
for block in response.content:
|
| 205 |
+
# Check block type using type attribute
|
| 206 |
+
# Note: type may be a string or enum, so convert to string for comparison
|
| 207 |
+
block_type = str(getattr(block, "type", ""))
|
| 208 |
+
|
| 209 |
+
if block_type == "text":
|
| 210 |
+
# Regular text block
|
| 211 |
+
if isinstance(block, BetaTextBlock) and block.text:
|
| 212 |
+
res.append(BetaTextBlockParam(type="text", text=block.text))
|
| 213 |
+
elif block_type == "thinking":
|
| 214 |
+
# Thinking block (for Claude 4 and Sonnet 3.7)
|
| 215 |
+
thinking_block = {
|
| 216 |
+
"type": "thinking",
|
| 217 |
+
"thinking": getattr(block, "thinking", ""),
|
| 218 |
+
}
|
| 219 |
+
if hasattr(block, "signature"):
|
| 220 |
+
thinking_block["signature"] = getattr(block, "signature", None)
|
| 221 |
+
res.append(cast(BetaContentBlockParam, thinking_block))
|
| 222 |
+
elif block_type == "tool_use":
|
| 223 |
+
# Tool use block - only include required fields to avoid API errors
|
| 224 |
+
# (e.g., 'caller' field is not permitted by Anthropic API)
|
| 225 |
+
tool_use_dict = {
|
| 226 |
+
"type": "tool_use",
|
| 227 |
+
"id": block.id,
|
| 228 |
+
"name": block.name,
|
| 229 |
+
"input": block.input,
|
| 230 |
+
}
|
| 231 |
+
res.append(cast(BetaToolUseBlockParam, tool_use_dict))
|
| 232 |
+
else:
|
| 233 |
+
# Unknown block type - try to handle generically
|
| 234 |
+
try:
|
| 235 |
+
res.append(cast(BetaContentBlockParam, block.model_dump()))
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.warning(f"Failed to parse block type {block_type}: {e}")
|
| 238 |
+
return res
|
| 239 |
+
else:
|
| 240 |
+
return []
|
| 241 |
+
|
openspace/grounding/backends/gui/config.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional
|
| 2 |
+
import os
|
| 3 |
+
import platform as platform_module
|
| 4 |
+
from openspace.utils.logging import Logger
|
| 5 |
+
|
| 6 |
+
logger = Logger.get_logger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_llm_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
| 10 |
+
"""
|
| 11 |
+
Build complete LLM configuration with auto-detection and environment variables.
|
| 12 |
+
|
| 13 |
+
Auto-detects:
|
| 14 |
+
- API key from environment variables (ANTHROPIC_API_KEY)
|
| 15 |
+
- Platform from system (macOS/Windows/Ubuntu)
|
| 16 |
+
- Provider defaults to 'anthropic'
|
| 17 |
+
|
| 18 |
+
User-provided config values will override auto-detected values.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
user_config: User-provided configuration (optional)
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Complete LLM configuration dict
|
| 25 |
+
|
| 26 |
+
Example:
|
| 27 |
+
>>> # Auto-detect everything
|
| 28 |
+
>>> config = build_llm_config()
|
| 29 |
+
|
| 30 |
+
>>> # Override specific values
|
| 31 |
+
>>> config = build_llm_config({
|
| 32 |
+
... "model": "claude-3-5-sonnet-20241022",
|
| 33 |
+
... "max_tokens": 8192
|
| 34 |
+
... })
|
| 35 |
+
"""
|
| 36 |
+
if user_config is None:
|
| 37 |
+
user_config = {}
|
| 38 |
+
|
| 39 |
+
# Auto-detect platform
|
| 40 |
+
system = platform_module.system()
|
| 41 |
+
if system == "Darwin":
|
| 42 |
+
detected_platform = "macOS"
|
| 43 |
+
elif system == "Windows":
|
| 44 |
+
detected_platform = "Windows"
|
| 45 |
+
else: # Linux
|
| 46 |
+
detected_platform = "Ubuntu"
|
| 47 |
+
|
| 48 |
+
# Auto-detect API key from environment
|
| 49 |
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
| 50 |
+
if not api_key:
|
| 51 |
+
logger.warning(
|
| 52 |
+
"ANTHROPIC_API_KEY not found in environment. "
|
| 53 |
+
"Please set it: export ANTHROPIC_API_KEY='your-key'"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Build configuration with precedence: user_config > auto-detected > defaults
|
| 57 |
+
config = {
|
| 58 |
+
"type": user_config.get("type", "anthropic"),
|
| 59 |
+
"model": user_config.get("model", "claude-sonnet-4-5"),
|
| 60 |
+
"platform": user_config.get("platform", detected_platform),
|
| 61 |
+
"api_key": user_config.get("api_key", api_key),
|
| 62 |
+
"provider": user_config.get("provider", "anthropic"),
|
| 63 |
+
"max_tokens": user_config.get("max_tokens", 4096),
|
| 64 |
+
"only_n_most_recent_images": user_config.get("only_n_most_recent_images", 3),
|
| 65 |
+
"enable_prompt_caching": user_config.get("enable_prompt_caching", True),
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Optional: screen_size (will be auto-detected from screenshot later)
|
| 69 |
+
if "screen_size" in user_config:
|
| 70 |
+
config["screen_size"] = user_config["screen_size"]
|
| 71 |
+
|
| 72 |
+
logger.info(f"Built LLM config - Platform: {config['platform']}, Model: {config['model']}")
|
| 73 |
+
if config["api_key"]:
|
| 74 |
+
logger.info(f"API key loaded: {config['api_key'][:10]}...")
|
| 75 |
+
|
| 76 |
+
return config
|
openspace/grounding/backends/gui/provider.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Union
|
| 2 |
+
from openspace.grounding.core.types import BackendType, SessionConfig
|
| 3 |
+
from openspace.grounding.core.provider import Provider
|
| 4 |
+
from openspace.grounding.core.session import BaseSession
|
| 5 |
+
from openspace.config import get_config
|
| 6 |
+
from openspace.config.utils import get_config_value
|
| 7 |
+
from openspace.platforms import get_local_server_config
|
| 8 |
+
from openspace.utils.logging import Logger
|
| 9 |
+
from .transport.connector import GUIConnector
|
| 10 |
+
from .transport.local_connector import LocalGUIConnector
|
| 11 |
+
from .session import GUISession
|
| 12 |
+
|
| 13 |
+
logger = Logger.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GUIProvider(Provider):
|
| 17 |
+
"""
|
| 18 |
+
Provider for GUI desktop environment.
|
| 19 |
+
Manages communication with desktop_env through HTTP API or local in-process execution.
|
| 20 |
+
|
| 21 |
+
Supports two modes:
|
| 22 |
+
- "local": Execute GUI operations directly in-process (no server needed)
|
| 23 |
+
- "server": Connect to a running local_server via HTTP API
|
| 24 |
+
|
| 25 |
+
Supports automatic default session creation:
|
| 26 |
+
- If no session exists, a default session will be created on first use
|
| 27 |
+
- Default session uses configuration from config file or environment
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
DEFAULT_SID = BackendType.GUI.value
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: Dict[str, Any] = None):
|
| 33 |
+
"""
|
| 34 |
+
Initialize GUI provider.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
config: Provider configuration
|
| 38 |
+
"""
|
| 39 |
+
super().__init__(BackendType.GUI, config)
|
| 40 |
+
self.connectors: Dict[str, Union[GUIConnector, LocalGUIConnector]] = {}
|
| 41 |
+
|
| 42 |
+
async def initialize(self) -> None:
|
| 43 |
+
"""
|
| 44 |
+
Initialize the provider and create default session.
|
| 45 |
+
"""
|
| 46 |
+
if not self.is_initialized:
|
| 47 |
+
logger.info("Initializing GUI provider")
|
| 48 |
+
# Auto-create default session
|
| 49 |
+
await self.create_session(SessionConfig(
|
| 50 |
+
session_name=self.DEFAULT_SID,
|
| 51 |
+
backend_type=BackendType.GUI,
|
| 52 |
+
connection_params={}
|
| 53 |
+
))
|
| 54 |
+
self.is_initialized = True
|
| 55 |
+
|
| 56 |
+
async def create_session(self, session_config: SessionConfig) -> BaseSession:
|
| 57 |
+
"""
|
| 58 |
+
Create GUI session.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
session_config: Session configuration
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
GUISession instance
|
| 65 |
+
"""
|
| 66 |
+
# Load GUI backend configuration
|
| 67 |
+
gui_config = get_config().get_backend_config("gui")
|
| 68 |
+
|
| 69 |
+
# Determine execution mode: "local" or "server"
|
| 70 |
+
mode = getattr(gui_config, "mode", "local")
|
| 71 |
+
|
| 72 |
+
# Extract connection parameters
|
| 73 |
+
conn_params = session_config.connection_params
|
| 74 |
+
timeout = get_config_value(conn_params, 'timeout', gui_config.timeout)
|
| 75 |
+
retry_times = get_config_value(conn_params, 'retry_times', gui_config.max_retries)
|
| 76 |
+
retry_interval = get_config_value(conn_params, 'retry_interval', gui_config.retry_interval)
|
| 77 |
+
|
| 78 |
+
# Build pkgs_prefix with failsafe setting
|
| 79 |
+
failsafe_str = "True" if gui_config.failsafe else "False"
|
| 80 |
+
pkgs_prefix = get_config_value(
|
| 81 |
+
conn_params,
|
| 82 |
+
'pkgs_prefix',
|
| 83 |
+
gui_config.pkgs_prefix.format(failsafe=failsafe_str, command="{command}")
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if mode == "local":
|
| 87 |
+
# ---------- LOCAL MODE ----------
|
| 88 |
+
logger.info("GUI backend using LOCAL mode (no server required)")
|
| 89 |
+
connector = LocalGUIConnector(
|
| 90 |
+
timeout=timeout,
|
| 91 |
+
retry_times=retry_times,
|
| 92 |
+
retry_interval=retry_interval,
|
| 93 |
+
pkgs_prefix=pkgs_prefix,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
# ---------- SERVER MODE ----------
|
| 97 |
+
logger.info("GUI backend using SERVER mode (connecting to local_server)")
|
| 98 |
+
local_server_config = get_local_server_config()
|
| 99 |
+
vm_ip = get_config_value(conn_params, 'vm_ip', local_server_config['host'])
|
| 100 |
+
server_port = get_config_value(conn_params, 'server_port', local_server_config['port'])
|
| 101 |
+
|
| 102 |
+
connector = GUIConnector(
|
| 103 |
+
vm_ip=vm_ip,
|
| 104 |
+
server_port=server_port,
|
| 105 |
+
timeout=timeout,
|
| 106 |
+
retry_times=retry_times,
|
| 107 |
+
retry_interval=retry_interval,
|
| 108 |
+
pkgs_prefix=pkgs_prefix,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Create session
|
| 112 |
+
session = GUISession(
|
| 113 |
+
connector=connector,
|
| 114 |
+
session_id=session_config.session_name,
|
| 115 |
+
backend_type=BackendType.GUI,
|
| 116 |
+
config=session_config,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Store connector and session
|
| 120 |
+
self.connectors[session_config.session_name] = connector
|
| 121 |
+
self._sessions[session_config.session_name] = session
|
| 122 |
+
|
| 123 |
+
logger.info(f"Created GUI session: {session_config.session_name} (mode={mode})")
|
| 124 |
+
return session
|
| 125 |
+
|
| 126 |
+
async def close_session(self, session_name: str) -> None:
|
| 127 |
+
"""
|
| 128 |
+
Close GUI session.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
session_name: Name of the session to close
|
| 132 |
+
"""
|
| 133 |
+
if session_name in self._sessions:
|
| 134 |
+
session = self._sessions[session_name]
|
| 135 |
+
await session.disconnect()
|
| 136 |
+
del self._sessions[session_name]
|
| 137 |
+
|
| 138 |
+
if session_name in self.connectors:
|
| 139 |
+
connector = self.connectors[session_name]
|
| 140 |
+
await connector.disconnect()
|
| 141 |
+
del self.connectors[session_name]
|
| 142 |
+
|
| 143 |
+
logger.info(f"Closed GUI session: {session_name}")
|
openspace/grounding/backends/gui/session.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Union
|
| 2 |
+
import os
|
| 3 |
+
from openspace.grounding.core.session import BaseSession
|
| 4 |
+
from openspace.grounding.core.types import BackendType, SessionStatus, SessionConfig
|
| 5 |
+
from openspace.utils.logging import Logger
|
| 6 |
+
from .transport.connector import GUIConnector
|
| 7 |
+
from .transport.local_connector import LocalGUIConnector
|
| 8 |
+
from .tool import GUIAgentTool
|
| 9 |
+
from .config import build_llm_config
|
| 10 |
+
|
| 11 |
+
logger = Logger.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GUISession(BaseSession):
|
| 15 |
+
"""
|
| 16 |
+
Session for GUI desktop environment.
|
| 17 |
+
Manages connection and tools for GUI automation.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
connector: Union[GUIConnector, LocalGUIConnector],
|
| 23 |
+
session_id: str,
|
| 24 |
+
backend_type: BackendType.GUI,
|
| 25 |
+
config: SessionConfig,
|
| 26 |
+
auto_connect: bool = True,
|
| 27 |
+
auto_initialize: bool = True,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize GUI session.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
connector: GUI HTTP connector
|
| 34 |
+
session_id: Unique session identifier
|
| 35 |
+
backend_type: Backend type (GUI)
|
| 36 |
+
config: Session configuration
|
| 37 |
+
auto_connect: Auto-connect on context enter
|
| 38 |
+
auto_initialize: Auto-initialize on context enter
|
| 39 |
+
"""
|
| 40 |
+
super().__init__(
|
| 41 |
+
connector=connector,
|
| 42 |
+
session_id=session_id,
|
| 43 |
+
backend_type=backend_type,
|
| 44 |
+
auto_connect=auto_connect,
|
| 45 |
+
auto_initialize=auto_initialize,
|
| 46 |
+
)
|
| 47 |
+
self.config = config
|
| 48 |
+
self.gui_connector = connector
|
| 49 |
+
|
| 50 |
+
async def initialize(self) -> Dict[str, Any]:
|
| 51 |
+
"""
|
| 52 |
+
Initialize session: connect and discover tools.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Session information dict
|
| 56 |
+
"""
|
| 57 |
+
logger.info(f"Initializing GUI session: {self.session_id}")
|
| 58 |
+
|
| 59 |
+
# Ensure connected
|
| 60 |
+
if not self.connector.is_connected:
|
| 61 |
+
await self.connect()
|
| 62 |
+
|
| 63 |
+
# Create LLM client if configured
|
| 64 |
+
llm_client = None
|
| 65 |
+
user_llm_config = self.config.connection_params.get("llm_config")
|
| 66 |
+
|
| 67 |
+
# Build complete LLM config with auto-detection
|
| 68 |
+
# If user provides llm_config, merge with auto-detected values
|
| 69 |
+
# If user doesn't provide llm_config, try to auto-build one if ANTHROPIC_API_KEY exists
|
| 70 |
+
if user_llm_config or os.environ.get("ANTHROPIC_API_KEY"):
|
| 71 |
+
llm_config = build_llm_config(user_llm_config)
|
| 72 |
+
|
| 73 |
+
if llm_config.get("type") == "anthropic":
|
| 74 |
+
# Check if API key is available
|
| 75 |
+
if not llm_config.get("api_key"):
|
| 76 |
+
logger.warning(
|
| 77 |
+
"Anthropic API key not found. Skipping LLM client initialization. "
|
| 78 |
+
"Set ANTHROPIC_API_KEY environment variable or provide api_key in llm_config."
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
try:
|
| 82 |
+
from .anthropic_client import AnthropicGUIClient
|
| 83 |
+
|
| 84 |
+
# Detect actual screen size from screenshot (most accurate)
|
| 85 |
+
# PyAutoGUI may report logical resolution, but we need the actual screenshot size
|
| 86 |
+
try:
|
| 87 |
+
screenshot_bytes = await self.gui_connector.get_screenshot()
|
| 88 |
+
if screenshot_bytes:
|
| 89 |
+
from PIL import Image
|
| 90 |
+
import io
|
| 91 |
+
img = Image.open(io.BytesIO(screenshot_bytes))
|
| 92 |
+
actual_screen_size = img.size
|
| 93 |
+
logger.info(f"Auto-detected screen size from screenshot: {actual_screen_size}")
|
| 94 |
+
screen_size = actual_screen_size
|
| 95 |
+
else:
|
| 96 |
+
raise RuntimeError("Could not get screenshot")
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# Fallback to pyautogui detection
|
| 99 |
+
actual_screen_size = await self.gui_connector.get_screen_size()
|
| 100 |
+
if actual_screen_size:
|
| 101 |
+
logger.info(f"Auto-detected screen size from pyautogui: {actual_screen_size}")
|
| 102 |
+
screen_size = actual_screen_size
|
| 103 |
+
else:
|
| 104 |
+
# Final fallback to configured value
|
| 105 |
+
screen_size = llm_config.get("screen_size", (1920, 1080))
|
| 106 |
+
logger.warning(f"Could not auto-detect screen size, using configured: {screen_size}")
|
| 107 |
+
|
| 108 |
+
# Detect PyAutoGUI working size (logical pixels)
|
| 109 |
+
pyautogui_size = await self.gui_connector.get_screen_size()
|
| 110 |
+
if pyautogui_size:
|
| 111 |
+
logger.info(f"PyAutoGUI working size (logical): {pyautogui_size}")
|
| 112 |
+
else:
|
| 113 |
+
# If we can't detect PyAutoGUI size, assume it's the same as screen size
|
| 114 |
+
pyautogui_size = screen_size
|
| 115 |
+
logger.warning(f"Could not detect PyAutoGUI size, assuming same as screen: {pyautogui_size}")
|
| 116 |
+
|
| 117 |
+
llm_client = AnthropicGUIClient(
|
| 118 |
+
model=llm_config["model"],
|
| 119 |
+
platform=llm_config["platform"],
|
| 120 |
+
api_key=llm_config["api_key"],
|
| 121 |
+
provider=llm_config["provider"],
|
| 122 |
+
screen_size=screen_size,
|
| 123 |
+
pyautogui_size=pyautogui_size,
|
| 124 |
+
max_tokens=llm_config["max_tokens"],
|
| 125 |
+
only_n_most_recent_images=llm_config["only_n_most_recent_images"],
|
| 126 |
+
)
|
| 127 |
+
logger.info(
|
| 128 |
+
f"Initialized Anthropic LLM client - "
|
| 129 |
+
f"Model: {llm_config['model']}, Platform: {llm_config['platform']}"
|
| 130 |
+
)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.warning(f"Failed to initialize Anthropic client: {e}")
|
| 133 |
+
|
| 134 |
+
# Get recording_manager from connection_params if available
|
| 135 |
+
recording_manager = self.config.connection_params.get("recording_manager")
|
| 136 |
+
|
| 137 |
+
# Create GUI Agent Tool
|
| 138 |
+
self.tools = [
|
| 139 |
+
GUIAgentTool(
|
| 140 |
+
connector=self.gui_connector,
|
| 141 |
+
llm_client=llm_client,
|
| 142 |
+
recording_manager=recording_manager
|
| 143 |
+
)
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
logger.info(f"Initialized GUI session with {len(self.tools)} tool(s)")
|
| 147 |
+
|
| 148 |
+
# Return session info
|
| 149 |
+
session_info = {
|
| 150 |
+
"session_id": self.session_id,
|
| 151 |
+
"backend_type": self.backend_type.value,
|
| 152 |
+
"vm_ip": self.gui_connector.vm_ip,
|
| 153 |
+
"server_port": self.gui_connector.server_port,
|
| 154 |
+
"num_tools": len(self.tools),
|
| 155 |
+
"tools": [tool.name for tool in self.tools],
|
| 156 |
+
"llm_client": "anthropic" if llm_client else "none",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
return session_info
|
| 160 |
+
|
| 161 |
+
async def connect(self) -> None:
|
| 162 |
+
"""Connect to GUI desktop environment"""
|
| 163 |
+
if self.connector.is_connected:
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
self.status = SessionStatus.CONNECTING
|
| 167 |
+
logger.info(f"Connecting to desktop_env at {self.gui_connector.base_url}")
|
| 168 |
+
|
| 169 |
+
await self.connector.connect()
|
| 170 |
+
|
| 171 |
+
self.status = SessionStatus.CONNECTED
|
| 172 |
+
logger.info("Connected to desktop environment")
|
| 173 |
+
|
| 174 |
+
async def disconnect(self) -> None:
|
| 175 |
+
"""Disconnect from GUI desktop environment"""
|
| 176 |
+
if not self.connector.is_connected:
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
logger.info("Disconnecting from desktop environment")
|
| 180 |
+
await self.connector.disconnect()
|
| 181 |
+
|
| 182 |
+
self.status = SessionStatus.DISCONNECTED
|
| 183 |
+
logger.info("Disconnected from desktop environment")
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def is_connected(self) -> bool:
|
| 187 |
+
"""Check if session is connected"""
|
| 188 |
+
return self.connector.is_connected
|
openspace/grounding/backends/gui/tool.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
from openspace.grounding.core.tool.base import BaseTool
|
| 4 |
+
from openspace.grounding.core.types import BackendType, ToolResult, ToolStatus
|
| 5 |
+
from .transport.connector import GUIConnector
|
| 6 |
+
from .transport.actions import ACTION_SPACE, KEYBOARD_KEYS
|
| 7 |
+
from openspace.utils.logging import Logger
|
| 8 |
+
|
| 9 |
+
logger = Logger.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GUIAgentTool(BaseTool):
|
| 13 |
+
"""
|
| 14 |
+
LLM-powered GUI Agent Tool.
|
| 15 |
+
|
| 16 |
+
This tool acts as an intelligent agent that:
|
| 17 |
+
- Takes a task description as input
|
| 18 |
+
- Observes the desktop via screenshot
|
| 19 |
+
- Uses LLM/VLM to understand and plan actions
|
| 20 |
+
- Outputs action space commands
|
| 21 |
+
- Executes actions through the connector
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
_name = "gui_agent"
|
| 25 |
+
_description = """Vision-based GUI automation agent for tasks requiring graphical interface interaction.
|
| 26 |
+
|
| 27 |
+
Use this tool when the task involves:
|
| 28 |
+
- Operating desktop applications with graphical interfaces (browsers, editors, design tools, etc.)
|
| 29 |
+
- Tasks that require visual understanding of UI elements, layouts, or content
|
| 30 |
+
- Multi-step workflows that need click, drag, type, or other GUI interactions
|
| 31 |
+
- Scenarios where programmatic APIs or command-line tools are unavailable or insufficient
|
| 32 |
+
|
| 33 |
+
The agent observes screen state through screenshots, uses vision-language models to understand
|
| 34 |
+
the interface, plans appropriate actions, and executes GUI operations autonomously.
|
| 35 |
+
|
| 36 |
+
IMPORTANT - max_steps Parameter Guidelines:
|
| 37 |
+
- Simple tasks (1-2 actions): 15-20 steps
|
| 38 |
+
- Medium tasks (3-5 actions): 25-35 steps
|
| 39 |
+
- Complex tasks (6+ actions, like web navigation): 35-50 steps
|
| 40 |
+
- When uncertain, prefer larger values (35+) to avoid premature termination
|
| 41 |
+
- Default is 25, but increase for multi-step workflows
|
| 42 |
+
|
| 43 |
+
Input:
|
| 44 |
+
- task_description: Natural language task description
|
| 45 |
+
- max_steps: Maximum actions (default 25, increase for complex tasks)
|
| 46 |
+
|
| 47 |
+
Output: Task execution results with action history and completion status
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
backend_type = BackendType.GUI
|
| 51 |
+
|
| 52 |
+
def __init__(self, connector: GUIConnector, llm_client=None, recording_manager=None, **kwargs):
|
| 53 |
+
"""
|
| 54 |
+
Initialize GUI Agent Tool.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
connector: GUI connector for communication with desktop_env
|
| 58 |
+
llm_client: LLM/VLM client for vision-based planning (optional)
|
| 59 |
+
recording_manager: RecordingManager for recording intermediate steps (optional)
|
| 60 |
+
**kwargs: Additional arguments for BaseTool
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
self.connector = connector
|
| 64 |
+
self.llm_client = llm_client # Will be injected later
|
| 65 |
+
self.recording_manager = recording_manager # For recording intermediate steps
|
| 66 |
+
self.action_history = [] # Track executed actions
|
| 67 |
+
|
| 68 |
+
async def _arun(
|
| 69 |
+
self,
|
| 70 |
+
task_description: str,
|
| 71 |
+
max_steps: int = 50,
|
| 72 |
+
) -> ToolResult:
|
| 73 |
+
"""
|
| 74 |
+
Execute a GUI automation task using LLM planning.
|
| 75 |
+
|
| 76 |
+
This is the main entry point that:
|
| 77 |
+
1. Gets current screenshot
|
| 78 |
+
2. Uses LLM to plan next action based on task and screenshot
|
| 79 |
+
3. Executes the planned action
|
| 80 |
+
4. Repeats until task is complete or max_steps reached
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
task_description: Natural language description of the task
|
| 84 |
+
max_steps: Maximum number of actions to execute (default 25)
|
| 85 |
+
Recommended values based on task complexity:
|
| 86 |
+
- Simple (1-2 actions): 15-20
|
| 87 |
+
- Medium (3-5 actions): 25-35
|
| 88 |
+
- Complex (6+ actions, web navigation, multi-app): 35-50
|
| 89 |
+
When in doubt, use higher values to avoid premature termination
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
ToolResult with task execution status
|
| 93 |
+
"""
|
| 94 |
+
if not task_description:
|
| 95 |
+
return ToolResult(
|
| 96 |
+
status=ToolStatus.ERROR,
|
| 97 |
+
error="task_description is required"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
logger.info(f"Starting GUI task: {task_description}")
|
| 101 |
+
self.action_history = []
|
| 102 |
+
|
| 103 |
+
# Execute task with LLM planning loop
|
| 104 |
+
try:
|
| 105 |
+
result = await self._execute_task_with_planning(
|
| 106 |
+
task_description=task_description,
|
| 107 |
+
max_steps=max_steps,
|
| 108 |
+
)
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Task execution failed: {e}")
|
| 113 |
+
return ToolResult(
|
| 114 |
+
status=ToolStatus.ERROR,
|
| 115 |
+
error=str(e),
|
| 116 |
+
metadata={
|
| 117 |
+
"task_description": task_description,
|
| 118 |
+
"actions_executed": len(self.action_history),
|
| 119 |
+
"action_history": self.action_history,
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
async def _execute_task_with_planning(
|
| 124 |
+
self,
|
| 125 |
+
task_description: str,
|
| 126 |
+
max_steps: int,
|
| 127 |
+
) -> ToolResult:
|
| 128 |
+
"""
|
| 129 |
+
Execute task with LLM-based planning loop.
|
| 130 |
+
|
| 131 |
+
Planning loop:
|
| 132 |
+
1. Observe: Get screenshot
|
| 133 |
+
2. Plan: LLM decides next action
|
| 134 |
+
3. Execute: Perform the action
|
| 135 |
+
4. Verify: Check if task is complete
|
| 136 |
+
5. Repeat until done or max_steps
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
task_description: Task to complete
|
| 140 |
+
max_steps: Maximum planning iterations
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
ToolResult with execution details
|
| 144 |
+
"""
|
| 145 |
+
# Collect all screenshots for visual analysis
|
| 146 |
+
all_screenshots = []
|
| 147 |
+
# Collect intermediate steps
|
| 148 |
+
intermediate_steps = []
|
| 149 |
+
|
| 150 |
+
for step in range(max_steps):
|
| 151 |
+
logger.info(f"Planning step {step + 1}/{max_steps}")
|
| 152 |
+
|
| 153 |
+
# Step 1: Observe current state
|
| 154 |
+
screenshot = await self.connector.get_screenshot()
|
| 155 |
+
if not screenshot:
|
| 156 |
+
return ToolResult(
|
| 157 |
+
status=ToolStatus.ERROR,
|
| 158 |
+
error="Failed to get screenshot for planning",
|
| 159 |
+
metadata={"step": step, "action_history": self.action_history}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Collect screenshot for visual analysis
|
| 163 |
+
all_screenshots.append(screenshot)
|
| 164 |
+
|
| 165 |
+
# Step 2: Plan next action using LLM
|
| 166 |
+
planned_action = await self._plan_next_action(
|
| 167 |
+
task_description=task_description,
|
| 168 |
+
screenshot=screenshot,
|
| 169 |
+
action_history=self.action_history,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Check if task is complete
|
| 173 |
+
if planned_action["action_type"] == "DONE":
|
| 174 |
+
logger.info("Task marked as complete by LLM")
|
| 175 |
+
reasoning = planned_action.get("reasoning", "Task completed successfully")
|
| 176 |
+
|
| 177 |
+
intermediate_steps.append({
|
| 178 |
+
"step_number": step + 1,
|
| 179 |
+
"action": "DONE",
|
| 180 |
+
"reasoning": reasoning,
|
| 181 |
+
"status": "done",
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
return ToolResult(
|
| 185 |
+
status=ToolStatus.SUCCESS,
|
| 186 |
+
content=f"Task completed: {task_description}\n\nFinal state: {reasoning}",
|
| 187 |
+
metadata={
|
| 188 |
+
"steps_taken": step + 1,
|
| 189 |
+
"action_history": self.action_history,
|
| 190 |
+
"screenshots": all_screenshots,
|
| 191 |
+
"intermediate_steps": intermediate_steps,
|
| 192 |
+
"final_reasoning": reasoning,
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Check if task failed
|
| 197 |
+
if planned_action["action_type"] == "FAIL":
|
| 198 |
+
logger.warning("Task marked as failed by LLM")
|
| 199 |
+
reason = planned_action.get("reason", "Task cannot be completed")
|
| 200 |
+
|
| 201 |
+
intermediate_steps.append({
|
| 202 |
+
"step_number": step + 1,
|
| 203 |
+
"action": "FAIL",
|
| 204 |
+
"reasoning": planned_action.get("reasoning", ""),
|
| 205 |
+
"status": "failed",
|
| 206 |
+
})
|
| 207 |
+
|
| 208 |
+
return ToolResult(
|
| 209 |
+
status=ToolStatus.ERROR,
|
| 210 |
+
error=reason,
|
| 211 |
+
metadata={
|
| 212 |
+
"steps_taken": step + 1,
|
| 213 |
+
"action_history": self.action_history,
|
| 214 |
+
"screenshots": all_screenshots,
|
| 215 |
+
"intermediate_steps": intermediate_steps,
|
| 216 |
+
}
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Check if action is WAIT (screenshot observation, continue to next step)
|
| 220 |
+
if planned_action["action_type"] == "WAIT":
|
| 221 |
+
logger.info("Screenshot observation step, continuing planning loop")
|
| 222 |
+
intermediate_steps.append({
|
| 223 |
+
"step_number": step + 1,
|
| 224 |
+
"action": "WAIT",
|
| 225 |
+
"reasoning": planned_action.get("reasoning", ""),
|
| 226 |
+
"status": "observation",
|
| 227 |
+
})
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
# Step 3: Execute the planned action
|
| 231 |
+
execution_result = await self._execute_planned_action(planned_action)
|
| 232 |
+
|
| 233 |
+
# Record action in history
|
| 234 |
+
self.action_history.append({
|
| 235 |
+
"step": step + 1,
|
| 236 |
+
"planned_action": planned_action,
|
| 237 |
+
"execution_result": execution_result,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
intermediate_steps.append({
|
| 241 |
+
"step_number": step + 1,
|
| 242 |
+
"action": planned_action.get("action_type", "unknown"),
|
| 243 |
+
"reasoning": planned_action.get("reasoning", ""),
|
| 244 |
+
"status": execution_result.get("status", "unknown"),
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
# Check execution result
|
| 248 |
+
if execution_result.get("status") != "success":
|
| 249 |
+
logger.warning(f"Action execution failed: {execution_result.get('error')}")
|
| 250 |
+
# Continue to next iteration for retry planning
|
| 251 |
+
|
| 252 |
+
# Max steps reached
|
| 253 |
+
return ToolResult(
|
| 254 |
+
status=ToolStatus.ERROR,
|
| 255 |
+
error=f"Task incomplete after {max_steps} steps",
|
| 256 |
+
metadata={
|
| 257 |
+
"task_description": task_description,
|
| 258 |
+
"steps_taken": max_steps,
|
| 259 |
+
"action_history": self.action_history,
|
| 260 |
+
"screenshots": all_screenshots,
|
| 261 |
+
"intermediate_steps": intermediate_steps,
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
async def _plan_next_action(
|
| 266 |
+
self,
|
| 267 |
+
task_description: str,
|
| 268 |
+
screenshot: bytes,
|
| 269 |
+
action_history: list,
|
| 270 |
+
) -> Dict[str, Any]:
|
| 271 |
+
"""
|
| 272 |
+
Use LLM/VLM to plan the next action.
|
| 273 |
+
|
| 274 |
+
This method sends:
|
| 275 |
+
- Task description
|
| 276 |
+
- Current screenshot (vision input)
|
| 277 |
+
- Action history (context)
|
| 278 |
+
- Available ACTION_SPACE
|
| 279 |
+
|
| 280 |
+
And gets back a structured action plan.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
task_description: The task to accomplish
|
| 284 |
+
screenshot: Current desktop screenshot (PNG/JPEG bytes)
|
| 285 |
+
action_history: Previously executed actions
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Dict with action_type and parameters
|
| 289 |
+
"""
|
| 290 |
+
if self.llm_client is None:
|
| 291 |
+
# Fallback: Simple heuristic or manual mode
|
| 292 |
+
logger.warning("No LLM client configured, using fallback mode")
|
| 293 |
+
return {
|
| 294 |
+
"action_type": "FAIL",
|
| 295 |
+
"reason": "LLM client not configured"
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
# Check if using Anthropic client
|
| 299 |
+
try:
|
| 300 |
+
from .anthropic_client import AnthropicGUIClient
|
| 301 |
+
is_anthropic = isinstance(self.llm_client, AnthropicGUIClient)
|
| 302 |
+
except ImportError:
|
| 303 |
+
is_anthropic = False
|
| 304 |
+
|
| 305 |
+
if is_anthropic:
|
| 306 |
+
# Use Anthropic client
|
| 307 |
+
try:
|
| 308 |
+
reasoning, commands = await self.llm_client.plan_action(
|
| 309 |
+
task_description=task_description,
|
| 310 |
+
screenshot=screenshot,
|
| 311 |
+
action_history=action_history,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if commands == ["FAIL"]:
|
| 315 |
+
return {
|
| 316 |
+
"action_type": "FAIL",
|
| 317 |
+
"reason": "Anthropic planning failed"
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
if commands == ["DONE"]:
|
| 321 |
+
return {
|
| 322 |
+
"action_type": "DONE",
|
| 323 |
+
"reasoning": reasoning
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
if commands == ["SCREENSHOT"]:
|
| 327 |
+
# Screenshot is automatically handled by system
|
| 328 |
+
# Continue to next planning step
|
| 329 |
+
logger.info("LLM requested screenshot (observation step)")
|
| 330 |
+
return {
|
| 331 |
+
"action_type": "WAIT",
|
| 332 |
+
"reasoning": reasoning or "Observing screen state"
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
# If no commands but has reasoning, task is complete
|
| 336 |
+
# (Anthropic returns text-only when task is done)
|
| 337 |
+
if not commands and reasoning:
|
| 338 |
+
logger.info("LLM returned text-only response, interpreting as task completion")
|
| 339 |
+
return {
|
| 340 |
+
"action_type": "DONE",
|
| 341 |
+
"reasoning": reasoning
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# No commands and no reasoning = error
|
| 345 |
+
if not commands:
|
| 346 |
+
return {
|
| 347 |
+
"action_type": "FAIL",
|
| 348 |
+
"reason": "No commands generated and no completion message"
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
# Return first command (Anthropic returns pyautogui commands directly)
|
| 352 |
+
return {
|
| 353 |
+
"action_type": "PYAUTOGUI_COMMAND",
|
| 354 |
+
"command": commands[0],
|
| 355 |
+
"reasoning": reasoning
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
logger.error(f"Anthropic planning failed: {e}")
|
| 360 |
+
return {
|
| 361 |
+
"action_type": "FAIL",
|
| 362 |
+
"reason": f"Planning error: {str(e)}"
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# Generic LLM client (for future integration with other LLMs)
|
| 366 |
+
# Encode screenshot to base64 for LLM
|
| 367 |
+
screenshot_b64 = base64.b64encode(screenshot).decode('utf-8')
|
| 368 |
+
|
| 369 |
+
# Prepare prompt for LLM
|
| 370 |
+
prompt = self._build_planning_prompt(
|
| 371 |
+
task_description=task_description,
|
| 372 |
+
action_history=action_history,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Call LLM with vision input
|
| 376 |
+
try:
|
| 377 |
+
response = await self.llm_client.plan_action(
|
| 378 |
+
prompt=prompt,
|
| 379 |
+
image_base64=screenshot_b64,
|
| 380 |
+
action_space=ACTION_SPACE,
|
| 381 |
+
keyboard_keys=KEYBOARD_KEYS,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Parse LLM response to action dict
|
| 385 |
+
action = self._parse_llm_response(response)
|
| 386 |
+
|
| 387 |
+
logger.info(f"LLM planned action: {action['action_type']}")
|
| 388 |
+
return action
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.error(f"LLM planning failed: {e}")
|
| 392 |
+
return {
|
| 393 |
+
"action_type": "FAIL",
|
| 394 |
+
"reason": f"Planning error: {str(e)}"
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
def _build_planning_prompt(
|
| 398 |
+
self,
|
| 399 |
+
task_description: str,
|
| 400 |
+
action_history: list,
|
| 401 |
+
) -> str:
|
| 402 |
+
"""
|
| 403 |
+
Build prompt for LLM planning.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
task_description: The task to accomplish
|
| 407 |
+
action_history: Previously executed actions
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Formatted prompt string
|
| 411 |
+
"""
|
| 412 |
+
prompt = f"""You are a GUI automation agent. Your task is to complete the following:
|
| 413 |
+
|
| 414 |
+
Task: {task_description}
|
| 415 |
+
|
| 416 |
+
You can observe the current desktop state through the provided screenshot.
|
| 417 |
+
You must plan the next action to take from the available ACTION_SPACE.
|
| 418 |
+
|
| 419 |
+
Available actions:
|
| 420 |
+
- Mouse: MOVE_TO, CLICK, RIGHT_CLICK, DOUBLE_CLICK, DRAG_TO, SCROLL
|
| 421 |
+
- Keyboard: TYPING, PRESS, KEY_DOWN, KEY_UP, HOTKEY
|
| 422 |
+
- Control: WAIT, DONE, FAIL
|
| 423 |
+
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
if action_history:
|
| 427 |
+
prompt += f"\nPrevious actions taken ({len(action_history)}):\n"
|
| 428 |
+
for i, action in enumerate(action_history[-5:], 1): # Last 5 actions
|
| 429 |
+
prompt += f"{i}. {action['planned_action']['action_type']}"
|
| 430 |
+
if 'parameters' in action['planned_action']:
|
| 431 |
+
prompt += f" - {action['planned_action']['parameters']}"
|
| 432 |
+
prompt += "\n"
|
| 433 |
+
|
| 434 |
+
prompt += """
|
| 435 |
+
Based on the screenshot and task, output the next action in JSON format:
|
| 436 |
+
{
|
| 437 |
+
"action_type": "ACTION_TYPE",
|
| 438 |
+
"parameters": {...},
|
| 439 |
+
"reasoning": "Why this action is needed"
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
If the task is complete, output: {"action_type": "DONE"}
|
| 443 |
+
If the task cannot be completed, output: {"action_type": "FAIL", "reason": "explanation"}
|
| 444 |
+
"""
|
| 445 |
+
|
| 446 |
+
return prompt
|
| 447 |
+
|
| 448 |
+
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
|
| 449 |
+
"""
|
| 450 |
+
Parse LLM response to extract action.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
response: LLM response (should be JSON)
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
Action dict with action_type and parameters
|
| 457 |
+
"""
|
| 458 |
+
import json
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
# Try to parse as JSON
|
| 462 |
+
action = json.loads(response)
|
| 463 |
+
|
| 464 |
+
# Validate action
|
| 465 |
+
if "action_type" not in action:
|
| 466 |
+
raise ValueError("Missing action_type in LLM response")
|
| 467 |
+
|
| 468 |
+
return action
|
| 469 |
+
|
| 470 |
+
except json.JSONDecodeError:
|
| 471 |
+
logger.error(f"Failed to parse LLM response as JSON: {response[:200]}")
|
| 472 |
+
return {
|
| 473 |
+
"action_type": "FAIL",
|
| 474 |
+
"reason": "Invalid LLM response format"
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
async def _execute_planned_action(
|
| 478 |
+
self,
|
| 479 |
+
action: Dict[str, Any]
|
| 480 |
+
) -> Dict[str, Any]:
|
| 481 |
+
"""
|
| 482 |
+
Execute a planned action through the connector.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
action: Action dict with action_type and parameters
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
Execution result dict
|
| 489 |
+
"""
|
| 490 |
+
action_type = action["action_type"]
|
| 491 |
+
|
| 492 |
+
# Handle Anthropic's direct pyautogui commands
|
| 493 |
+
if action_type == "PYAUTOGUI_COMMAND":
|
| 494 |
+
command = action.get("command", "")
|
| 495 |
+
logger.info(f"Executing pyautogui command: {command}")
|
| 496 |
+
|
| 497 |
+
try:
|
| 498 |
+
result = await self.connector.execute_python_command(command)
|
| 499 |
+
return {
|
| 500 |
+
"status": "success" if result else "error",
|
| 501 |
+
"action_type": action_type,
|
| 502 |
+
"command": command,
|
| 503 |
+
"result": result
|
| 504 |
+
}
|
| 505 |
+
except Exception as e:
|
| 506 |
+
logger.error(f"Command execution error: {e}")
|
| 507 |
+
return {
|
| 508 |
+
"status": "error",
|
| 509 |
+
"action_type": action_type,
|
| 510 |
+
"error": str(e)
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
# Handle standard action space commands
|
| 514 |
+
parameters = action.get("parameters", {})
|
| 515 |
+
logger.info(f"Executing action: {action_type}")
|
| 516 |
+
|
| 517 |
+
try:
|
| 518 |
+
result = await self.connector.execute_action(action_type, parameters)
|
| 519 |
+
return result
|
| 520 |
+
|
| 521 |
+
except Exception as e:
|
| 522 |
+
logger.error(f"Action execution error: {e}")
|
| 523 |
+
return {
|
| 524 |
+
"status": "error",
|
| 525 |
+
"action_type": action_type,
|
| 526 |
+
"error": str(e)
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
# Helper methods for direct action execution
|
| 530 |
+
|
| 531 |
+
async def execute_action(
|
| 532 |
+
self,
|
| 533 |
+
action_type: str,
|
| 534 |
+
parameters: Dict[str, Any]
|
| 535 |
+
) -> ToolResult:
|
| 536 |
+
"""
|
| 537 |
+
Direct action execution (bypass LLM planning).
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
action_type: Action type from ACTION_SPACE
|
| 541 |
+
parameters: Action parameters
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
ToolResult with execution status
|
| 545 |
+
"""
|
| 546 |
+
result = await self.connector.execute_action(action_type, parameters)
|
| 547 |
+
|
| 548 |
+
if result.get("status") == "success":
|
| 549 |
+
return ToolResult(
|
| 550 |
+
status=ToolStatus.SUCCESS,
|
| 551 |
+
content=f"Executed {action_type}",
|
| 552 |
+
metadata=result
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
return ToolResult(
|
| 556 |
+
status=ToolStatus.ERROR,
|
| 557 |
+
error=result.get("error", "Unknown error"),
|
| 558 |
+
metadata=result
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
async def get_screenshot(self) -> ToolResult:
|
| 562 |
+
"""Get current desktop screenshot."""
|
| 563 |
+
screenshot = await self.connector.get_screenshot()
|
| 564 |
+
if screenshot:
|
| 565 |
+
return ToolResult(
|
| 566 |
+
status=ToolStatus.SUCCESS,
|
| 567 |
+
content=screenshot,
|
| 568 |
+
metadata={"type": "screenshot", "size": len(screenshot)}
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
return ToolResult(
|
| 572 |
+
status=ToolStatus.ERROR,
|
| 573 |
+
error="Failed to capture screenshot"
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
async def _record_intermediate_step(
|
| 577 |
+
self,
|
| 578 |
+
step_number: int,
|
| 579 |
+
planned_action: Dict[str, Any],
|
| 580 |
+
execution_result: Dict[str, Any],
|
| 581 |
+
screenshot: bytes,
|
| 582 |
+
task_description: str,
|
| 583 |
+
):
|
| 584 |
+
"""
|
| 585 |
+
Record an intermediate step of GUI agent execution.
|
| 586 |
+
|
| 587 |
+
This method records each planning-action cycle to the recording system,
|
| 588 |
+
providing detailed traces of GUI agent's decision-making process.
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
step_number: Step number in the execution sequence
|
| 592 |
+
planned_action: Action planned by LLM
|
| 593 |
+
execution_result: Result of executing the action
|
| 594 |
+
screenshot: Screenshot before executing the action
|
| 595 |
+
task_description: Overall task description
|
| 596 |
+
"""
|
| 597 |
+
# Try to get recording_manager dynamically if not set at initialization
|
| 598 |
+
recording_manager = self.recording_manager
|
| 599 |
+
if not recording_manager and hasattr(self, '_runtime_info') and self._runtime_info:
|
| 600 |
+
# Try to get from grounding_client
|
| 601 |
+
grounding_client = self._runtime_info.grounding_client
|
| 602 |
+
if grounding_client and hasattr(grounding_client, 'recording_manager'):
|
| 603 |
+
recording_manager = grounding_client.recording_manager
|
| 604 |
+
logger.debug(f"Step {step_number}: Dynamically retrieved recording_manager from grounding_client")
|
| 605 |
+
|
| 606 |
+
if not recording_manager:
|
| 607 |
+
logger.debug(f"Step {step_number}: No recording_manager available, skipping intermediate step recording")
|
| 608 |
+
return
|
| 609 |
+
|
| 610 |
+
# Check if recording is active
|
| 611 |
+
try:
|
| 612 |
+
from openspace.recording.manager import RecordingManager
|
| 613 |
+
if not RecordingManager.is_recording():
|
| 614 |
+
logger.debug(f"Step {step_number}: RecordingManager not started")
|
| 615 |
+
return
|
| 616 |
+
except Exception as e:
|
| 617 |
+
logger.debug(f"Step {step_number}: Failed to check recording status: {e}")
|
| 618 |
+
return
|
| 619 |
+
|
| 620 |
+
# Check if recorder is initialized
|
| 621 |
+
if not hasattr(recording_manager, '_recorder') or not recording_manager._recorder:
|
| 622 |
+
logger.warning(f"Step {step_number}: recording_manager._recorder not initialized")
|
| 623 |
+
return
|
| 624 |
+
|
| 625 |
+
# Build command string for display
|
| 626 |
+
action_type = planned_action.get("action_type", "unknown")
|
| 627 |
+
command = self._format_action_command(planned_action)
|
| 628 |
+
|
| 629 |
+
# Build result summary
|
| 630 |
+
status = execution_result.get("status", "unknown")
|
| 631 |
+
is_success = status in ("success", "done", "observation")
|
| 632 |
+
|
| 633 |
+
# Build result content
|
| 634 |
+
if status == "done":
|
| 635 |
+
result_content = f"Task completed at step {step_number}"
|
| 636 |
+
elif status == "failed":
|
| 637 |
+
result_content = execution_result.get("message", "Task failed")
|
| 638 |
+
elif status == "observation":
|
| 639 |
+
result_content = execution_result.get("message", "Screenshot observation")
|
| 640 |
+
else:
|
| 641 |
+
result_content = execution_result.get("result", execution_result.get("message", str(execution_result)))
|
| 642 |
+
|
| 643 |
+
# Build parameters for recording
|
| 644 |
+
parameters = {
|
| 645 |
+
"task_description": task_description,
|
| 646 |
+
"step_number": step_number,
|
| 647 |
+
"action_type": action_type,
|
| 648 |
+
"planned_action": planned_action,
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
# Record to trajectory recorder (handles screenshot saving)
|
| 652 |
+
try:
|
| 653 |
+
await recording_manager._recorder.record_step(
|
| 654 |
+
backend="gui",
|
| 655 |
+
tool="gui_agent_step",
|
| 656 |
+
command=command,
|
| 657 |
+
result={
|
| 658 |
+
"status": "success" if is_success else "error",
|
| 659 |
+
"output": str(result_content)[:200],
|
| 660 |
+
},
|
| 661 |
+
parameters=parameters,
|
| 662 |
+
screenshot=screenshot,
|
| 663 |
+
extra={
|
| 664 |
+
"gui_step_number": step_number,
|
| 665 |
+
"reasoning": planned_action.get("reasoning", ""),
|
| 666 |
+
}
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
logger.info(f"✓ Recorded GUI intermediate step {step_number}: {command}")
|
| 670 |
+
|
| 671 |
+
except Exception as e:
|
| 672 |
+
logger.error(f"✗ Failed to record intermediate step {step_number}: {e}", exc_info=True)
|
| 673 |
+
|
| 674 |
+
def _format_action_command(self, planned_action: Dict[str, Any]) -> str:
|
| 675 |
+
"""
|
| 676 |
+
Format planned action into a human-readable command string.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
planned_action: Action dictionary from LLM planning
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
Formatted command string
|
| 683 |
+
"""
|
| 684 |
+
action_type = planned_action.get("action_type", "unknown")
|
| 685 |
+
|
| 686 |
+
# Handle special action types
|
| 687 |
+
if action_type == "DONE":
|
| 688 |
+
return "DONE (task completed)"
|
| 689 |
+
elif action_type == "FAIL":
|
| 690 |
+
reason = planned_action.get("reason", "unknown")
|
| 691 |
+
return f"FAIL ({reason})"
|
| 692 |
+
elif action_type == "WAIT":
|
| 693 |
+
return "WAIT (screenshot observation)"
|
| 694 |
+
|
| 695 |
+
# Handle PyAutoGUI commands
|
| 696 |
+
elif action_type == "PYAUTOGUI_COMMAND":
|
| 697 |
+
command = planned_action.get("command", "")
|
| 698 |
+
# Truncate long commands
|
| 699 |
+
if len(command) > 100:
|
| 700 |
+
return command[:100] + "..."
|
| 701 |
+
return command
|
| 702 |
+
|
| 703 |
+
# Handle standard action space commands
|
| 704 |
+
else:
|
| 705 |
+
parameters = planned_action.get("parameters", {})
|
| 706 |
+
if parameters:
|
| 707 |
+
# Format first 2 parameters
|
| 708 |
+
param_items = list(parameters.items())[:2]
|
| 709 |
+
param_str = ", ".join([f"{k}={v}" for k, v in param_items])
|
| 710 |
+
return f"{action_type}({param_str})"
|
| 711 |
+
else:
|
| 712 |
+
return action_type
|
openspace/grounding/backends/gui/transport/actions.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GUI Action Space Definitions.
|
| 3 |
+
"""
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
# Screen resolution constants
|
| 7 |
+
X_MAX = 1920
|
| 8 |
+
Y_MAX = 1080
|
| 9 |
+
|
| 10 |
+
# Keyboard keys constants
|
| 11 |
+
KEYBOARD_KEYS = [
|
| 12 |
+
'\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
|
| 13 |
+
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@',
|
| 14 |
+
'[', '\\', ']', '^', '_', '`',
|
| 15 |
+
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
|
| 16 |
+
'{', '|', '}', '~',
|
| 17 |
+
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
|
| 18 |
+
'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop',
|
| 19 |
+
'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide',
|
| 20 |
+
'down', 'end', 'enter', 'esc', 'escape', 'execute',
|
| 21 |
+
'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19',
|
| 22 |
+
'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
|
| 23 |
+
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji',
|
| 24 |
+
'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply',
|
| 25 |
+
'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9',
|
| 26 |
+
'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen',
|
| 27 |
+
'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
|
| 28 |
+
'shift', 'shiftleft', 'shiftright', 'sleep', 'stop', 'subtract', 'tab', 'up',
|
| 29 |
+
'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
|
| 30 |
+
'command', 'option', 'optionleft', 'optionright'
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# Action Space Definition
|
| 34 |
+
ACTION_SPACE = [
|
| 35 |
+
{
|
| 36 |
+
"action_type": "MOVE_TO",
|
| 37 |
+
"note": "move the cursor to the specified position",
|
| 38 |
+
"parameters": {
|
| 39 |
+
"x": {"type": float, "range": [0, X_MAX], "optional": False},
|
| 40 |
+
"y": {"type": float, "range": [0, Y_MAX], "optional": False},
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"action_type": "CLICK",
|
| 45 |
+
"note": "click the left button if button not specified, otherwise click the specified button",
|
| 46 |
+
"parameters": {
|
| 47 |
+
"button": {"type": str, "range": ["left", "right", "middle"], "optional": True},
|
| 48 |
+
"x": {"type": float, "range": [0, X_MAX], "optional": True},
|
| 49 |
+
"y": {"type": float, "range": [0, Y_MAX], "optional": True},
|
| 50 |
+
"num_clicks": {"type": int, "range": [1, 2, 3], "optional": True},
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"action_type": "MOUSE_DOWN",
|
| 55 |
+
"note": "press the mouse button",
|
| 56 |
+
"parameters": {
|
| 57 |
+
"button": {"type": str, "range": ["left", "right", "middle"], "optional": True}
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"action_type": "MOUSE_UP",
|
| 62 |
+
"note": "release the mouse button",
|
| 63 |
+
"parameters": {
|
| 64 |
+
"button": {"type": str, "range": ["left", "right", "middle"], "optional": True}
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"action_type": "RIGHT_CLICK",
|
| 69 |
+
"note": "right click at position",
|
| 70 |
+
"parameters": {
|
| 71 |
+
"x": {"type": float, "range": [0, X_MAX], "optional": True},
|
| 72 |
+
"y": {"type": float, "range": [0, Y_MAX], "optional": True}
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"action_type": "DOUBLE_CLICK",
|
| 77 |
+
"note": "double click at position",
|
| 78 |
+
"parameters": {
|
| 79 |
+
"x": {"type": float, "range": [0, X_MAX], "optional": True},
|
| 80 |
+
"y": {"type": float, "range": [0, Y_MAX], "optional": True}
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"action_type": "DRAG_TO",
|
| 85 |
+
"note": "drag the cursor to position",
|
| 86 |
+
"parameters": {
|
| 87 |
+
"x": {"type": float, "range": [0, X_MAX], "optional": False},
|
| 88 |
+
"y": {"type": float, "range": [0, Y_MAX], "optional": False}
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"action_type": "SCROLL",
|
| 93 |
+
"note": "scroll the mouse wheel",
|
| 94 |
+
"parameters": {
|
| 95 |
+
"dx": {"type": int, "range": None, "optional": False},
|
| 96 |
+
"dy": {"type": int, "range": None, "optional": False}
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"action_type": "TYPING",
|
| 101 |
+
"note": "type the specified text",
|
| 102 |
+
"parameters": {
|
| 103 |
+
"text": {"type": str, "range": None, "optional": False}
|
| 104 |
+
}
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"action_type": "PRESS",
|
| 108 |
+
"note": "press the specified key",
|
| 109 |
+
"parameters": {
|
| 110 |
+
"key": {"type": str, "range": KEYBOARD_KEYS, "optional": False}
|
| 111 |
+
}
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"action_type": "KEY_DOWN",
|
| 115 |
+
"note": "press down the specified key",
|
| 116 |
+
"parameters": {
|
| 117 |
+
"key": {"type": str, "range": KEYBOARD_KEYS, "optional": False}
|
| 118 |
+
}
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"action_type": "KEY_UP",
|
| 122 |
+
"note": "release the specified key",
|
| 123 |
+
"parameters": {
|
| 124 |
+
"key": {"type": str, "range": KEYBOARD_KEYS, "optional": False}
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"action_type": "HOTKEY",
|
| 129 |
+
"note": "press key combination",
|
| 130 |
+
"parameters": {
|
| 131 |
+
"keys": {"type": list, "range": [KEYBOARD_KEYS], "optional": False}
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"action_type": "WAIT",
|
| 136 |
+
"note": "wait until next action",
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"action_type": "FAIL",
|
| 140 |
+
"note": "mark task as failed",
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"action_type": "DONE",
|
| 144 |
+
"note": "mark task as done",
|
| 145 |
+
}
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def build_pyautogui_command(action_type: str, parameters: Dict[str, Any]) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Build pyautogui command from action type and parameters.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
action_type: Type of action (e.g., 'CLICK', 'TYPING')
|
| 155 |
+
parameters: Action parameters
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Python command string
|
| 159 |
+
"""
|
| 160 |
+
if action_type == "MOVE_TO":
|
| 161 |
+
if "x" in parameters and "y" in parameters:
|
| 162 |
+
x, y = parameters["x"], parameters["y"]
|
| 163 |
+
return f"pyautogui.moveTo({x}, {y}, 0.5, pyautogui.easeInQuad)"
|
| 164 |
+
else:
|
| 165 |
+
return "pyautogui.moveTo()"
|
| 166 |
+
|
| 167 |
+
elif action_type == "CLICK":
|
| 168 |
+
button = parameters.get("button", "left")
|
| 169 |
+
num_clicks = parameters.get("num_clicks", 1)
|
| 170 |
+
|
| 171 |
+
if "x" in parameters and "y" in parameters:
|
| 172 |
+
x, y = parameters["x"], parameters["y"]
|
| 173 |
+
return f"pyautogui.click(button='{button}', x={x}, y={y}, clicks={num_clicks})"
|
| 174 |
+
else:
|
| 175 |
+
return f"pyautogui.click(button='{button}', clicks={num_clicks})"
|
| 176 |
+
|
| 177 |
+
elif action_type == "MOUSE_DOWN":
|
| 178 |
+
button = parameters.get("button", "left")
|
| 179 |
+
return f"pyautogui.mouseDown(button='{button}')"
|
| 180 |
+
|
| 181 |
+
elif action_type == "MOUSE_UP":
|
| 182 |
+
button = parameters.get("button", "left")
|
| 183 |
+
return f"pyautogui.mouseUp(button='{button}')"
|
| 184 |
+
|
| 185 |
+
elif action_type == "RIGHT_CLICK":
|
| 186 |
+
if "x" in parameters and "y" in parameters:
|
| 187 |
+
x, y = parameters["x"], parameters["y"]
|
| 188 |
+
return f"pyautogui.rightClick(x={x}, y={y})"
|
| 189 |
+
else:
|
| 190 |
+
return "pyautogui.rightClick()"
|
| 191 |
+
|
| 192 |
+
elif action_type == "DOUBLE_CLICK":
|
| 193 |
+
if "x" in parameters and "y" in parameters:
|
| 194 |
+
x, y = parameters["x"], parameters["y"]
|
| 195 |
+
return f"pyautogui.doubleClick(x={x}, y={y})"
|
| 196 |
+
else:
|
| 197 |
+
return "pyautogui.doubleClick()"
|
| 198 |
+
|
| 199 |
+
elif action_type == "DRAG_TO":
|
| 200 |
+
if "x" in parameters and "y" in parameters:
|
| 201 |
+
x, y = parameters["x"], parameters["y"]
|
| 202 |
+
return f"pyautogui.dragTo({x}, {y}, 1.0, button='left')"
|
| 203 |
+
|
| 204 |
+
elif action_type == "SCROLL":
|
| 205 |
+
dx = parameters.get("dx", 0)
|
| 206 |
+
dy = parameters.get("dy", 0)
|
| 207 |
+
return f"pyautogui.scroll({dy})"
|
| 208 |
+
|
| 209 |
+
elif action_type == "TYPING":
|
| 210 |
+
text = parameters.get("text", "")
|
| 211 |
+
# Use repr() for proper string escaping
|
| 212 |
+
return f"pyautogui.typewrite({repr(text)})"
|
| 213 |
+
|
| 214 |
+
elif action_type == "PRESS":
|
| 215 |
+
key = parameters.get("key", "")
|
| 216 |
+
return f"pyautogui.press('{key}')"
|
| 217 |
+
|
| 218 |
+
elif action_type == "KEY_DOWN":
|
| 219 |
+
key = parameters.get("key", "")
|
| 220 |
+
return f"pyautogui.keyDown('{key}')"
|
| 221 |
+
|
| 222 |
+
elif action_type == "KEY_UP":
|
| 223 |
+
key = parameters.get("key", "")
|
| 224 |
+
return f"pyautogui.keyUp('{key}')"
|
| 225 |
+
|
| 226 |
+
elif action_type == "HOTKEY":
|
| 227 |
+
keys = parameters.get("keys", [])
|
| 228 |
+
if keys:
|
| 229 |
+
keys_str = ", ".join([f"'{k}'" for k in keys])
|
| 230 |
+
return f"pyautogui.hotkey({keys_str})"
|
| 231 |
+
|
| 232 |
+
return None
|
openspace/grounding/backends/gui/transport/connector.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import re
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
from openspace.grounding.core.transport.connectors import AioHttpConnector
|
| 5 |
+
from .actions import build_pyautogui_command, KEYBOARD_KEYS
|
| 6 |
+
from openspace.utils.logging import Logger
|
| 7 |
+
|
| 8 |
+
logger = Logger.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GUIConnector(AioHttpConnector):
|
| 12 |
+
"""
|
| 13 |
+
Connector for desktop_env HTTP API.
|
| 14 |
+
Provides action execution and observation methods.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
vm_ip: str,
|
| 20 |
+
server_port: int = 5000,
|
| 21 |
+
timeout: int = 90,
|
| 22 |
+
retry_times: int = 3,
|
| 23 |
+
retry_interval: float = 5.0,
|
| 24 |
+
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}",
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Initialize GUI connector.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
vm_ip: IP address of the VM running desktop_env
|
| 31 |
+
server_port: Port of the desktop_env HTTP server
|
| 32 |
+
timeout: Request timeout in seconds
|
| 33 |
+
retry_times: Number of retries for failed requests
|
| 34 |
+
retry_interval: Interval between retries in seconds
|
| 35 |
+
pkgs_prefix: Python command prefix for pyautogui setup
|
| 36 |
+
"""
|
| 37 |
+
base_url = f"http://{vm_ip}:{server_port}"
|
| 38 |
+
super().__init__(base_url, timeout=timeout)
|
| 39 |
+
|
| 40 |
+
self.vm_ip = vm_ip
|
| 41 |
+
self.server_port = server_port
|
| 42 |
+
self.retry_times = retry_times
|
| 43 |
+
self.retry_interval = retry_interval
|
| 44 |
+
self.pkgs_prefix = pkgs_prefix
|
| 45 |
+
self.timeout = timeout
|
| 46 |
+
|
| 47 |
+
async def _retry_invoke(
|
| 48 |
+
self,
|
| 49 |
+
operation_name: str,
|
| 50 |
+
operation_func,
|
| 51 |
+
*args,
|
| 52 |
+
**kwargs
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Execute operation with retry logic.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
operation_name: Name of operation for logging
|
| 59 |
+
operation_func: Async function to execute
|
| 60 |
+
*args: Positional arguments for operation_func
|
| 61 |
+
**kwargs: Keyword arguments for operation_func
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Operation result
|
| 65 |
+
|
| 66 |
+
Raises:
|
| 67 |
+
Exception: Last exception after all retries fail
|
| 68 |
+
"""
|
| 69 |
+
last_exc: Exception | None = None
|
| 70 |
+
|
| 71 |
+
for attempt in range(1, self.retry_times + 1):
|
| 72 |
+
try:
|
| 73 |
+
result = await operation_func(*args, **kwargs)
|
| 74 |
+
logger.debug("%s executed successfully (attempt %d/%d)", operation_name, attempt, self.retry_times)
|
| 75 |
+
return result
|
| 76 |
+
except asyncio.TimeoutError as exc:
|
| 77 |
+
logger.error("%s timed out", operation_name)
|
| 78 |
+
raise RuntimeError(f"{operation_name} timed out after {self.timeout} seconds") from exc
|
| 79 |
+
except Exception as exc:
|
| 80 |
+
last_exc = exc
|
| 81 |
+
if attempt == self.retry_times:
|
| 82 |
+
break
|
| 83 |
+
logger.warning(
|
| 84 |
+
"%s failed (attempt %d/%d): %s, retrying in %.1f seconds...",
|
| 85 |
+
operation_name, attempt, self.retry_times, exc, self.retry_interval
|
| 86 |
+
)
|
| 87 |
+
await asyncio.sleep(self.retry_interval)
|
| 88 |
+
|
| 89 |
+
error_msg = f"{operation_name} failed after {self.retry_times} retries"
|
| 90 |
+
logger.error(error_msg)
|
| 91 |
+
raise last_exc or RuntimeError(error_msg)
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool:
|
| 95 |
+
"""Validate image response using magic bytes."""
|
| 96 |
+
if not isinstance(data, (bytes, bytearray)) or not data:
|
| 97 |
+
return False
|
| 98 |
+
# PNG magic
|
| 99 |
+
if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n":
|
| 100 |
+
return True
|
| 101 |
+
# JPEG magic
|
| 102 |
+
if len(data) >= 3 and data[:3] == b"\xff\xd8\xff":
|
| 103 |
+
return True
|
| 104 |
+
# Fallback to content-type
|
| 105 |
+
if content_type and ("image/png" in content_type or "image/jpeg" in content_type):
|
| 106 |
+
return True
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def _fix_pyautogui_less_than_bug(command: str) -> str:
|
| 111 |
+
"""
|
| 112 |
+
Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls.
|
| 113 |
+
|
| 114 |
+
This fixes the known PyAutoGUI issue where typing '<' produces '>' instead.
|
| 115 |
+
References:
|
| 116 |
+
- https://github.com/asweigart/pyautogui/issues/198
|
| 117 |
+
- https://github.com/xlang-ai/OSWorld/issues/257
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
command (str): The original pyautogui command
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
str: The fixed command with '<' characters handled properly
|
| 124 |
+
"""
|
| 125 |
+
# Pattern to match press('<') or press('\u003c') calls
|
| 126 |
+
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
|
| 127 |
+
|
| 128 |
+
# Handle press('<') calls
|
| 129 |
+
def replace_press_less_than(match):
|
| 130 |
+
return 'pyautogui.hotkey("shift", ",")'
|
| 131 |
+
|
| 132 |
+
# First handle press('<') calls
|
| 133 |
+
command = re.sub(press_pattern, replace_press_less_than, command)
|
| 134 |
+
|
| 135 |
+
# Pattern to match typewrite calls with quoted strings
|
| 136 |
+
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
|
| 137 |
+
|
| 138 |
+
# Then handle typewrite calls
|
| 139 |
+
def process_typewrite_match(match):
|
| 140 |
+
quote_char = match.group(1)
|
| 141 |
+
content = match.group(2)
|
| 142 |
+
|
| 143 |
+
# Preprocess: Try to decode Unicode escapes like \u003c to actual '<'
|
| 144 |
+
# This handles cases where '<' is represented as escaped Unicode
|
| 145 |
+
try:
|
| 146 |
+
# Attempt to decode unicode escapes
|
| 147 |
+
decoded_content = content.encode('utf-8').decode('unicode_escape')
|
| 148 |
+
content = decoded_content
|
| 149 |
+
except UnicodeDecodeError:
|
| 150 |
+
# If decoding fails, proceed with original content to avoid breaking existing logic
|
| 151 |
+
pass # Graceful degradation - fall back to original content if decoding fails
|
| 152 |
+
|
| 153 |
+
# Check if content contains '<'
|
| 154 |
+
if '<' not in content:
|
| 155 |
+
return match.group(0)
|
| 156 |
+
|
| 157 |
+
# Split by '<' and rebuild
|
| 158 |
+
parts = content.split('<')
|
| 159 |
+
result_parts = []
|
| 160 |
+
|
| 161 |
+
for i, part in enumerate(parts):
|
| 162 |
+
if i == 0:
|
| 163 |
+
# First part
|
| 164 |
+
if part:
|
| 165 |
+
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
| 166 |
+
else:
|
| 167 |
+
# Add hotkey for '<' and then typewrite for the rest
|
| 168 |
+
result_parts.append('pyautogui.hotkey("shift", ",")')
|
| 169 |
+
if part:
|
| 170 |
+
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
| 171 |
+
|
| 172 |
+
return '; '.join(result_parts)
|
| 173 |
+
|
| 174 |
+
command = re.sub(typewrite_pattern, process_typewrite_match, command)
|
| 175 |
+
|
| 176 |
+
return command
|
| 177 |
+
|
| 178 |
+
async def get_screen_size(self) -> Optional[tuple[int, int]]:
|
| 179 |
+
"""
|
| 180 |
+
Get actual screen size from desktop environment using pyautogui.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
(width, height) tuple, or None on failure
|
| 184 |
+
"""
|
| 185 |
+
try:
|
| 186 |
+
command = "print(pyautogui.size())"
|
| 187 |
+
result = await self.execute_python_command(command)
|
| 188 |
+
if result and result.get("status") == "success":
|
| 189 |
+
output = result.get("output", "")
|
| 190 |
+
# Parse output like "Size(width=2880, height=1800)"
|
| 191 |
+
import re
|
| 192 |
+
match = re.search(r'width=(\d+).*height=(\d+)', output)
|
| 193 |
+
if match:
|
| 194 |
+
width = int(match.group(1))
|
| 195 |
+
height = int(match.group(2))
|
| 196 |
+
logger.info(f"Detected screen size: {width}x{height}")
|
| 197 |
+
return (width, height)
|
| 198 |
+
logger.warning(f"Failed to detect screen size, output: {result}")
|
| 199 |
+
return None
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error(f"Failed to get screen size: {e}")
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
async def get_screenshot(self) -> Optional[bytes]:
|
| 205 |
+
"""
|
| 206 |
+
Get screenshot from desktop environment.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Screenshot image bytes (PNG/JPEG), or None on failure
|
| 210 |
+
"""
|
| 211 |
+
try:
|
| 212 |
+
async def _get():
|
| 213 |
+
response = await self._request("GET", "/screenshot", timeout=10)
|
| 214 |
+
if response.status == 200:
|
| 215 |
+
content_type = response.headers.get("Content-Type", "")
|
| 216 |
+
content = await response.read()
|
| 217 |
+
if self._is_valid_image_response(content_type, content):
|
| 218 |
+
return content
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError("Invalid screenshot format")
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError(f"HTTP {response.status}")
|
| 223 |
+
|
| 224 |
+
return await self._retry_invoke("get_screenshot", _get)
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Failed to get screenshot: {e}")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
async def execute_python_command(self, command: str) -> Optional[Dict[str, Any]]:
|
| 230 |
+
"""
|
| 231 |
+
Execute a Python command on desktop environment.
|
| 232 |
+
Used for pyautogui commands.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
command: Python command to execute
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Response dict with execution result, or None on failure
|
| 239 |
+
"""
|
| 240 |
+
try:
|
| 241 |
+
# Apply '<' character fix for PyAutoGUI bug
|
| 242 |
+
fixed_command = self._fix_pyautogui_less_than_bug(command)
|
| 243 |
+
|
| 244 |
+
command_list = ["python", "-c", self.pkgs_prefix.format(command=fixed_command)]
|
| 245 |
+
payload = {"command": command_list, "shell": False}
|
| 246 |
+
|
| 247 |
+
async def _execute():
|
| 248 |
+
return await self.post_json("/execute", payload)
|
| 249 |
+
|
| 250 |
+
return await self._retry_invoke("execute_python_command", _execute)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"Failed to execute command: {e}")
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
async def execute_action(self, action_type: str, parameters: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 256 |
+
"""
|
| 257 |
+
Execute a desktop action.
|
| 258 |
+
This is the main method for action space execution.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
action_type: Action type (e.g., 'CLICK', 'TYPING')
|
| 262 |
+
parameters: Action parameters
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Result dict with execution status
|
| 266 |
+
"""
|
| 267 |
+
parameters = parameters or {}
|
| 268 |
+
|
| 269 |
+
# Handle control actions
|
| 270 |
+
if action_type in ['WAIT', 'FAIL', 'DONE']:
|
| 271 |
+
return {
|
| 272 |
+
"status": "success",
|
| 273 |
+
"action_type": action_type,
|
| 274 |
+
"message": f"Control action {action_type} acknowledged"
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
# Validate keyboard keys
|
| 278 |
+
if action_type in ['PRESS', 'KEY_DOWN', 'KEY_UP']:
|
| 279 |
+
key = parameters.get('key')
|
| 280 |
+
if key and key not in KEYBOARD_KEYS:
|
| 281 |
+
return {
|
| 282 |
+
"status": "error",
|
| 283 |
+
"action_type": action_type,
|
| 284 |
+
"error": f"Invalid key: {key}. Must be in supported keyboard keys."
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
if action_type == 'HOTKEY':
|
| 288 |
+
keys = parameters.get('keys', [])
|
| 289 |
+
invalid_keys = [k for k in keys if k not in KEYBOARD_KEYS]
|
| 290 |
+
if invalid_keys:
|
| 291 |
+
return {
|
| 292 |
+
"status": "error",
|
| 293 |
+
"action_type": action_type,
|
| 294 |
+
"error": f"Invalid keys: {invalid_keys}"
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
# Build pyautogui command
|
| 298 |
+
command = build_pyautogui_command(action_type, parameters)
|
| 299 |
+
|
| 300 |
+
if command is None:
|
| 301 |
+
return {
|
| 302 |
+
"status": "error",
|
| 303 |
+
"action_type": action_type,
|
| 304 |
+
"error": f"Unsupported action type: {action_type}"
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
# Execute command
|
| 308 |
+
result = await self.execute_python_command(command)
|
| 309 |
+
|
| 310 |
+
if result:
|
| 311 |
+
return {
|
| 312 |
+
"status": "success",
|
| 313 |
+
"action_type": action_type,
|
| 314 |
+
"parameters": parameters,
|
| 315 |
+
"result": result
|
| 316 |
+
}
|
| 317 |
+
else:
|
| 318 |
+
return {
|
| 319 |
+
"status": "error",
|
| 320 |
+
"action_type": action_type,
|
| 321 |
+
"parameters": parameters,
|
| 322 |
+
"error": "Command execution failed"
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
async def get_accessibility_tree(self, max_depth: int = 5) -> Optional[Dict[str, Any]]:
|
| 326 |
+
"""
|
| 327 |
+
Get accessibility tree from desktop environment.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
max_depth: Maximum depth of accessibility tree traversal
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Accessibility tree as dict, or None on failure
|
| 334 |
+
"""
|
| 335 |
+
try:
|
| 336 |
+
async def _get():
|
| 337 |
+
response = await self._request("GET", "/accessibility", timeout=10)
|
| 338 |
+
if response.status == 200:
|
| 339 |
+
data = await response.json()
|
| 340 |
+
return data.get("AT")
|
| 341 |
+
else:
|
| 342 |
+
raise RuntimeError(f"HTTP {response.status}")
|
| 343 |
+
|
| 344 |
+
return await self._retry_invoke("get_accessibility_tree", _get)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"Failed to get accessibility tree: {e}")
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
async def get_cursor_position(self) -> Optional[tuple[int, int]]:
|
| 350 |
+
"""
|
| 351 |
+
Get current mouse cursor position.
|
| 352 |
+
Useful for GUI debugging and relative positioning.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
(x, y) tuple, or None on failure
|
| 356 |
+
"""
|
| 357 |
+
try:
|
| 358 |
+
async def _get():
|
| 359 |
+
result = await self.get_json("/cursor_position")
|
| 360 |
+
return (result.get("x"), result.get("y"))
|
| 361 |
+
|
| 362 |
+
return await self._retry_invoke("get_cursor_position", _get)
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"Failed to get cursor position: {e}")
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
async def invoke(self, name: str, params: dict[str, Any]) -> Any:
|
| 368 |
+
"""
|
| 369 |
+
Unified RPC entry for operations.
|
| 370 |
+
Required by BaseConnector.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
name: Operation name (action_type or observation method)
|
| 374 |
+
params: Operation parameters
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Operation result
|
| 378 |
+
"""
|
| 379 |
+
# Handle observation methods
|
| 380 |
+
if name == "screenshot":
|
| 381 |
+
return await self.get_screenshot()
|
| 382 |
+
elif name == "accessibility_tree":
|
| 383 |
+
max_depth = params.get("max_depth", 5) if params else 5
|
| 384 |
+
return await self.get_accessibility_tree(max_depth)
|
| 385 |
+
elif name == "cursor_position":
|
| 386 |
+
return await self.get_cursor_position()
|
| 387 |
+
else:
|
| 388 |
+
# Treat as action
|
| 389 |
+
return await self.execute_action(name.upper(), params or {})
|
openspace/grounding/backends/gui/transport/local_connector.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Local GUI Connector — execute GUI operations directly in-process.
|
| 3 |
+
|
| 4 |
+
This connector has the **same public API** as GUIConnector (HTTP version)
|
| 5 |
+
but uses local pyautogui / ScreenshotHelper / AccessibilityHelper,
|
| 6 |
+
removing the need for a local_server.
|
| 7 |
+
|
| 8 |
+
Return format is kept identical so that GUISession / GUIAgentTool
|
| 9 |
+
work without any changes.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import asyncio
|
| 13 |
+
import os
|
| 14 |
+
import platform
|
| 15 |
+
import re
|
| 16 |
+
import tempfile
|
| 17 |
+
import uuid
|
| 18 |
+
from typing import Any, Dict, Optional
|
| 19 |
+
|
| 20 |
+
from openspace.grounding.core.transport.connectors.base import BaseConnector
|
| 21 |
+
from openspace.grounding.core.transport.task_managers.noop import NoOpConnectionManager
|
| 22 |
+
from openspace.utils.logging import Logger
|
| 23 |
+
|
| 24 |
+
logger = Logger.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
platform_name = platform.system()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LocalGUIConnector(BaseConnector[Any]):
|
| 30 |
+
"""
|
| 31 |
+
GUI connector that runs desktop automation **locally** using pyautogui /
|
| 32 |
+
ScreenshotHelper / AccessibilityHelper, bypassing the Flask local_server.
|
| 33 |
+
|
| 34 |
+
Public API is compatible with ``GUIConnector`` so that ``GUISession``
|
| 35 |
+
works without modification.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
timeout: int = 90,
|
| 41 |
+
retry_times: int = 3,
|
| 42 |
+
retry_interval: float = 5.0,
|
| 43 |
+
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}",
|
| 44 |
+
):
|
| 45 |
+
super().__init__(NoOpConnectionManager())
|
| 46 |
+
self.timeout = timeout
|
| 47 |
+
self.retry_times = retry_times
|
| 48 |
+
self.retry_interval = retry_interval
|
| 49 |
+
self.pkgs_prefix = pkgs_prefix
|
| 50 |
+
|
| 51 |
+
# Compatibility attributes expected by GUISession
|
| 52 |
+
self.vm_ip = "localhost"
|
| 53 |
+
self.server_port = 0
|
| 54 |
+
self.base_url = "local://localhost"
|
| 55 |
+
|
| 56 |
+
# Lazy-initialized helpers (avoid import side effects at class load)
|
| 57 |
+
self._screenshot_helper = None
|
| 58 |
+
self._accessibility_helper = None
|
| 59 |
+
|
| 60 |
+
def _get_screenshot_helper(self):
|
| 61 |
+
if self._screenshot_helper is None:
|
| 62 |
+
from openspace.local_server.utils import ScreenshotHelper
|
| 63 |
+
self._screenshot_helper = ScreenshotHelper()
|
| 64 |
+
return self._screenshot_helper
|
| 65 |
+
|
| 66 |
+
def _get_accessibility_helper(self):
|
| 67 |
+
if self._accessibility_helper is None:
|
| 68 |
+
from openspace.local_server.utils import AccessibilityHelper
|
| 69 |
+
self._accessibility_helper = AccessibilityHelper()
|
| 70 |
+
return self._accessibility_helper
|
| 71 |
+
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# connect / disconnect
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
async def connect(self) -> None:
|
| 77 |
+
"""No real connection for local mode."""
|
| 78 |
+
if self._connected:
|
| 79 |
+
return
|
| 80 |
+
await super().connect()
|
| 81 |
+
logger.info("LocalGUIConnector: ready (local mode, no server required)")
|
| 82 |
+
|
| 83 |
+
# ------------------------------------------------------------------
|
| 84 |
+
# Retry wrapper (same interface as GUIConnector._retry_invoke)
|
| 85 |
+
# ------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
async def _retry_invoke(
|
| 88 |
+
self,
|
| 89 |
+
operation_name: str,
|
| 90 |
+
operation_func,
|
| 91 |
+
*args,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
last_exc: Exception | None = None
|
| 95 |
+
for attempt in range(1, self.retry_times + 1):
|
| 96 |
+
try:
|
| 97 |
+
result = await operation_func(*args, **kwargs)
|
| 98 |
+
logger.debug(
|
| 99 |
+
"%s executed successfully (attempt %d/%d)",
|
| 100 |
+
operation_name, attempt, self.retry_times,
|
| 101 |
+
)
|
| 102 |
+
return result
|
| 103 |
+
except asyncio.TimeoutError as exc:
|
| 104 |
+
logger.error("%s timed out", operation_name)
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
f"{operation_name} timed out after {self.timeout} seconds"
|
| 107 |
+
) from exc
|
| 108 |
+
except Exception as exc:
|
| 109 |
+
last_exc = exc
|
| 110 |
+
if attempt == self.retry_times:
|
| 111 |
+
break
|
| 112 |
+
logger.warning(
|
| 113 |
+
"%s failed (attempt %d/%d): %s, retrying in %.1f seconds...",
|
| 114 |
+
operation_name, attempt, self.retry_times, exc, self.retry_interval,
|
| 115 |
+
)
|
| 116 |
+
await asyncio.sleep(self.retry_interval)
|
| 117 |
+
|
| 118 |
+
error_msg = f"{operation_name} failed after {self.retry_times} retries"
|
| 119 |
+
logger.error(error_msg)
|
| 120 |
+
raise last_exc or RuntimeError(error_msg)
|
| 121 |
+
|
| 122 |
+
# ------------------------------------------------------------------
|
| 123 |
+
# PyAutoGUI '<' bug fix (same as GUIConnector)
|
| 124 |
+
# ------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def _fix_pyautogui_less_than_bug(command: str) -> str:
|
| 128 |
+
"""Fix PyAutoGUI '<' character bug."""
|
| 129 |
+
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
|
| 130 |
+
|
| 131 |
+
def replace_press_less_than(match):
|
| 132 |
+
return 'pyautogui.hotkey("shift", ",")'
|
| 133 |
+
|
| 134 |
+
command = re.sub(press_pattern, replace_press_less_than, command)
|
| 135 |
+
|
| 136 |
+
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
|
| 137 |
+
|
| 138 |
+
def process_typewrite_match(match):
|
| 139 |
+
quote_char = match.group(1)
|
| 140 |
+
content = match.group(2)
|
| 141 |
+
try:
|
| 142 |
+
decoded_content = content.encode("utf-8").decode("unicode_escape")
|
| 143 |
+
content = decoded_content
|
| 144 |
+
except UnicodeDecodeError:
|
| 145 |
+
pass
|
| 146 |
+
if "<" not in content:
|
| 147 |
+
return match.group(0)
|
| 148 |
+
parts = content.split("<")
|
| 149 |
+
result_parts = []
|
| 150 |
+
for i, part in enumerate(parts):
|
| 151 |
+
if i == 0:
|
| 152 |
+
if part:
|
| 153 |
+
result_parts.append(
|
| 154 |
+
f"pyautogui.typewrite({quote_char}{part}{quote_char})"
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
result_parts.append('pyautogui.hotkey("shift", ",")')
|
| 158 |
+
if part:
|
| 159 |
+
result_parts.append(
|
| 160 |
+
f"pyautogui.typewrite({quote_char}{part}{quote_char})"
|
| 161 |
+
)
|
| 162 |
+
return "; ".join(result_parts)
|
| 163 |
+
|
| 164 |
+
command = re.sub(typewrite_pattern, process_typewrite_match, command)
|
| 165 |
+
return command
|
| 166 |
+
|
| 167 |
+
# ------------------------------------------------------------------
|
| 168 |
+
# Image response validation (same as GUIConnector)
|
| 169 |
+
# ------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool:
|
| 173 |
+
if not isinstance(data, (bytes, bytearray)) or not data:
|
| 174 |
+
return False
|
| 175 |
+
if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n":
|
| 176 |
+
return True
|
| 177 |
+
if len(data) >= 3 and data[:3] == b"\xff\xd8\xff":
|
| 178 |
+
return True
|
| 179 |
+
if content_type and ("image/png" in content_type or "image/jpeg" in content_type):
|
| 180 |
+
return True
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
# ------------------------------------------------------------------
|
| 184 |
+
# Public API (same signatures as GUIConnector)
|
| 185 |
+
# ------------------------------------------------------------------
|
| 186 |
+
|
| 187 |
+
async def get_screen_size(self) -> Optional[tuple[int, int]]:
|
| 188 |
+
"""Get screen size using pyautogui."""
|
| 189 |
+
try:
|
| 190 |
+
command = "print(pyautogui.size())"
|
| 191 |
+
result = await self.execute_python_command(command)
|
| 192 |
+
if result and result.get("status") == "success":
|
| 193 |
+
output = result.get("output", "")
|
| 194 |
+
match = re.search(r"width=(\d+).*height=(\d+)", output)
|
| 195 |
+
if match:
|
| 196 |
+
width = int(match.group(1))
|
| 197 |
+
height = int(match.group(2))
|
| 198 |
+
logger.info("Detected screen size: %dx%d", width, height)
|
| 199 |
+
return (width, height)
|
| 200 |
+
logger.warning("Failed to detect screen size, output: %s", result)
|
| 201 |
+
return None
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error("Failed to get screen size: %s", e)
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
async def get_screenshot(self) -> Optional[bytes]:
|
| 207 |
+
"""Capture screenshot locally using ScreenshotHelper."""
|
| 208 |
+
try:
|
| 209 |
+
async def _get():
|
| 210 |
+
helper = self._get_screenshot_helper()
|
| 211 |
+
tmp_path = os.path.join(
|
| 212 |
+
tempfile.gettempdir(), f"screenshot_{uuid.uuid4().hex}.png"
|
| 213 |
+
)
|
| 214 |
+
if helper.capture(tmp_path, with_cursor=True):
|
| 215 |
+
with open(tmp_path, "rb") as f:
|
| 216 |
+
data = f.read()
|
| 217 |
+
os.remove(tmp_path)
|
| 218 |
+
return data
|
| 219 |
+
else:
|
| 220 |
+
raise RuntimeError("Screenshot capture failed")
|
| 221 |
+
|
| 222 |
+
return await self._retry_invoke("get_screenshot", _get)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error("Failed to get screenshot: %s", e)
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
async def execute_python_command(self, command: str) -> Optional[Dict[str, Any]]:
|
| 228 |
+
"""Execute a pyautogui Python command locally via subprocess."""
|
| 229 |
+
try:
|
| 230 |
+
fixed_command = self._fix_pyautogui_less_than_bug(command)
|
| 231 |
+
full_command = self.pkgs_prefix.format(command=fixed_command)
|
| 232 |
+
|
| 233 |
+
async def _execute():
|
| 234 |
+
python_cmd = "python" if platform_name == "Windows" else "python3"
|
| 235 |
+
proc = await asyncio.create_subprocess_exec(
|
| 236 |
+
python_cmd, "-c", full_command,
|
| 237 |
+
stdout=asyncio.subprocess.PIPE,
|
| 238 |
+
stderr=asyncio.subprocess.PIPE,
|
| 239 |
+
)
|
| 240 |
+
stdout_b, stderr_b = await asyncio.wait_for(
|
| 241 |
+
proc.communicate(), timeout=self.timeout
|
| 242 |
+
)
|
| 243 |
+
stdout = stdout_b.decode("utf-8", errors="replace") if stdout_b else ""
|
| 244 |
+
stderr = stderr_b.decode("utf-8", errors="replace") if stderr_b else ""
|
| 245 |
+
returncode = proc.returncode or 0
|
| 246 |
+
return {
|
| 247 |
+
"status": "success" if returncode == 0 else "error",
|
| 248 |
+
"output": stdout + stderr,
|
| 249 |
+
"error": stderr if returncode != 0 else "",
|
| 250 |
+
"returncode": returncode,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
return await self._retry_invoke("execute_python_command", _execute)
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error("Failed to execute command: %s", e)
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
async def execute_action(
|
| 259 |
+
self, action_type: str, parameters: Dict[str, Any] | None = None
|
| 260 |
+
) -> Dict[str, Any]:
|
| 261 |
+
"""Execute a desktop action (same logic as GUIConnector)."""
|
| 262 |
+
parameters = parameters or {}
|
| 263 |
+
|
| 264 |
+
if action_type in ["WAIT", "FAIL", "DONE"]:
|
| 265 |
+
return {
|
| 266 |
+
"status": "success",
|
| 267 |
+
"action_type": action_type,
|
| 268 |
+
"message": f"Control action {action_type} acknowledged",
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
# Import action builder (same module used by GUIConnector)
|
| 272 |
+
from openspace.grounding.backends.gui.transport.actions import (
|
| 273 |
+
build_pyautogui_command,
|
| 274 |
+
KEYBOARD_KEYS,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if action_type in ["PRESS", "KEY_DOWN", "KEY_UP"]:
|
| 278 |
+
key = parameters.get("key")
|
| 279 |
+
if key and key not in KEYBOARD_KEYS:
|
| 280 |
+
return {
|
| 281 |
+
"status": "error",
|
| 282 |
+
"action_type": action_type,
|
| 283 |
+
"error": f"Invalid key: {key}. Must be in supported keyboard keys.",
|
| 284 |
+
}
|
| 285 |
+
if action_type == "HOTKEY":
|
| 286 |
+
keys = parameters.get("keys", [])
|
| 287 |
+
invalid_keys = [k for k in keys if k not in KEYBOARD_KEYS]
|
| 288 |
+
if invalid_keys:
|
| 289 |
+
return {
|
| 290 |
+
"status": "error",
|
| 291 |
+
"action_type": action_type,
|
| 292 |
+
"error": f"Invalid keys: {invalid_keys}",
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
command = build_pyautogui_command(action_type, parameters)
|
| 296 |
+
if command is None:
|
| 297 |
+
return {
|
| 298 |
+
"status": "error",
|
| 299 |
+
"action_type": action_type,
|
| 300 |
+
"error": f"Unsupported action type: {action_type}",
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
result = await self.execute_python_command(command)
|
| 304 |
+
if result:
|
| 305 |
+
return {
|
| 306 |
+
"status": "success",
|
| 307 |
+
"action_type": action_type,
|
| 308 |
+
"parameters": parameters,
|
| 309 |
+
"result": result,
|
| 310 |
+
}
|
| 311 |
+
else:
|
| 312 |
+
return {
|
| 313 |
+
"status": "error",
|
| 314 |
+
"action_type": action_type,
|
| 315 |
+
"parameters": parameters,
|
| 316 |
+
"error": "Command execution failed",
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
async def get_accessibility_tree(
|
| 320 |
+
self, max_depth: int = 5
|
| 321 |
+
) -> Optional[Dict[str, Any]]:
|
| 322 |
+
"""Get accessibility tree locally."""
|
| 323 |
+
try:
|
| 324 |
+
async def _get():
|
| 325 |
+
helper = self._get_accessibility_helper()
|
| 326 |
+
return helper.get_tree(max_depth=max_depth)
|
| 327 |
+
|
| 328 |
+
return await self._retry_invoke("get_accessibility_tree", _get)
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error("Failed to get accessibility tree: %s", e)
|
| 331 |
+
return None
|
| 332 |
+
|
| 333 |
+
async def get_cursor_position(self) -> Optional[tuple[int, int]]:
|
| 334 |
+
"""Get cursor position locally."""
|
| 335 |
+
try:
|
| 336 |
+
async def _get():
|
| 337 |
+
helper = self._get_screenshot_helper()
|
| 338 |
+
return helper.get_cursor_position()
|
| 339 |
+
|
| 340 |
+
return await self._retry_invoke("get_cursor_position", _get)
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error("Failed to get cursor position: %s", e)
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
# ------------------------------------------------------------------
|
| 346 |
+
# BaseConnector abstract methods
|
| 347 |
+
# ------------------------------------------------------------------
|
| 348 |
+
|
| 349 |
+
async def invoke(self, name: str, params: dict[str, Any]) -> Any:
|
| 350 |
+
if name == "screenshot":
|
| 351 |
+
return await self.get_screenshot()
|
| 352 |
+
elif name == "accessibility_tree":
|
| 353 |
+
max_depth = params.get("max_depth", 5) if params else 5
|
| 354 |
+
return await self.get_accessibility_tree(max_depth)
|
| 355 |
+
elif name == "cursor_position":
|
| 356 |
+
return await self.get_cursor_position()
|
| 357 |
+
else:
|
| 358 |
+
return await self.execute_action(name.upper(), params or {})
|
| 359 |
+
|
| 360 |
+
async def request(self, *args: Any, **kwargs: Any) -> Any:
|
| 361 |
+
raise NotImplementedError(
|
| 362 |
+
"LocalGUIConnector does not support raw HTTP requests"
|
| 363 |
+
)
|
| 364 |
+
|
openspace/grounding/backends/mcp/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Backend for OpenSpace Grounding.
|
| 3 |
+
|
| 4 |
+
This module provides the MCP (Model Context Protocol) backend implementation
|
| 5 |
+
for the grounding framework. It includes:
|
| 6 |
+
|
| 7 |
+
- MCPProvider: Manages multiple MCP server sessions
|
| 8 |
+
- MCPSession: Handles individual MCP server connections
|
| 9 |
+
- MCPClient: High-level client for MCP server configuration
|
| 10 |
+
- MCPInstallerManager: Manages automatic installation of MCP dependencies
|
| 11 |
+
- MCPToolCache: Caches tool metadata to avoid starting servers on list_tools
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from .provider import MCPProvider
|
| 15 |
+
from .session import MCPSession
|
| 16 |
+
from .client import MCPClient
|
| 17 |
+
from .installer import (
|
| 18 |
+
MCPInstallerManager,
|
| 19 |
+
get_global_installer,
|
| 20 |
+
set_global_installer,
|
| 21 |
+
MCPDependencyError,
|
| 22 |
+
MCPCommandNotFoundError,
|
| 23 |
+
MCPInstallationCancelledError,
|
| 24 |
+
MCPInstallationFailedError,
|
| 25 |
+
)
|
| 26 |
+
from .tool_cache import MCPToolCache, get_tool_cache
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"MCPProvider",
|
| 30 |
+
"MCPSession",
|
| 31 |
+
"MCPClient",
|
| 32 |
+
"MCPInstallerManager",
|
| 33 |
+
"get_global_installer",
|
| 34 |
+
"set_global_installer",
|
| 35 |
+
"MCPDependencyError",
|
| 36 |
+
"MCPCommandNotFoundError",
|
| 37 |
+
"MCPInstallationCancelledError",
|
| 38 |
+
"MCPInstallationFailedError",
|
| 39 |
+
"MCPToolCache",
|
| 40 |
+
"get_tool_cache",
|
| 41 |
+
]
|
openspace/grounding/backends/mcp/client.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Client for managing MCP servers and sessions.
|
| 3 |
+
|
| 4 |
+
This module provides a high-level client that manages MCP servers, connectors,
|
| 5 |
+
and sessions from configuration.
|
| 6 |
+
"""
|
| 7 |
+
import asyncio
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
|
| 11 |
+
from openspace.grounding.core.types import SandboxOptions
|
| 12 |
+
from openspace.config.utils import get_config_value, save_json_file, load_json_file
|
| 13 |
+
from .config import create_connector_from_config
|
| 14 |
+
from .session import MCPSession
|
| 15 |
+
from .installer import MCPInstallerManager, MCPDependencyError
|
| 16 |
+
|
| 17 |
+
from openspace.utils.logging import Logger
|
| 18 |
+
|
| 19 |
+
logger = Logger.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MCPClient:
|
| 23 |
+
"""Client for managing MCP servers and sessions.
|
| 24 |
+
|
| 25 |
+
This class provides a unified interface for working with MCP servers,
|
| 26 |
+
handling configuration, connector creation, and session management.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
config: str | dict[str, Any] | None = None,
|
| 32 |
+
sandbox: bool = False,
|
| 33 |
+
sandbox_options: SandboxOptions | None = None,
|
| 34 |
+
timeout: float = 30.0,
|
| 35 |
+
sse_read_timeout: float = 300.0,
|
| 36 |
+
max_retries: int = 3,
|
| 37 |
+
retry_interval: float = 2.0,
|
| 38 |
+
installer: Optional[MCPInstallerManager] = None,
|
| 39 |
+
check_dependencies: bool = True,
|
| 40 |
+
tool_call_max_retries: int = 3,
|
| 41 |
+
tool_call_retry_delay: float = 1.0,
|
| 42 |
+
) -> None:
|
| 43 |
+
"""Initialize a new MCP client.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
config: Either a dict containing configuration or a path to a JSON config file.
|
| 47 |
+
If None, an empty configuration is used.
|
| 48 |
+
sandbox: Whether to use sandboxed execution mode for running MCP servers.
|
| 49 |
+
sandbox_options: Optional sandbox configuration options.
|
| 50 |
+
timeout: Timeout for operations in seconds (default: 30.0)
|
| 51 |
+
sse_read_timeout: SSE read timeout in seconds (default: 300.0)
|
| 52 |
+
max_retries: Maximum number of retry attempts for failed operations (default: 3)
|
| 53 |
+
retry_interval: Wait time between retries in seconds (default: 2.0)
|
| 54 |
+
installer: Optional installer manager for dependency installation
|
| 55 |
+
check_dependencies: Whether to check and install dependencies (default: True)
|
| 56 |
+
tool_call_max_retries: Maximum number of retries for tool calls (default: 3)
|
| 57 |
+
tool_call_retry_delay: Initial delay between tool call retries in seconds (default: 1.0)
|
| 58 |
+
"""
|
| 59 |
+
self.config: dict[str, Any] = {}
|
| 60 |
+
self.sandbox = sandbox
|
| 61 |
+
self.sandbox_options = sandbox_options
|
| 62 |
+
self.timeout = timeout
|
| 63 |
+
self.sse_read_timeout = sse_read_timeout
|
| 64 |
+
self.max_retries = max_retries
|
| 65 |
+
self.retry_interval = retry_interval
|
| 66 |
+
self.installer = installer
|
| 67 |
+
self.check_dependencies = check_dependencies
|
| 68 |
+
self.tool_call_max_retries = tool_call_max_retries
|
| 69 |
+
self.tool_call_retry_delay = tool_call_retry_delay
|
| 70 |
+
self.sessions: dict[str, MCPSession] = {}
|
| 71 |
+
self.active_sessions: list[str] = []
|
| 72 |
+
|
| 73 |
+
# Load configuration if provided
|
| 74 |
+
if config is not None:
|
| 75 |
+
if isinstance(config, str):
|
| 76 |
+
self.config = load_json_file(config)
|
| 77 |
+
else:
|
| 78 |
+
self.config = config
|
| 79 |
+
|
| 80 |
+
def _get_mcp_servers(self) -> dict[str, Any]:
|
| 81 |
+
"""Internal helper to get mcpServers configuration.
|
| 82 |
+
|
| 83 |
+
Tries both 'mcpServers' and 'servers' keys for compatibility.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Dictionary of MCP server configurations, empty dict if none found.
|
| 87 |
+
"""
|
| 88 |
+
servers = get_config_value(self.config, "mcpServers", None)
|
| 89 |
+
if servers is None:
|
| 90 |
+
servers = get_config_value(self.config, "servers", {})
|
| 91 |
+
return servers or {}
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_dict(
|
| 95 |
+
cls,
|
| 96 |
+
config: dict[str, Any],
|
| 97 |
+
sandbox: bool = False,
|
| 98 |
+
sandbox_options: SandboxOptions | None = None,
|
| 99 |
+
timeout: float = 30.0,
|
| 100 |
+
sse_read_timeout: float = 300.0,
|
| 101 |
+
max_retries: int = 3,
|
| 102 |
+
retry_interval: float = 2.0,
|
| 103 |
+
) -> "MCPClient":
|
| 104 |
+
"""Create a MCPClient from a dictionary.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
config: The configuration dictionary.
|
| 108 |
+
sandbox: Whether to use sandboxed execution mode for running MCP servers.
|
| 109 |
+
sandbox_options: Optional sandbox configuration options.
|
| 110 |
+
timeout: Timeout for operations in seconds (default: 30.0)
|
| 111 |
+
sse_read_timeout: SSE read timeout in seconds (default: 300.0)
|
| 112 |
+
max_retries: Maximum number of retry attempts (default: 3)
|
| 113 |
+
retry_interval: Wait time between retries in seconds (default: 2.0)
|
| 114 |
+
"""
|
| 115 |
+
return cls(config=config, sandbox=sandbox, sandbox_options=sandbox_options,
|
| 116 |
+
timeout=timeout, sse_read_timeout=sse_read_timeout,
|
| 117 |
+
max_retries=max_retries, retry_interval=retry_interval)
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def from_config_file(
|
| 121 |
+
cls, filepath: str, sandbox: bool = False, sandbox_options: SandboxOptions | None = None,
|
| 122 |
+
timeout: float = 30.0, sse_read_timeout: float = 300.0,
|
| 123 |
+
max_retries: int = 3, retry_interval: float = 2.0,
|
| 124 |
+
) -> "MCPClient":
|
| 125 |
+
"""Create a MCPClient from a configuration file.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
filepath: The path to the configuration file.
|
| 129 |
+
sandbox: Whether to use sandboxed execution mode for running MCP servers.
|
| 130 |
+
sandbox_options: Optional sandbox configuration options.
|
| 131 |
+
timeout: Timeout for operations in seconds (default: 30.0)
|
| 132 |
+
sse_read_timeout: SSE read timeout in seconds (default: 300.0)
|
| 133 |
+
max_retries: Maximum number of retry attempts (default: 3)
|
| 134 |
+
retry_interval: Wait time between retries in seconds (default: 2.0)
|
| 135 |
+
"""
|
| 136 |
+
return cls(config=load_json_file(filepath), sandbox=sandbox, sandbox_options=sandbox_options,
|
| 137 |
+
timeout=timeout, sse_read_timeout=sse_read_timeout,
|
| 138 |
+
max_retries=max_retries, retry_interval=retry_interval)
|
| 139 |
+
|
| 140 |
+
def add_server(
|
| 141 |
+
self,
|
| 142 |
+
name: str,
|
| 143 |
+
server_config: dict[str, Any],
|
| 144 |
+
) -> None:
|
| 145 |
+
"""Add a server configuration.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
name: The name to identify this server.
|
| 149 |
+
server_config: The server configuration.
|
| 150 |
+
"""
|
| 151 |
+
mcp_servers = self._get_mcp_servers()
|
| 152 |
+
if "mcpServers" not in self.config:
|
| 153 |
+
self.config["mcpServers"] = {}
|
| 154 |
+
|
| 155 |
+
self.config["mcpServers"][name] = server_config
|
| 156 |
+
logger.debug(f"Added MCP server configuration: {name}")
|
| 157 |
+
|
| 158 |
+
def remove_server(self, name: str) -> None:
|
| 159 |
+
"""Remove a server configuration.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
name: The name of the server to remove.
|
| 163 |
+
"""
|
| 164 |
+
mcp_servers = self._get_mcp_servers()
|
| 165 |
+
if name in mcp_servers:
|
| 166 |
+
# Remove from config
|
| 167 |
+
if "mcpServers" in self.config:
|
| 168 |
+
self.config["mcpServers"].pop(name, None)
|
| 169 |
+
elif "servers" in self.config:
|
| 170 |
+
self.config["servers"].pop(name, None)
|
| 171 |
+
|
| 172 |
+
# If we removed an active session, remove it from active_sessions
|
| 173 |
+
if name in self.active_sessions:
|
| 174 |
+
self.active_sessions.remove(name)
|
| 175 |
+
|
| 176 |
+
logger.debug(f"Removed MCP server configuration: {name}")
|
| 177 |
+
else:
|
| 178 |
+
logger.warning(f"Server '{name}' not found in configuration")
|
| 179 |
+
|
| 180 |
+
def get_server_names(self) -> list[str]:
|
| 181 |
+
"""Get the list of configured server names.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
List of server names.
|
| 185 |
+
"""
|
| 186 |
+
return list(self._get_mcp_servers().keys())
|
| 187 |
+
|
| 188 |
+
def save_config(self, filepath: str) -> None:
|
| 189 |
+
"""Save the current configuration to a file.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
filepath: The path to save the configuration to.
|
| 193 |
+
"""
|
| 194 |
+
save_json_file(self.config, filepath)
|
| 195 |
+
|
| 196 |
+
async def create_session(self, server_name: str, auto_initialize: bool = True) -> MCPSession:
|
| 197 |
+
"""Create a session for the specified server with retry logic.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
server_name: The name of the server to create a session for.
|
| 201 |
+
auto_initialize: Whether to automatically initialize the session.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
The created MCPSession.
|
| 205 |
+
|
| 206 |
+
Raises:
|
| 207 |
+
ValueError: If the specified server doesn't exist.
|
| 208 |
+
Exception: If session creation fails after all retries.
|
| 209 |
+
"""
|
| 210 |
+
# Check if session already exists
|
| 211 |
+
if server_name in self.sessions:
|
| 212 |
+
logger.debug(f"Session for server '{server_name}' already exists, returning existing session")
|
| 213 |
+
return self.sessions[server_name]
|
| 214 |
+
|
| 215 |
+
# Get server config
|
| 216 |
+
servers = self._get_mcp_servers()
|
| 217 |
+
|
| 218 |
+
if not servers:
|
| 219 |
+
warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2)
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if server_name not in servers:
|
| 223 |
+
raise ValueError(f"Server '{server_name}' not found in config. Available: {list(servers.keys())}")
|
| 224 |
+
|
| 225 |
+
server_config = servers[server_name]
|
| 226 |
+
|
| 227 |
+
# Retry logic for session creation
|
| 228 |
+
last_exc: Exception | None = None
|
| 229 |
+
|
| 230 |
+
for attempt in range(1, self.max_retries + 1):
|
| 231 |
+
try:
|
| 232 |
+
# Create connector with options (now async)
|
| 233 |
+
connector = await create_connector_from_config(
|
| 234 |
+
server_config,
|
| 235 |
+
server_name=server_name,
|
| 236 |
+
sandbox=self.sandbox,
|
| 237 |
+
sandbox_options=self.sandbox_options,
|
| 238 |
+
timeout=self.timeout,
|
| 239 |
+
sse_read_timeout=self.sse_read_timeout,
|
| 240 |
+
installer=self.installer,
|
| 241 |
+
check_dependencies=self.check_dependencies,
|
| 242 |
+
tool_call_max_retries=self.tool_call_max_retries,
|
| 243 |
+
tool_call_retry_delay=self.tool_call_retry_delay,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Create the session with proper initialization parameters
|
| 247 |
+
session = MCPSession(
|
| 248 |
+
connector=connector,
|
| 249 |
+
session_id=f"mcp-{server_name}",
|
| 250 |
+
auto_connect=True,
|
| 251 |
+
auto_initialize=False, # We'll handle initialization explicitly below
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Initialize if requested
|
| 255 |
+
if auto_initialize:
|
| 256 |
+
await session.initialize()
|
| 257 |
+
logger.debug(f"Initialized session for server '{server_name}'")
|
| 258 |
+
|
| 259 |
+
# Store session
|
| 260 |
+
self.sessions[server_name] = session
|
| 261 |
+
|
| 262 |
+
# Add to active sessions
|
| 263 |
+
if server_name not in self.active_sessions:
|
| 264 |
+
self.active_sessions.append(server_name)
|
| 265 |
+
|
| 266 |
+
logger.info(f"Created session for MCP server '{server_name}' (attempt {attempt}/{self.max_retries})")
|
| 267 |
+
return session
|
| 268 |
+
|
| 269 |
+
except MCPDependencyError as e:
|
| 270 |
+
# Don't retry dependency errors - they won't succeed on retry
|
| 271 |
+
# Error already shown to user by installer, just re-raise
|
| 272 |
+
logger.debug(f"Dependency error for server '{server_name}': {type(e).__name__}")
|
| 273 |
+
raise
|
| 274 |
+
except Exception as e:
|
| 275 |
+
last_exc = e
|
| 276 |
+
if attempt == self.max_retries:
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
# Use info level for first attempt (common after fresh install), warning for subsequent
|
| 280 |
+
log_level = logger.info if attempt == 1 else logger.warning
|
| 281 |
+
log_level(
|
| 282 |
+
f"Failed to create session for server '{server_name}' (attempt {attempt}/{self.max_retries}): {e}, "
|
| 283 |
+
f"retrying in {self.retry_interval} seconds..."
|
| 284 |
+
)
|
| 285 |
+
await asyncio.sleep(self.retry_interval)
|
| 286 |
+
|
| 287 |
+
# All retries failed
|
| 288 |
+
error_msg = f"Failed to create session for server '{server_name}' after {self.max_retries} retries"
|
| 289 |
+
logger.error(error_msg)
|
| 290 |
+
raise last_exc or RuntimeError(error_msg)
|
| 291 |
+
|
| 292 |
+
async def create_all_sessions(
|
| 293 |
+
self,
|
| 294 |
+
auto_initialize: bool = True,
|
| 295 |
+
) -> dict[str, MCPSession]:
|
| 296 |
+
"""Create sessions for all configured servers.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
auto_initialize: Whether to automatically initialize the sessions.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Dictionary mapping server names to their MCPSession instances.
|
| 303 |
+
|
| 304 |
+
Warns:
|
| 305 |
+
UserWarning: If no servers are configured.
|
| 306 |
+
"""
|
| 307 |
+
servers = self._get_mcp_servers()
|
| 308 |
+
|
| 309 |
+
if not servers:
|
| 310 |
+
warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2)
|
| 311 |
+
return {}
|
| 312 |
+
|
| 313 |
+
# Create sessions for all servers (create_session already handles initialization)
|
| 314 |
+
logger.debug(f"Creating sessions for {len(servers)} servers")
|
| 315 |
+
for name in servers:
|
| 316 |
+
try:
|
| 317 |
+
await self.create_session(name, auto_initialize)
|
| 318 |
+
except Exception as e:
|
| 319 |
+
logger.error(f"Failed to create session for server '{name}': {e}")
|
| 320 |
+
|
| 321 |
+
logger.info(f"Created {len(self.sessions)} MCP sessions")
|
| 322 |
+
return self.sessions
|
| 323 |
+
|
| 324 |
+
def get_session(self, server_name: str) -> MCPSession:
|
| 325 |
+
"""Get an existing session.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
server_name: The name of the server to get the session for.
|
| 329 |
+
If None, uses the first active session.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
The MCPSession for the specified server.
|
| 333 |
+
|
| 334 |
+
Raises:
|
| 335 |
+
ValueError: If no active sessions exist or the specified session doesn't exist.
|
| 336 |
+
"""
|
| 337 |
+
if server_name not in self.sessions:
|
| 338 |
+
raise ValueError(f"No session exists for server '{server_name}'")
|
| 339 |
+
|
| 340 |
+
return self.sessions[server_name]
|
| 341 |
+
|
| 342 |
+
def get_all_active_sessions(self) -> dict[str, MCPSession]:
|
| 343 |
+
"""Get all active sessions.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Dictionary mapping server names to their MCPSession instances.
|
| 347 |
+
"""
|
| 348 |
+
return {name: self.sessions[name] for name in self.active_sessions if name in self.sessions}
|
| 349 |
+
|
| 350 |
+
async def close_session(self, server_name: str) -> None:
|
| 351 |
+
"""Close a session.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
server_name: The name of the server to close the session for.
|
| 355 |
+
|
| 356 |
+
Raises:
|
| 357 |
+
ValueError: If no active sessions exist or the specified session doesn't exist.
|
| 358 |
+
"""
|
| 359 |
+
# Check if the session exists
|
| 360 |
+
if server_name not in self.sessions:
|
| 361 |
+
logger.warning(f"No session exists for server '{server_name}', nothing to close")
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
# Get the session
|
| 365 |
+
session = self.sessions[server_name]
|
| 366 |
+
error_occurred = False
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
# Disconnect from the session
|
| 370 |
+
logger.debug(f"Closing session for server '{server_name}'")
|
| 371 |
+
await session.disconnect()
|
| 372 |
+
logger.info(f"Successfully closed session for server '{server_name}'")
|
| 373 |
+
except Exception as e:
|
| 374 |
+
error_occurred = True
|
| 375 |
+
logger.error(f"Error closing session for server '{server_name}': {e}")
|
| 376 |
+
finally:
|
| 377 |
+
# Remove the session regardless of whether disconnect succeeded
|
| 378 |
+
self.sessions.pop(server_name, None)
|
| 379 |
+
|
| 380 |
+
# Remove from active_sessions
|
| 381 |
+
if server_name in self.active_sessions:
|
| 382 |
+
self.active_sessions.remove(server_name)
|
| 383 |
+
|
| 384 |
+
if error_occurred:
|
| 385 |
+
logger.warning(f"Session for '{server_name}' removed from tracking despite disconnect error")
|
| 386 |
+
|
| 387 |
+
async def close_all_sessions(self) -> None:
|
| 388 |
+
"""Close all active sessions.
|
| 389 |
+
|
| 390 |
+
This method ensures all sessions are closed even if some fail.
|
| 391 |
+
"""
|
| 392 |
+
# Get a list of all session names first to avoid modification during iteration
|
| 393 |
+
server_names = list(self.sessions.keys())
|
| 394 |
+
errors = []
|
| 395 |
+
|
| 396 |
+
for server_name in server_names:
|
| 397 |
+
try:
|
| 398 |
+
logger.debug(f"Closing session for server '{server_name}'")
|
| 399 |
+
await self.close_session(server_name)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
error_msg = f"Failed to close session for server '{server_name}': {e}"
|
| 402 |
+
logger.error(error_msg)
|
| 403 |
+
errors.append(error_msg)
|
| 404 |
+
|
| 405 |
+
# Log summary if there were errors
|
| 406 |
+
if errors:
|
| 407 |
+
logger.error(f"Encountered {len(errors)} errors while closing sessions")
|
| 408 |
+
else:
|
| 409 |
+
logger.debug("All sessions closed successfully")
|
openspace/grounding/backends/mcp/config.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration loader for MCP session.
|
| 3 |
+
|
| 4 |
+
This module provides functionality to load MCP configuration from JSON files.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
|
| 9 |
+
from openspace.grounding.core.types import SandboxOptions
|
| 10 |
+
from openspace.config.utils import get_config_value
|
| 11 |
+
from .transport.connectors import (
|
| 12 |
+
MCPBaseConnector,
|
| 13 |
+
HttpConnector,
|
| 14 |
+
SandboxConnector,
|
| 15 |
+
StdioConnector,
|
| 16 |
+
WebSocketConnector,
|
| 17 |
+
)
|
| 18 |
+
from .transport.connectors.utils import is_stdio_server
|
| 19 |
+
from .installer import MCPInstallerManager
|
| 20 |
+
|
| 21 |
+
# Import E2BSandbox
|
| 22 |
+
try:
|
| 23 |
+
from openspace.grounding.core.security import E2BSandbox
|
| 24 |
+
E2B_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
E2BSandbox = None
|
| 27 |
+
E2B_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
async def create_connector_from_config(
|
| 30 |
+
server_config: dict[str, Any],
|
| 31 |
+
server_name: str = "unknown",
|
| 32 |
+
sandbox: bool = False,
|
| 33 |
+
sandbox_options: SandboxOptions | None = None,
|
| 34 |
+
timeout: float = 30.0,
|
| 35 |
+
sse_read_timeout: float = 300.0,
|
| 36 |
+
installer: Optional[MCPInstallerManager] = None,
|
| 37 |
+
check_dependencies: bool = True,
|
| 38 |
+
tool_call_max_retries: int = 3,
|
| 39 |
+
tool_call_retry_delay: float = 1.0,
|
| 40 |
+
) -> MCPBaseConnector:
|
| 41 |
+
"""Create a connector based on server configuration.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
server_config: The server configuration section
|
| 45 |
+
server_name: Name of the MCP server (for display purposes)
|
| 46 |
+
sandbox: Whether to use sandboxed execution mode for running MCP servers.
|
| 47 |
+
sandbox_options: Optional sandbox configuration options.
|
| 48 |
+
timeout: Timeout for operations in seconds (default: 30.0)
|
| 49 |
+
sse_read_timeout: SSE read timeout in seconds (default: 300.0)
|
| 50 |
+
installer: Optional installer manager for dependency installation
|
| 51 |
+
check_dependencies: Whether to check and install dependencies (default: True)
|
| 52 |
+
tool_call_max_retries: Maximum number of retries for tool calls (default: 3)
|
| 53 |
+
tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
A configured connector instance
|
| 57 |
+
|
| 58 |
+
Raises:
|
| 59 |
+
RuntimeError: If dependencies are not installed and user declines installation
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Get original command and args from config
|
| 63 |
+
original_command = get_config_value(server_config, "command")
|
| 64 |
+
original_args = get_config_value(server_config, "args", [])
|
| 65 |
+
|
| 66 |
+
# Check and install dependencies if needed (only for stdio servers)
|
| 67 |
+
if is_stdio_server(server_config) and check_dependencies:
|
| 68 |
+
# Use provided installer or get global instance
|
| 69 |
+
if installer is None:
|
| 70 |
+
from .installer import get_global_installer
|
| 71 |
+
installer = get_global_installer()
|
| 72 |
+
|
| 73 |
+
# Ensure dependencies are installed (using original command/args)
|
| 74 |
+
await installer.ensure_dependencies(server_name, original_command, original_args)
|
| 75 |
+
|
| 76 |
+
# Stdio connector (command-based)
|
| 77 |
+
if is_stdio_server(server_config) and not sandbox:
|
| 78 |
+
return StdioConnector(
|
| 79 |
+
command=get_config_value(server_config, "command"),
|
| 80 |
+
args=get_config_value(server_config, "args"),
|
| 81 |
+
env=get_config_value(server_config, "env", None),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Sandboxed connector
|
| 85 |
+
elif is_stdio_server(server_config) and sandbox:
|
| 86 |
+
if not E2B_AVAILABLE:
|
| 87 |
+
raise ImportError(
|
| 88 |
+
"E2B sandbox support not available. Please install e2b-code-interpreter: "
|
| 89 |
+
"'pip install e2b-code-interpreter'"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Create E2B sandbox instance
|
| 93 |
+
_sandbox_options = sandbox_options or {}
|
| 94 |
+
e2b_sandbox = E2BSandbox(_sandbox_options)
|
| 95 |
+
|
| 96 |
+
# Extract timeout values from sandbox_options or use defaults
|
| 97 |
+
connector_timeout = _sandbox_options.get("timeout", timeout)
|
| 98 |
+
connector_sse_timeout = _sandbox_options.get("sse_read_timeout", sse_read_timeout)
|
| 99 |
+
|
| 100 |
+
# Create and return sandbox connector
|
| 101 |
+
return SandboxConnector(
|
| 102 |
+
sandbox=e2b_sandbox,
|
| 103 |
+
command=get_config_value(server_config, "command"),
|
| 104 |
+
args=get_config_value(server_config, "args"),
|
| 105 |
+
env=get_config_value(server_config, "env", None),
|
| 106 |
+
supergateway_command=_sandbox_options.get("supergateway_command", "npx -y supergateway"),
|
| 107 |
+
port=_sandbox_options.get("port", 3000),
|
| 108 |
+
timeout=connector_timeout,
|
| 109 |
+
sse_read_timeout=connector_sse_timeout,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# HTTP connector
|
| 113 |
+
elif "url" in server_config:
|
| 114 |
+
return HttpConnector(
|
| 115 |
+
base_url=get_config_value(server_config, "url"),
|
| 116 |
+
headers=get_config_value(server_config, "headers", None),
|
| 117 |
+
auth_token=get_config_value(server_config, "auth_token", None),
|
| 118 |
+
timeout=timeout,
|
| 119 |
+
sse_read_timeout=sse_read_timeout,
|
| 120 |
+
tool_call_max_retries=tool_call_max_retries,
|
| 121 |
+
tool_call_retry_delay=tool_call_retry_delay,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# WebSocket connector
|
| 125 |
+
elif "ws_url" in server_config:
|
| 126 |
+
return WebSocketConnector(
|
| 127 |
+
url=get_config_value(server_config, "ws_url"),
|
| 128 |
+
headers=get_config_value(server_config, "headers", None),
|
| 129 |
+
auth_token=get_config_value(server_config, "auth_token", None),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
raise ValueError("Cannot determine connector type from config")
|
openspace/grounding/backends/mcp/installer.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import sys
|
| 3 |
+
import shutil
|
| 4 |
+
from typing import Callable, Awaitable, Optional, Dict, List
|
| 5 |
+
from openspace.utils.logging import Logger
|
| 6 |
+
|
| 7 |
+
logger = Logger.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
PromptFunc = Callable[[str], Awaitable[bool]]
|
| 10 |
+
|
| 11 |
+
# Global lock to prevent concurrent user prompts
|
| 12 |
+
_prompt_lock = asyncio.Lock()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MCPDependencyError(RuntimeError):
|
| 16 |
+
"""Base exception for MCP dependency errors."""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MCPCommandNotFoundError(MCPDependencyError):
|
| 21 |
+
"""Raised when a required command is not available."""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MCPInstallationCancelledError(MCPDependencyError):
|
| 26 |
+
"""Raised when user cancels installation."""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MCPInstallationFailedError(MCPDependencyError):
|
| 31 |
+
"""Raised when installation fails."""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Colors:
|
| 36 |
+
RESET = "\033[0m"
|
| 37 |
+
BOLD = "\033[1m"
|
| 38 |
+
RED = "\033[91m"
|
| 39 |
+
YELLOW = "\033[93m"
|
| 40 |
+
GREEN = "\033[92m"
|
| 41 |
+
CYAN = "\033[96m"
|
| 42 |
+
GRAY = "\033[90m"
|
| 43 |
+
WHITE = "\033[97m"
|
| 44 |
+
BLUE = "\033[94m"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MCPInstallerManager:
|
| 48 |
+
"""
|
| 49 |
+
MCP dependencies package installer manager.
|
| 50 |
+
|
| 51 |
+
Responsible for detecting if the MCP server dependencies are installed, and if not, asking the user whether to install them.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, prompt: PromptFunc | None = None, auto_install: bool = False, verbose: bool = False):
|
| 55 |
+
"""Initialize the installer manager.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
prompt: Custom user prompt function, if None, the default CLI prompt is used
|
| 59 |
+
auto_install: If True, automatically install dependencies without asking the user
|
| 60 |
+
verbose: If True, show detailed installation logs; if False, only show progress indicator
|
| 61 |
+
"""
|
| 62 |
+
self._prompt: PromptFunc | None = prompt or self._default_cli_prompt
|
| 63 |
+
self._auto_install = auto_install
|
| 64 |
+
self._verbose = verbose
|
| 65 |
+
self._installed_cache: Dict[str, bool] = {} # Cache for checked packages
|
| 66 |
+
self._failed_installations: Dict[str, str] = {} # Track failed installations to avoid retry
|
| 67 |
+
|
| 68 |
+
async def _default_cli_prompt(self, message: str) -> bool:
|
| 69 |
+
"""Default CLI prompt function (called within lock by ensure_dependencies)."""
|
| 70 |
+
from openspace.utils.display import print_separator, colorize
|
| 71 |
+
|
| 72 |
+
print()
|
| 73 |
+
print_separator(70, 'c', 2)
|
| 74 |
+
print(f" {colorize('MCP dependencies installation prompt', color=Colors.BLUE, bold=True)}")
|
| 75 |
+
print_separator(70, 'c', 2)
|
| 76 |
+
print(f" {message}")
|
| 77 |
+
print_separator(70, 'gr', 2)
|
| 78 |
+
print(f" {colorize('[y/yes]', color=Colors.GREEN)} Install | {colorize('[n/no]', color=Colors.RED)} Cancel")
|
| 79 |
+
print_separator(70, 'gr', 2)
|
| 80 |
+
print(f" {colorize('Your choice:', bold=True)} ", end="", flush=True)
|
| 81 |
+
|
| 82 |
+
answer = await asyncio.get_running_loop().run_in_executor(None, sys.stdin.readline)
|
| 83 |
+
response = answer.strip().lower() in {"y", "yes"}
|
| 84 |
+
|
| 85 |
+
if response:
|
| 86 |
+
print(f"{Colors.GREEN}✓ Installation confirmed{Colors.RESET}\n")
|
| 87 |
+
else:
|
| 88 |
+
print(f"{Colors.RED}✗ Installation cancelled{Colors.RESET}\n")
|
| 89 |
+
|
| 90 |
+
return response
|
| 91 |
+
|
| 92 |
+
async def _ask_user(self, message: str) -> bool:
|
| 93 |
+
"""Ask the user whether to install."""
|
| 94 |
+
if self._auto_install:
|
| 95 |
+
logger.info("Automatic installation mode enabled, will automatically install dependencies")
|
| 96 |
+
return True
|
| 97 |
+
|
| 98 |
+
if self._prompt:
|
| 99 |
+
try:
|
| 100 |
+
return await self._prompt(message)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"Error asking user: {e}")
|
| 103 |
+
return False
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def _check_command_available(self, command: str) -> bool:
|
| 107 |
+
"""Check if the command is available.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
command: The command to check (e.g. "npx", "uvx")
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
bool: Whether the command is available
|
| 114 |
+
"""
|
| 115 |
+
return shutil.which(command) is not None
|
| 116 |
+
|
| 117 |
+
async def _check_package_installed(self, command: str, args: List[str]) -> bool:
|
| 118 |
+
"""Check if the package is installed.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
command: The command to check (e.g. "npx", "uvx")
|
| 122 |
+
args: The arguments list
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
bool: Whether the package is installed
|
| 126 |
+
"""
|
| 127 |
+
# Build cache key
|
| 128 |
+
cache_key = f"{command}:{':'.join(args)}"
|
| 129 |
+
|
| 130 |
+
# Check cache
|
| 131 |
+
if cache_key in self._installed_cache:
|
| 132 |
+
return self._installed_cache[cache_key]
|
| 133 |
+
|
| 134 |
+
# For different types of commands, use different check methods
|
| 135 |
+
try:
|
| 136 |
+
if command == "npx":
|
| 137 |
+
# For npx, check if the npm package exists
|
| 138 |
+
package_name = self._extract_npm_package(args)
|
| 139 |
+
if package_name:
|
| 140 |
+
result = await self._check_npm_package(package_name)
|
| 141 |
+
self._installed_cache[cache_key] = result
|
| 142 |
+
return result
|
| 143 |
+
elif command == "uvx":
|
| 144 |
+
# For uvx, check if the Python package exists
|
| 145 |
+
package_name = self._extract_python_package(args)
|
| 146 |
+
if package_name:
|
| 147 |
+
result = await self._check_python_package(package_name)
|
| 148 |
+
self._installed_cache[cache_key] = result
|
| 149 |
+
return result
|
| 150 |
+
elif command == "uv":
|
| 151 |
+
# For "uv run --with package ...", check if the Python package exists
|
| 152 |
+
package_name = self._extract_uv_package(args)
|
| 153 |
+
if package_name:
|
| 154 |
+
result = await self._check_uv_pip_package(package_name)
|
| 155 |
+
self._installed_cache[cache_key] = result
|
| 156 |
+
return result
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.debug(f"Error checking package installation status: {e}")
|
| 159 |
+
|
| 160 |
+
# Default to assuming not installed
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
def _extract_npm_package(self, args: List[str]) -> Optional[str]:
|
| 164 |
+
"""Extract package name from npx arguments.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
args: npx arguments list, e.g. ["-y", "mcp-excalidraw-server"] or ["bazi-mcp"]
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Package name (without version tag) or None
|
| 171 |
+
"""
|
| 172 |
+
for i, arg in enumerate(args):
|
| 173 |
+
# Skip option parameters
|
| 174 |
+
if arg.startswith("-"):
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
# Found package name, now strip version tag
|
| 178 |
+
package_name = arg
|
| 179 |
+
|
| 180 |
+
# Handle scoped packages: @scope/package@version -> @scope/package
|
| 181 |
+
if package_name.startswith("@"):
|
| 182 |
+
# Scoped package like @rtuin/mcp-mermaid-validator@latest
|
| 183 |
+
parts = package_name.split("/", 1)
|
| 184 |
+
if len(parts) == 2:
|
| 185 |
+
scope = parts[0]
|
| 186 |
+
name_with_version = parts[1]
|
| 187 |
+
# Remove version tag from name part (e.g., "pkg@latest" -> "pkg")
|
| 188 |
+
name = name_with_version.split("@")[0] if "@" in name_with_version else name_with_version
|
| 189 |
+
return f"{scope}/{name}"
|
| 190 |
+
return package_name
|
| 191 |
+
else:
|
| 192 |
+
# Regular package like mcp-deepwiki@latest -> mcp-deepwiki
|
| 193 |
+
return package_name.split("@")[0] if "@" in package_name else package_name
|
| 194 |
+
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
def _extract_python_package(self, args: List[str]) -> Optional[str]:
|
| 198 |
+
"""Extract package name from uvx arguments.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
args: uvx arguments list, e.g. ["--from", "office-powerpoint-mcp-server", "ppt_mcp_server"]
|
| 202 |
+
or ["--with", "mcp==1.9.0", "sitemap-mcp-server"]
|
| 203 |
+
or ["arxiv-mcp-server", "--storage-path", "./path"]
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Package name or None
|
| 207 |
+
"""
|
| 208 |
+
# Find --from parameter (this is the package to install)
|
| 209 |
+
for i, arg in enumerate(args):
|
| 210 |
+
if arg == "--from" and i + 1 < len(args):
|
| 211 |
+
return args[i + 1]
|
| 212 |
+
|
| 213 |
+
# Skip option flags and their values, find the main package (FIRST positional arg)
|
| 214 |
+
# Options that take a value: --with, --python, --from, --storage-path, etc.
|
| 215 |
+
options_with_value = {"--with", "--from", "--python", "-p", "--storage-path"}
|
| 216 |
+
skip_next = False
|
| 217 |
+
|
| 218 |
+
for arg in args:
|
| 219 |
+
if skip_next:
|
| 220 |
+
skip_next = False
|
| 221 |
+
continue
|
| 222 |
+
if arg in options_with_value:
|
| 223 |
+
skip_next = True
|
| 224 |
+
continue
|
| 225 |
+
if arg.startswith("-"):
|
| 226 |
+
# Other flags without values (or unknown options with values)
|
| 227 |
+
# Also skip the next arg if it looks like an option value (doesn't start with -)
|
| 228 |
+
continue
|
| 229 |
+
# First non-option argument is the package name
|
| 230 |
+
return arg
|
| 231 |
+
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
def _extract_uv_package(self, args: List[str]) -> Optional[str]:
|
| 235 |
+
"""Extract package name from uv run arguments.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
args: uv arguments list, e.g. ["run", "--with", "biomcp-python", "biomcp", "run"]
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Package name or None
|
| 242 |
+
"""
|
| 243 |
+
# Find --with parameter (this specifies the package to install)
|
| 244 |
+
for i, arg in enumerate(args):
|
| 245 |
+
if arg == "--with" and i + 1 < len(args):
|
| 246 |
+
package_name = args[i + 1]
|
| 247 |
+
# Remove version specifier if present (e.g., "mcp==1.9.0" -> "mcp")
|
| 248 |
+
if "==" in package_name:
|
| 249 |
+
return package_name.split("==")[0]
|
| 250 |
+
if ">=" in package_name:
|
| 251 |
+
return package_name.split(">=")[0]
|
| 252 |
+
return package_name
|
| 253 |
+
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
async def _check_npm_package(self, package_name: str) -> bool:
|
| 257 |
+
"""Check if the npm package is globally installed.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
package_name: npm package name
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
bool: Whether the npm package is installed
|
| 264 |
+
"""
|
| 265 |
+
try:
|
| 266 |
+
process = await asyncio.create_subprocess_exec(
|
| 267 |
+
"npm", "list", "-g", package_name,
|
| 268 |
+
stdout=asyncio.subprocess.PIPE,
|
| 269 |
+
stderr=asyncio.subprocess.PIPE
|
| 270 |
+
)
|
| 271 |
+
stdout, stderr = await process.communicate()
|
| 272 |
+
|
| 273 |
+
# npm list returns 0 if the package is installed
|
| 274 |
+
return process.returncode == 0
|
| 275 |
+
except Exception as e:
|
| 276 |
+
logger.debug(f"Error checking npm package {package_name}: {e}")
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
async def _check_python_package(self, package_name: str) -> bool:
|
| 280 |
+
"""Check if the Python package is installed as a uvx tool.
|
| 281 |
+
|
| 282 |
+
uvx tools are installed in ~/.local/share/uv/tools/ directory,
|
| 283 |
+
not in the current pip environment.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
package_name: Python package/tool name
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
bool: Whether the uvx tool is installed
|
| 290 |
+
"""
|
| 291 |
+
import os
|
| 292 |
+
from pathlib import Path
|
| 293 |
+
|
| 294 |
+
# Strip version specifier if present (e.g., "mcp==1.9.0" -> "mcp")
|
| 295 |
+
clean_name = package_name.split("==")[0].split(">=")[0].split("<=")[0].split(">")[0].split("<")[0]
|
| 296 |
+
|
| 297 |
+
# Check if uvx tool exists in the standard uv tools directory
|
| 298 |
+
uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
|
| 299 |
+
tool_dir = uv_tools_dir / clean_name
|
| 300 |
+
|
| 301 |
+
if tool_dir.exists():
|
| 302 |
+
logger.debug(f"uvx tool '{clean_name}' found at {tool_dir}")
|
| 303 |
+
return True
|
| 304 |
+
|
| 305 |
+
# Fallback: try running uvx with --help to check if it's available
|
| 306 |
+
try:
|
| 307 |
+
process = await asyncio.create_subprocess_exec(
|
| 308 |
+
"uvx", clean_name, "--help",
|
| 309 |
+
stdout=asyncio.subprocess.PIPE,
|
| 310 |
+
stderr=asyncio.subprocess.PIPE
|
| 311 |
+
)
|
| 312 |
+
# Just wait briefly, don't need the full output
|
| 313 |
+
try:
|
| 314 |
+
await asyncio.wait_for(process.communicate(), timeout=5.0)
|
| 315 |
+
except asyncio.TimeoutError:
|
| 316 |
+
process.kill()
|
| 317 |
+
await process.wait()
|
| 318 |
+
|
| 319 |
+
# If it didn't error immediately, the tool likely exists
|
| 320 |
+
return process.returncode == 0
|
| 321 |
+
except Exception as e:
|
| 322 |
+
logger.debug(f"Error checking uvx tool {clean_name}: {e}")
|
| 323 |
+
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
async def _check_uv_pip_package(self, package_name: str) -> bool:
|
| 327 |
+
"""Check if a Python package is installed via uv pip.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
package_name: Python package name
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
bool: Whether the package is installed
|
| 334 |
+
"""
|
| 335 |
+
# Strip version specifier if present
|
| 336 |
+
clean_name = package_name.split("==")[0].split(">=")[0].split("<=")[0].split(">")[0].split("<")[0]
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
# Try using uv pip show to check if package is installed
|
| 340 |
+
process = await asyncio.create_subprocess_exec(
|
| 341 |
+
"uv", "pip", "show", clean_name,
|
| 342 |
+
stdout=asyncio.subprocess.PIPE,
|
| 343 |
+
stderr=asyncio.subprocess.PIPE
|
| 344 |
+
)
|
| 345 |
+
stdout, stderr = await process.communicate()
|
| 346 |
+
|
| 347 |
+
if process.returncode == 0:
|
| 348 |
+
logger.debug(f"uv pip package '{clean_name}' found")
|
| 349 |
+
return True
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.debug(f"Error checking uv pip package {clean_name}: {e}")
|
| 352 |
+
|
| 353 |
+
# Fallback: check with regular pip
|
| 354 |
+
try:
|
| 355 |
+
process = await asyncio.create_subprocess_exec(
|
| 356 |
+
"pip", "show", clean_name,
|
| 357 |
+
stdout=asyncio.subprocess.PIPE,
|
| 358 |
+
stderr=asyncio.subprocess.PIPE
|
| 359 |
+
)
|
| 360 |
+
stdout, stderr = await process.communicate()
|
| 361 |
+
|
| 362 |
+
return process.returncode == 0
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.debug(f"Error checking pip package {clean_name}: {e}")
|
| 365 |
+
|
| 366 |
+
return False
|
| 367 |
+
|
| 368 |
+
async def _install_package(self, command: str, args: List[str], use_sudo: bool = False) -> bool:
|
| 369 |
+
"""Execute the install command.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
command: The command to execute (e.g. "npx", "uvx")
|
| 373 |
+
args: The arguments list
|
| 374 |
+
use_sudo: Whether to use sudo for installation
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
bool: Whether the installation is successful
|
| 378 |
+
"""
|
| 379 |
+
install_command = self._get_install_command(command, args)
|
| 380 |
+
|
| 381 |
+
if not install_command:
|
| 382 |
+
logger.error("Cannot determine install command")
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
# Add sudo if requested
|
| 386 |
+
if use_sudo:
|
| 387 |
+
install_command = ["sudo"] + install_command
|
| 388 |
+
|
| 389 |
+
logger.info(f"Executing install command: {' '.join(install_command)}")
|
| 390 |
+
|
| 391 |
+
try:
|
| 392 |
+
# For sudo commands, always show verbose output so password prompt is visible
|
| 393 |
+
if self._verbose or use_sudo:
|
| 394 |
+
# Verbose mode: show all installation logs
|
| 395 |
+
from openspace.utils.display import print_separator, colorize
|
| 396 |
+
|
| 397 |
+
print_separator(70, 'c', 2)
|
| 398 |
+
if use_sudo:
|
| 399 |
+
print(f" {colorize('Installing with administrator privileges...', color=Colors.BLUE)}")
|
| 400 |
+
print(f" {colorize('>> You will be prompted for your password below <<', color=Colors.YELLOW)}")
|
| 401 |
+
else:
|
| 402 |
+
print(f" {colorize('Installing dependencies...', color=Colors.BLUE)}")
|
| 403 |
+
print(f" {colorize('Command: ' + ' '.join(install_command), color=Colors.GRAY)}")
|
| 404 |
+
print_separator(70, 'c', 2)
|
| 405 |
+
print()
|
| 406 |
+
|
| 407 |
+
# For sudo, don't redirect stdin so password prompt works
|
| 408 |
+
if use_sudo:
|
| 409 |
+
process = await asyncio.create_subprocess_exec(
|
| 410 |
+
*install_command,
|
| 411 |
+
stdout=asyncio.subprocess.PIPE,
|
| 412 |
+
stderr=asyncio.subprocess.STDOUT,
|
| 413 |
+
stdin=None # Let sudo use terminal for password
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
process = await asyncio.create_subprocess_exec(
|
| 417 |
+
*install_command,
|
| 418 |
+
stdout=asyncio.subprocess.PIPE,
|
| 419 |
+
stderr=asyncio.subprocess.STDOUT
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Real-time output of installation logs
|
| 423 |
+
output_lines = []
|
| 424 |
+
while True:
|
| 425 |
+
line = await process.stdout.readline()
|
| 426 |
+
if not line:
|
| 427 |
+
break
|
| 428 |
+
line_str = line.decode().rstrip()
|
| 429 |
+
output_lines.append(line_str)
|
| 430 |
+
print(f"{Colors.GRAY}{line_str}{Colors.RESET}")
|
| 431 |
+
|
| 432 |
+
await process.wait()
|
| 433 |
+
full_output = '\n'.join(output_lines)
|
| 434 |
+
else:
|
| 435 |
+
# Quiet mode: only show progress indicator
|
| 436 |
+
print(f"\n{Colors.BLUE}Installing dependencies...{Colors.RESET} ", end="", flush=True)
|
| 437 |
+
|
| 438 |
+
process = await asyncio.create_subprocess_exec(
|
| 439 |
+
*install_command,
|
| 440 |
+
stdout=asyncio.subprocess.PIPE,
|
| 441 |
+
stderr=asyncio.subprocess.PIPE
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Show spinner animation while installing
|
| 445 |
+
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
| 446 |
+
spinner_idx = 0
|
| 447 |
+
|
| 448 |
+
while True:
|
| 449 |
+
try:
|
| 450 |
+
await asyncio.wait_for(process.wait(), timeout=0.1)
|
| 451 |
+
break
|
| 452 |
+
except asyncio.TimeoutError:
|
| 453 |
+
print(f"\r{Colors.BLUE}Installing dependencies...{Colors.RESET} {Colors.CYAN}{spinner[spinner_idx]}{Colors.RESET}", end="", flush=True)
|
| 454 |
+
spinner_idx = (spinner_idx + 1) % len(spinner)
|
| 455 |
+
|
| 456 |
+
# Clear the spinner line
|
| 457 |
+
print(f"\r{' ' * 100}\r", end="", flush=True)
|
| 458 |
+
|
| 459 |
+
# Collect output
|
| 460 |
+
stdout, stderr = await process.communicate()
|
| 461 |
+
full_output = (stdout or stderr).decode() if (stdout or stderr) else ""
|
| 462 |
+
|
| 463 |
+
if process.returncode == 0:
|
| 464 |
+
print(f"{Colors.GREEN}✓ Dependencies installed successfully{Colors.RESET}")
|
| 465 |
+
if not use_sudo:
|
| 466 |
+
print(f"{Colors.GRAY}(Note: First connection may take a moment to initialize){Colors.RESET}")
|
| 467 |
+
# Update cache
|
| 468 |
+
cache_key = f"{command}:{':'.join(args)}"
|
| 469 |
+
self._installed_cache[cache_key] = True
|
| 470 |
+
return True
|
| 471 |
+
else:
|
| 472 |
+
# Check if it's a permission error
|
| 473 |
+
is_permission_error = "EACCES" in full_output or "permission denied" in full_output.lower()
|
| 474 |
+
|
| 475 |
+
if is_permission_error and not use_sudo:
|
| 476 |
+
print(f"\n{Colors.YELLOW}Permission denied{Colors.RESET}")
|
| 477 |
+
print(f"{Colors.GRAY}The installation requires administrator privileges.{Colors.RESET}\n")
|
| 478 |
+
|
| 479 |
+
# Ask user if they want to use sudo
|
| 480 |
+
message = (
|
| 481 |
+
f"\n{Colors.WHITE}Administrator privileges required{Colors.RESET}\n\n"
|
| 482 |
+
f"Command: {Colors.GRAY}{' '.join(install_command)}{Colors.RESET}\n\n"
|
| 483 |
+
f"{Colors.YELLOW}Do you want to retry with sudo (requires password)?{Colors.RESET}"
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if await self._ask_user(message):
|
| 487 |
+
# No extra print needed, the verbose mode will show clear instructions
|
| 488 |
+
return await self._install_package(command, args, use_sudo=True)
|
| 489 |
+
else:
|
| 490 |
+
print(f"\n{Colors.RED}✗ Installation cancelled{Colors.RESET}")
|
| 491 |
+
return False
|
| 492 |
+
else:
|
| 493 |
+
print(f"{Colors.RED}✗ Dependencies installation failed (return code: {process.returncode}){Colors.RESET}")
|
| 494 |
+
# Show error output if not already shown
|
| 495 |
+
if not self._verbose and full_output:
|
| 496 |
+
# Limit error output to last 20 lines
|
| 497 |
+
error_lines = full_output.split('\n')
|
| 498 |
+
if len(error_lines) > 20:
|
| 499 |
+
error_lines = ['...(truncated)...'] + error_lines[-20:]
|
| 500 |
+
print(f"{Colors.GRAY}Error output:\n{chr(10).join(error_lines)}{Colors.RESET}")
|
| 501 |
+
|
| 502 |
+
# Add general guidance for manual installation
|
| 503 |
+
print(f"\n{Colors.YELLOW}Tip:{Colors.RESET} {Colors.GRAY}If automatic installation fails, please refer to the")
|
| 504 |
+
print(f"official documentation of the MCP server for manual installation instructions.{Colors.RESET}\n")
|
| 505 |
+
|
| 506 |
+
return False
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
logger.error(f"Error installing dependencies: {e}")
|
| 510 |
+
print(f"{Colors.RED}✗ Error occurred during installation: {e}{Colors.RESET}")
|
| 511 |
+
return False
|
| 512 |
+
|
| 513 |
+
def _get_install_command(self, command: str, args: List[str]) -> Optional[List[str]]:
|
| 514 |
+
"""Generate install command based on command type.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
command: The command to execute (e.g. "npx", "uvx", "uv")
|
| 518 |
+
args: The original arguments list
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
Install command list or None
|
| 522 |
+
"""
|
| 523 |
+
if command == "npx":
|
| 524 |
+
package_name = self._extract_npm_package(args)
|
| 525 |
+
if package_name:
|
| 526 |
+
return ["npm", "install", "-g", package_name]
|
| 527 |
+
elif command == "uvx":
|
| 528 |
+
package_name = self._extract_python_package(args)
|
| 529 |
+
if package_name:
|
| 530 |
+
return ["pip", "install", package_name]
|
| 531 |
+
elif command == "uv":
|
| 532 |
+
# Handle "uv run --with package_name ..." format
|
| 533 |
+
package_name = self._extract_uv_package(args)
|
| 534 |
+
if package_name:
|
| 535 |
+
return ["uv", "pip", "install", package_name]
|
| 536 |
+
|
| 537 |
+
return None
|
| 538 |
+
|
| 539 |
+
async def ensure_dependencies(
|
| 540 |
+
self,
|
| 541 |
+
server_name: str,
|
| 542 |
+
command: str,
|
| 543 |
+
args: List[str]
|
| 544 |
+
) -> bool:
|
| 545 |
+
"""Ensure the dependencies of the MCP server are installed.
|
| 546 |
+
|
| 547 |
+
This method checks if the dependencies are installed, and if not, asks the user whether to install them.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
server_name: MCP server name (for display purposes)
|
| 551 |
+
command: The command to execute (e.g. "npx", "uvx")
|
| 552 |
+
args: The arguments list
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
bool: Whether the dependencies are installed (installed or successfully installed)
|
| 556 |
+
|
| 557 |
+
Raises:
|
| 558 |
+
RuntimeError: When the command is not available or the user refuses to install
|
| 559 |
+
"""
|
| 560 |
+
# Use lock to ensure entire installation process is atomic
|
| 561 |
+
async with _prompt_lock:
|
| 562 |
+
return await self._ensure_dependencies_impl(server_name, command, args)
|
| 563 |
+
|
| 564 |
+
async def _ensure_dependencies_impl(
|
| 565 |
+
self,
|
| 566 |
+
server_name: str,
|
| 567 |
+
command: str,
|
| 568 |
+
args: List[str]
|
| 569 |
+
) -> bool:
|
| 570 |
+
"""Internal implementation of ensure_dependencies (called within lock)."""
|
| 571 |
+
# Skip dependency checking for direct script execution commands
|
| 572 |
+
# These commands run scripts directly and don't need package installation
|
| 573 |
+
SKIP_COMMANDS = {"node", "python", "python3", "bash", "sh", "deno", "bun"}
|
| 574 |
+
|
| 575 |
+
if command.lower() in SKIP_COMMANDS:
|
| 576 |
+
logger.debug(f"Skipping dependency check for direct script execution command: {command}")
|
| 577 |
+
return True
|
| 578 |
+
|
| 579 |
+
# Skip dependency checking for GitHub-based npx packages
|
| 580 |
+
# These packages are handled directly by npx which downloads, builds, and runs them
|
| 581 |
+
# npm install -g doesn't work properly for GitHub packages that require building
|
| 582 |
+
if command == "npx":
|
| 583 |
+
package_name = self._extract_npm_package(args)
|
| 584 |
+
if package_name and package_name.startswith("github:"):
|
| 585 |
+
logger.debug(f"Skipping dependency check for GitHub-based npx package: {package_name}")
|
| 586 |
+
return True
|
| 587 |
+
|
| 588 |
+
# Check if this server has already failed installation
|
| 589 |
+
cache_key = f"{server_name}:{command}:{':'.join(args)}"
|
| 590 |
+
if cache_key in self._failed_installations:
|
| 591 |
+
error_msg = self._failed_installations[cache_key]
|
| 592 |
+
logger.debug(f"Skipping installation for '{server_name}' - previously failed")
|
| 593 |
+
raise MCPDependencyError(error_msg)
|
| 594 |
+
|
| 595 |
+
# Special handling for uvx - check if uv is installed
|
| 596 |
+
if command == "uvx":
|
| 597 |
+
if not self._check_command_available("uv"):
|
| 598 |
+
# Only show once to user, no verbose logging
|
| 599 |
+
print(f"\n{Colors.RED}✗ Server '{server_name}' requires 'uv' to be installed{Colors.RESET}")
|
| 600 |
+
print(f"{Colors.YELLOW}Please install uv first:")
|
| 601 |
+
print(f" • macOS/Linux: curl -LsSf https://astral.sh/uv/install.sh | sh")
|
| 602 |
+
print(f" • Or with pip: pip install uv")
|
| 603 |
+
print(f" • Or with brew: brew install uv{Colors.RESET}\n")
|
| 604 |
+
|
| 605 |
+
error_msg = f"uvx requires 'uv' to be installed (server: {server_name})"
|
| 606 |
+
self._failed_installations[cache_key] = error_msg
|
| 607 |
+
raise MCPCommandNotFoundError(error_msg)
|
| 608 |
+
|
| 609 |
+
# Check if the command is available
|
| 610 |
+
if not self._check_command_available(command):
|
| 611 |
+
error_msg = (
|
| 612 |
+
f"Command '{command}' is not available.\n"
|
| 613 |
+
f"Please install the necessary tools first."
|
| 614 |
+
)
|
| 615 |
+
logger.error(error_msg)
|
| 616 |
+
self._failed_installations[cache_key] = error_msg
|
| 617 |
+
raise MCPCommandNotFoundError(error_msg)
|
| 618 |
+
|
| 619 |
+
# Check if the package is installed
|
| 620 |
+
if await self._check_package_installed(command, args):
|
| 621 |
+
logger.debug(f"The dependencies of the MCP server '{server_name}' are installed")
|
| 622 |
+
return True
|
| 623 |
+
|
| 624 |
+
# Extract package name for display
|
| 625 |
+
if command == "npx":
|
| 626 |
+
package_name = self._extract_npm_package(args)
|
| 627 |
+
package_type = "npm"
|
| 628 |
+
elif command == "uvx":
|
| 629 |
+
package_name = self._extract_python_package(args)
|
| 630 |
+
package_type = "Python"
|
| 631 |
+
elif command == "uv":
|
| 632 |
+
package_name = self._extract_uv_package(args)
|
| 633 |
+
package_type = "Python"
|
| 634 |
+
else:
|
| 635 |
+
package_name = f"{command} {' '.join(args)}"
|
| 636 |
+
package_type = "package"
|
| 637 |
+
|
| 638 |
+
# Build the message for displaying the install command
|
| 639 |
+
install_cmd = self._get_install_command(command, args)
|
| 640 |
+
|
| 641 |
+
# If we can't determine an install command, show helpful message
|
| 642 |
+
if not install_cmd:
|
| 643 |
+
print(f"\n{Colors.YELLOW}Cannot automatically install dependencies for '{server_name}'{Colors.RESET}")
|
| 644 |
+
print(f"{Colors.GRAY}Command: {command} {' '.join(args)}{Colors.RESET}")
|
| 645 |
+
print(f"\n{Colors.WHITE}This MCP server may require manual installation or configuration.{Colors.RESET}")
|
| 646 |
+
print(f"{Colors.GRAY}Please refer to the MCP server's official documentation for installation instructions.{Colors.RESET}\n")
|
| 647 |
+
|
| 648 |
+
error_msg = f"Manual installation required for '{server_name}' (command: {command})"
|
| 649 |
+
self._failed_installations[cache_key] = error_msg
|
| 650 |
+
raise MCPDependencyError(error_msg)
|
| 651 |
+
|
| 652 |
+
install_cmd_str = ' '.join(install_cmd)
|
| 653 |
+
|
| 654 |
+
# Build the message
|
| 655 |
+
message = (
|
| 656 |
+
f"\n{Colors.WHITE}The MCP server needs to install dependencies{Colors.RESET}\n\n"
|
| 657 |
+
f"Server name: {Colors.CYAN}{server_name}{Colors.RESET}\n"
|
| 658 |
+
f"Package type: {Colors.YELLOW}{package_type}{Colors.RESET}\n"
|
| 659 |
+
f"Package name: {Colors.YELLOW}{package_name or 'Unknown'}{Colors.RESET}\n"
|
| 660 |
+
f"Install command: {Colors.GRAY}{install_cmd_str}{Colors.RESET}\n\n"
|
| 661 |
+
f"{Colors.YELLOW}Whether to install this dependency package?{Colors.RESET}"
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# Ask the user
|
| 665 |
+
if not await self._ask_user(message):
|
| 666 |
+
error_msg = f"User cancelled the dependency installation for '{server_name}'"
|
| 667 |
+
logger.warning(error_msg)
|
| 668 |
+
self._failed_installations[cache_key] = error_msg
|
| 669 |
+
raise MCPInstallationCancelledError(error_msg)
|
| 670 |
+
|
| 671 |
+
# Execute installation
|
| 672 |
+
success = await self._install_package(command, args)
|
| 673 |
+
|
| 674 |
+
if not success:
|
| 675 |
+
error_msg = f"Dependency installation failed for '{server_name}'"
|
| 676 |
+
logger.error(error_msg)
|
| 677 |
+
self._failed_installations[cache_key] = error_msg
|
| 678 |
+
raise MCPInstallationFailedError(error_msg)
|
| 679 |
+
|
| 680 |
+
return True
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
# Global singleton instance
|
| 684 |
+
_global_installer: Optional[MCPInstallerManager] = None
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def get_global_installer() -> MCPInstallerManager:
|
| 688 |
+
"""Get the global installer manager instance."""
|
| 689 |
+
global _global_installer
|
| 690 |
+
if _global_installer is None:
|
| 691 |
+
_global_installer = MCPInstallerManager()
|
| 692 |
+
return _global_installer
|
| 693 |
+
|
| 694 |
+
def set_global_installer(installer: MCPInstallerManager) -> None:
|
| 695 |
+
"""Set the global installer manager instance."""
|
| 696 |
+
global _global_installer
|
| 697 |
+
_global_installer = installer
|
openspace/grounding/backends/mcp/provider.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Provider implementation.
|
| 3 |
+
|
| 4 |
+
This module provides a provider for managing MCP server sessions.
|
| 5 |
+
"""
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
from openspace.grounding.backends.mcp.session import MCPSession
|
| 10 |
+
from openspace.grounding.core.provider import Provider
|
| 11 |
+
from openspace.grounding.core.types import SessionConfig, BackendType, ToolSchema
|
| 12 |
+
from openspace.grounding.backends.mcp.client import MCPClient
|
| 13 |
+
from openspace.grounding.backends.mcp.installer import MCPInstallerManager, MCPDependencyError
|
| 14 |
+
from openspace.grounding.backends.mcp.tool_cache import get_tool_cache
|
| 15 |
+
from openspace.grounding.backends.mcp.tool_converter import _sanitize_mcp_schema
|
| 16 |
+
from openspace.grounding.core.tool import BaseTool, RemoteTool
|
| 17 |
+
from openspace.utils.logging import Logger
|
| 18 |
+
from openspace.config.utils import get_config_value
|
| 19 |
+
|
| 20 |
+
logger = Logger.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MCPProvider(Provider[MCPSession]):
|
| 24 |
+
"""
|
| 25 |
+
MCP Provider manages multiple MCP server sessions.
|
| 26 |
+
|
| 27 |
+
Each MCP server defined in config corresponds to one session.
|
| 28 |
+
The provider handles lazy/eager session creation and tool aggregation.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: Dict | None = None, installer: Optional[MCPInstallerManager] = None):
|
| 32 |
+
"""Initialize MCP Provider.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
config: Configuration dict with MCP server definitions.
|
| 36 |
+
Example: {"mcpServers": {"server1": {...}, "server2": {...}}}
|
| 37 |
+
installer: Optional installer manager for dependency installation
|
| 38 |
+
"""
|
| 39 |
+
super().__init__(BackendType.MCP, config)
|
| 40 |
+
|
| 41 |
+
# Extract MCP-specific configuration
|
| 42 |
+
sandbox = get_config_value(config, "sandbox", False)
|
| 43 |
+
timeout = get_config_value(config, "timeout", 30)
|
| 44 |
+
sse_read_timeout = get_config_value(config, "sse_read_timeout", 300.0)
|
| 45 |
+
max_retries = get_config_value(config, "max_retries", 3)
|
| 46 |
+
retry_interval = get_config_value(config, "retry_interval", 2.0)
|
| 47 |
+
check_dependencies = get_config_value(config, "check_dependencies", True)
|
| 48 |
+
auto_install = get_config_value(config, "auto_install", False)
|
| 49 |
+
# Tool call retry settings (for transient errors like 400, 500, etc.)
|
| 50 |
+
tool_call_max_retries = get_config_value(config, "tool_call_max_retries", 3)
|
| 51 |
+
tool_call_retry_delay = get_config_value(config, "tool_call_retry_delay", 1.0)
|
| 52 |
+
|
| 53 |
+
# Create sandbox options if sandbox is enabled
|
| 54 |
+
sandbox_options = None
|
| 55 |
+
if sandbox:
|
| 56 |
+
sandbox_options = {
|
| 57 |
+
"timeout": timeout,
|
| 58 |
+
"sse_read_timeout": sse_read_timeout,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Create installer with auto_install setting if not provided
|
| 62 |
+
if installer is None and auto_install:
|
| 63 |
+
installer = MCPInstallerManager(auto_install=True)
|
| 64 |
+
|
| 65 |
+
# Initialize MCPClient with configuration
|
| 66 |
+
self._client = MCPClient(
|
| 67 |
+
config=config or {},
|
| 68 |
+
sandbox=sandbox,
|
| 69 |
+
sandbox_options=sandbox_options,
|
| 70 |
+
timeout=timeout,
|
| 71 |
+
sse_read_timeout=sse_read_timeout,
|
| 72 |
+
max_retries=max_retries,
|
| 73 |
+
retry_interval=retry_interval,
|
| 74 |
+
installer=installer,
|
| 75 |
+
check_dependencies=check_dependencies,
|
| 76 |
+
tool_call_max_retries=tool_call_max_retries,
|
| 77 |
+
tool_call_retry_delay=tool_call_retry_delay,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Map server name to session for quick lookup
|
| 81 |
+
self._server_sessions: Dict[str, MCPSession] = {}
|
| 82 |
+
|
| 83 |
+
async def initialize(self) -> None:
|
| 84 |
+
"""Initialize the MCP provider.
|
| 85 |
+
|
| 86 |
+
If config["eager_sessions"] is True, creates sessions for all configured servers.
|
| 87 |
+
Otherwise, sessions are created lazily on first access.
|
| 88 |
+
"""
|
| 89 |
+
if self.is_initialized:
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
# config can be dict or Pydantic model, use utility function
|
| 93 |
+
eager = get_config_value(self.config, "eager_sessions", False)
|
| 94 |
+
if eager:
|
| 95 |
+
servers = self.list_servers()
|
| 96 |
+
logger.debug(f"Eagerly initializing {len(servers)} MCP server sessions")
|
| 97 |
+
for srv in servers:
|
| 98 |
+
if srv not in self._server_sessions:
|
| 99 |
+
cfg = SessionConfig(
|
| 100 |
+
session_name=f"mcp-{srv}",
|
| 101 |
+
backend_type=BackendType.MCP,
|
| 102 |
+
connection_params={"server": srv},
|
| 103 |
+
)
|
| 104 |
+
await self.create_session(cfg)
|
| 105 |
+
|
| 106 |
+
self.is_initialized = True
|
| 107 |
+
logger.info(
|
| 108 |
+
f"MCPProvider initialized with {len(self.list_servers())} servers (eager={eager})"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def list_servers(self) -> List[str]:
|
| 112 |
+
"""Return all configured MCP server names from MCPClient config.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List of server names
|
| 116 |
+
"""
|
| 117 |
+
return self._client.get_server_names()
|
| 118 |
+
|
| 119 |
+
async def create_session(self, session_config: SessionConfig) -> MCPSession:
|
| 120 |
+
"""Create a new MCP session for a specific server.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
session_config: Must contain 'server' in connection_params
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
MCPSession instance
|
| 127 |
+
|
| 128 |
+
Raises:
|
| 129 |
+
ValueError: If 'server' not in connection_params
|
| 130 |
+
Exception: If session creation or initialization fails
|
| 131 |
+
"""
|
| 132 |
+
server = get_config_value(session_config.connection_params, "server")
|
| 133 |
+
if not server:
|
| 134 |
+
raise ValueError("MCPProvider.create_session requires 'server' in connection_params")
|
| 135 |
+
|
| 136 |
+
# Generate session_id: mcp-<server_name>
|
| 137 |
+
session_id = f"{self.backend_type.value}-{server}"
|
| 138 |
+
|
| 139 |
+
# Check if session already exists
|
| 140 |
+
if server in self._server_sessions:
|
| 141 |
+
logger.debug(f"Session for server '{server}' already exists, returning existing session")
|
| 142 |
+
return self._server_sessions[server]
|
| 143 |
+
|
| 144 |
+
# Create session through MCPClient
|
| 145 |
+
try:
|
| 146 |
+
logger.debug(f"Creating new session for MCP server: {server}")
|
| 147 |
+
session = await self._client.create_session(server, auto_initialize=True)
|
| 148 |
+
session.session_id = session_id
|
| 149 |
+
|
| 150 |
+
# Store in both maps
|
| 151 |
+
self._server_sessions[server] = session
|
| 152 |
+
self._sessions[session_id] = session
|
| 153 |
+
|
| 154 |
+
logger.info(f"Created MCP session '{session_id}' for server '{server}'")
|
| 155 |
+
return session
|
| 156 |
+
except MCPDependencyError as e:
|
| 157 |
+
# Dependency errors already shown to user, just debug log
|
| 158 |
+
logger.debug(f"Dependency error for server '{server}': {type(e).__name__}")
|
| 159 |
+
raise
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"Failed to create session for server '{server}': {e}")
|
| 162 |
+
raise
|
| 163 |
+
|
| 164 |
+
async def close_session(self, session_name: str) -> None:
|
| 165 |
+
"""Close an MCP session by session name.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
session_name: Session name in format 'mcp-<server_name>'
|
| 169 |
+
"""
|
| 170 |
+
# Parse server name from session_name (format: mcp-<server_name>)
|
| 171 |
+
try:
|
| 172 |
+
prefix, server_name = session_name.split("-", 1)
|
| 173 |
+
if prefix != self.backend_type.value:
|
| 174 |
+
raise ValueError(f"Invalid MCP session name format: {session_name}, expected 'mcp-<server_name>'")
|
| 175 |
+
except ValueError as e:
|
| 176 |
+
logger.warning(f"Invalid session_name format: {session_name} - {e}")
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
# Check if session exists
|
| 180 |
+
if session_name not in self._sessions and server_name not in self._server_sessions:
|
| 181 |
+
logger.warning(f"Session '{session_name}' not found, nothing to close")
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
error_occurred = False
|
| 185 |
+
try:
|
| 186 |
+
logger.debug(f"Closing MCP session '{session_name}' (server: {server_name})")
|
| 187 |
+
await self._client.close_session(server_name)
|
| 188 |
+
logger.info(f"Successfully closed MCP session '{session_name}'")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
error_occurred = True
|
| 191 |
+
logger.error(f"Error closing MCP session '{session_name}': {e}")
|
| 192 |
+
finally:
|
| 193 |
+
# Clean up both maps regardless of errors
|
| 194 |
+
self._server_sessions.pop(server_name, None)
|
| 195 |
+
self._sessions.pop(session_name, None)
|
| 196 |
+
|
| 197 |
+
if error_occurred:
|
| 198 |
+
logger.warning(f"Session '{session_name}' removed from tracking despite close error")
|
| 199 |
+
|
| 200 |
+
async def list_tools(self, session_name: str | None = None, use_cache: bool = True) -> List[BaseTool]:
|
| 201 |
+
"""List tools from MCP sessions.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
session_name: If provided, only list tools from that session.
|
| 205 |
+
If None, list tools from all sessions.
|
| 206 |
+
use_cache: If True, try to load from cache first (no server startup).
|
| 207 |
+
If False, start servers and get live tools.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
List of BaseTool instances
|
| 211 |
+
"""
|
| 212 |
+
await self.ensure_initialized()
|
| 213 |
+
|
| 214 |
+
# Case 1: List tools from specific session (always live, no cache)
|
| 215 |
+
if session_name:
|
| 216 |
+
sess = self._sessions.get(session_name)
|
| 217 |
+
if sess:
|
| 218 |
+
try:
|
| 219 |
+
tools = await sess.list_tools()
|
| 220 |
+
server_name = session_name.replace(f"{self.backend_type.value}-", "", 1)
|
| 221 |
+
for tool in tools:
|
| 222 |
+
tool.bind_runtime_info(
|
| 223 |
+
backend=self.backend_type,
|
| 224 |
+
session_name=session_name,
|
| 225 |
+
server_name=server_name,
|
| 226 |
+
)
|
| 227 |
+
return tools
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"Error listing tools from session '{session_name}': {e}")
|
| 230 |
+
return []
|
| 231 |
+
else:
|
| 232 |
+
logger.warning(f"Session '{session_name}' not found")
|
| 233 |
+
return []
|
| 234 |
+
|
| 235 |
+
# Case 2: List tools from all servers
|
| 236 |
+
# Try cache first if enabled
|
| 237 |
+
if use_cache:
|
| 238 |
+
cache = get_tool_cache()
|
| 239 |
+
if cache.has_cache():
|
| 240 |
+
tools = self._load_tools_from_cache()
|
| 241 |
+
if tools:
|
| 242 |
+
logger.info(f"Loaded {len(tools)} tools from cache (no server startup)")
|
| 243 |
+
return tools
|
| 244 |
+
|
| 245 |
+
# No cache or cache disabled, start servers
|
| 246 |
+
return await self._list_tools_live()
|
| 247 |
+
|
| 248 |
+
def _load_tools_from_cache(self) -> List[BaseTool]:
|
| 249 |
+
"""Load tools from cache file without starting servers.
|
| 250 |
+
|
| 251 |
+
Priority:
|
| 252 |
+
1. Try to load from sanitized cache (mcp_tool_cache_sanitized.json)
|
| 253 |
+
2. If not exists, load from raw cache and sanitize, then save sanitized version
|
| 254 |
+
"""
|
| 255 |
+
cache = get_tool_cache()
|
| 256 |
+
config_servers = self.list_servers()
|
| 257 |
+
|
| 258 |
+
# Try sanitized cache first
|
| 259 |
+
if cache.has_sanitized_cache():
|
| 260 |
+
logger.debug("Loading from sanitized cache")
|
| 261 |
+
all_cached_tools = cache.get_all_sanitized_tools()
|
| 262 |
+
return self._build_tools_from_cache(all_cached_tools, config_servers)
|
| 263 |
+
|
| 264 |
+
# Fall back to raw cache, sanitize and save
|
| 265 |
+
if cache.has_cache():
|
| 266 |
+
logger.info("Sanitized cache not found, building from raw cache...")
|
| 267 |
+
all_cached_tools = cache.get_all_tools()
|
| 268 |
+
sanitized_servers = self._sanitize_and_save_cache(all_cached_tools, cache)
|
| 269 |
+
return self._build_tools_from_cache(sanitized_servers, config_servers)
|
| 270 |
+
|
| 271 |
+
return []
|
| 272 |
+
|
| 273 |
+
def _sanitize_and_save_cache(
|
| 274 |
+
self,
|
| 275 |
+
raw_tools: Dict[str, List[Dict]],
|
| 276 |
+
cache
|
| 277 |
+
) -> Dict[str, List[Dict]]:
|
| 278 |
+
"""Sanitize raw cache and save to sanitized cache file."""
|
| 279 |
+
sanitized_servers: Dict[str, List[Dict]] = {}
|
| 280 |
+
|
| 281 |
+
for server_name, tool_list in raw_tools.items():
|
| 282 |
+
sanitized_tools = []
|
| 283 |
+
for tool_meta in tool_list:
|
| 284 |
+
raw_params = tool_meta.get("parameters", {})
|
| 285 |
+
sanitized_params = _sanitize_mcp_schema(raw_params)
|
| 286 |
+
sanitized_tools.append({
|
| 287 |
+
"name": tool_meta["name"],
|
| 288 |
+
"description": tool_meta.get("description", ""),
|
| 289 |
+
"parameters": sanitized_params,
|
| 290 |
+
})
|
| 291 |
+
sanitized_servers[server_name] = sanitized_tools
|
| 292 |
+
|
| 293 |
+
# Save sanitized cache for future use
|
| 294 |
+
cache.save_sanitized(sanitized_servers)
|
| 295 |
+
logger.info(f"Created sanitized cache with {len(sanitized_servers)} servers")
|
| 296 |
+
|
| 297 |
+
return sanitized_servers
|
| 298 |
+
|
| 299 |
+
def _build_tools_from_cache(
|
| 300 |
+
self,
|
| 301 |
+
all_cached_tools: Dict[str, List[Dict]],
|
| 302 |
+
config_servers: List[str]
|
| 303 |
+
) -> List[BaseTool]:
|
| 304 |
+
"""Build BaseTool instances from cached tool metadata."""
|
| 305 |
+
tools: List[BaseTool] = []
|
| 306 |
+
|
| 307 |
+
for server_name in config_servers:
|
| 308 |
+
tool_list = all_cached_tools.get(server_name)
|
| 309 |
+
if not tool_list:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
session_name = f"{self.backend_type.value}-{server_name}"
|
| 313 |
+
for tool_meta in tool_list:
|
| 314 |
+
schema = ToolSchema(
|
| 315 |
+
name=tool_meta["name"],
|
| 316 |
+
description=tool_meta.get("description", ""),
|
| 317 |
+
parameters=tool_meta.get("parameters", {}),
|
| 318 |
+
backend_type=BackendType.MCP,
|
| 319 |
+
)
|
| 320 |
+
tool = RemoteTool(schema=schema, connector=None, backend=BackendType.MCP)
|
| 321 |
+
tool.bind_runtime_info(
|
| 322 |
+
backend=self.backend_type,
|
| 323 |
+
session_name=session_name,
|
| 324 |
+
server_name=server_name,
|
| 325 |
+
)
|
| 326 |
+
tools.append(tool)
|
| 327 |
+
|
| 328 |
+
return tools
|
| 329 |
+
|
| 330 |
+
async def _list_tools_live(self) -> List[BaseTool]:
|
| 331 |
+
"""List tools by starting all servers.
|
| 332 |
+
|
| 333 |
+
Uses a semaphore to serialize session creation, avoiding TaskGroup race conditions
|
| 334 |
+
that occur when multiple MCP connections are initialized concurrently.
|
| 335 |
+
"""
|
| 336 |
+
servers = self.list_servers()
|
| 337 |
+
|
| 338 |
+
if not servers:
|
| 339 |
+
logger.warning("No MCP servers configured")
|
| 340 |
+
return []
|
| 341 |
+
|
| 342 |
+
# Find servers that don't have sessions yet
|
| 343 |
+
to_create = [s for s in servers if s not in self._server_sessions]
|
| 344 |
+
|
| 345 |
+
# Create missing sessions with serialized execution using semaphore
|
| 346 |
+
if to_create:
|
| 347 |
+
logger.info(f"Creating {len(to_create)} MCP sessions (serialized to avoid race conditions)")
|
| 348 |
+
|
| 349 |
+
# Use semaphore with limit=1 to serialize session creation
|
| 350 |
+
# This avoids TaskGroup race conditions in concurrent HTTP connection setup
|
| 351 |
+
semaphore = asyncio.Semaphore(1)
|
| 352 |
+
|
| 353 |
+
async def _create_with_semaphore(server: str):
|
| 354 |
+
async with semaphore:
|
| 355 |
+
logger.debug(f"Creating session for '{server}'")
|
| 356 |
+
return await self._lazy_create(server)
|
| 357 |
+
|
| 358 |
+
tasks = [_create_with_semaphore(s) for s in to_create]
|
| 359 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 360 |
+
|
| 361 |
+
# Log errors
|
| 362 |
+
for i, result in enumerate(results):
|
| 363 |
+
if isinstance(result, MCPDependencyError):
|
| 364 |
+
logger.debug(f"Dependency error for '{to_create[i]}': {type(result).__name__}")
|
| 365 |
+
elif isinstance(result, Exception):
|
| 366 |
+
logger.error(f"Failed to create session for '{to_create[i]}': {result}")
|
| 367 |
+
|
| 368 |
+
# Aggregate tools from all sessions
|
| 369 |
+
uniq: Dict[tuple[str, str], BaseTool] = {}
|
| 370 |
+
failed_servers = []
|
| 371 |
+
|
| 372 |
+
logger.debug(f"Listing tools from {len(self._server_sessions)} sessions")
|
| 373 |
+
for server, sess in self._server_sessions.items():
|
| 374 |
+
try:
|
| 375 |
+
tools = await sess.list_tools()
|
| 376 |
+
session_name = f"{self.backend_type.value}-{server}"
|
| 377 |
+
for tool in tools:
|
| 378 |
+
key = (server, tool.schema.name)
|
| 379 |
+
if key not in uniq:
|
| 380 |
+
tool.bind_runtime_info(
|
| 381 |
+
backend=self.backend_type,
|
| 382 |
+
session_name=session_name,
|
| 383 |
+
server_name=server,
|
| 384 |
+
)
|
| 385 |
+
uniq[key] = tool
|
| 386 |
+
except Exception as e:
|
| 387 |
+
failed_servers.append(server)
|
| 388 |
+
logger.error(f"Error listing tools from server '{server}': {e}")
|
| 389 |
+
|
| 390 |
+
if failed_servers:
|
| 391 |
+
logger.warning(f"Failed to list tools from {len(failed_servers)} server(s): {failed_servers}")
|
| 392 |
+
|
| 393 |
+
tools_list = list(uniq.values())
|
| 394 |
+
logger.debug(f"Listed {len(tools_list)} unique tools from {len(self._server_sessions)} MCP servers")
|
| 395 |
+
|
| 396 |
+
# Save to cache for next time
|
| 397 |
+
await self._save_tools_to_cache(tools_list)
|
| 398 |
+
|
| 399 |
+
return tools_list
|
| 400 |
+
|
| 401 |
+
async def _save_tools_to_cache(self, tools: List[BaseTool]) -> None:
|
| 402 |
+
"""Save tools metadata to cache file."""
|
| 403 |
+
cache = get_tool_cache()
|
| 404 |
+
|
| 405 |
+
# Group tools by server
|
| 406 |
+
servers: Dict[str, List[Dict]] = {}
|
| 407 |
+
for tool in tools:
|
| 408 |
+
server_name = tool.runtime_info.server_name if tool.is_bound else "unknown"
|
| 409 |
+
if server_name not in servers:
|
| 410 |
+
servers[server_name] = []
|
| 411 |
+
servers[server_name].append({
|
| 412 |
+
"name": tool.schema.name,
|
| 413 |
+
"description": tool.schema.description or "",
|
| 414 |
+
"parameters": tool.schema.parameters or {},
|
| 415 |
+
})
|
| 416 |
+
|
| 417 |
+
cache.save(servers)
|
| 418 |
+
|
| 419 |
+
async def ensure_server_session(self, server_name: str) -> Optional[MCPSession]:
|
| 420 |
+
"""Ensure a server session exists, creating it if needed.
|
| 421 |
+
|
| 422 |
+
This is used for on-demand server startup when executing tools.
|
| 423 |
+
"""
|
| 424 |
+
if server_name in self._server_sessions:
|
| 425 |
+
return self._server_sessions[server_name]
|
| 426 |
+
|
| 427 |
+
# Server not running, start it
|
| 428 |
+
logger.info(f"Starting MCP server on-demand: {server_name}")
|
| 429 |
+
cfg = SessionConfig(
|
| 430 |
+
session_name=f"mcp-{server_name}",
|
| 431 |
+
backend_type=BackendType.MCP,
|
| 432 |
+
connection_params={"server": server_name},
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
try:
|
| 436 |
+
session = await self.create_session(cfg)
|
| 437 |
+
return session
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.error(f"Failed to start server '{server_name}': {e}")
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
async def _lazy_create(self, server: str) -> None:
|
| 443 |
+
"""Internal helper for lazy session creation.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
server: Server name to create session for
|
| 447 |
+
|
| 448 |
+
Raises:
|
| 449 |
+
Exception: Re-raises any exception from session creation for error tracking
|
| 450 |
+
"""
|
| 451 |
+
# Double-check to avoid race conditions
|
| 452 |
+
if server in self._server_sessions:
|
| 453 |
+
logger.debug(f"Session for server '{server}' already exists, skipping lazy creation")
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
cfg = SessionConfig(
|
| 457 |
+
session_name=f"mcp-{server}",
|
| 458 |
+
backend_type=BackendType.MCP,
|
| 459 |
+
connection_params={"server": server},
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
await self.create_session(cfg)
|
| 464 |
+
logger.debug(f"Lazily created session for server '{server}'")
|
| 465 |
+
except MCPDependencyError as e:
|
| 466 |
+
# Dependency errors already shown to user
|
| 467 |
+
logger.debug(f"Dependency error for server '{server}': {type(e).__name__}")
|
| 468 |
+
# Re-raise so that asyncio.gather can track the error
|
| 469 |
+
raise
|
| 470 |
+
except Exception as e:
|
| 471 |
+
logger.error(f"Failed to lazily create session for server '{server}': {e}")
|
| 472 |
+
# Re-raise so that asyncio.gather can track the error
|
| 473 |
+
raise
|
openspace/grounding/backends/mcp/session.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session manager for MCP connections.
|
| 3 |
+
|
| 4 |
+
This module provides a session manager for MCP connections,
|
| 5 |
+
which handles authentication, initialization, and tool discovery.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict
|
| 9 |
+
|
| 10 |
+
from openspace.grounding.backends.mcp.transport.connectors import MCPBaseConnector
|
| 11 |
+
from openspace.grounding.backends.mcp.tool_converter import convert_mcp_tool_to_base_tool
|
| 12 |
+
from openspace.grounding.core.session import BaseSession
|
| 13 |
+
from openspace.grounding.core.types import BackendType
|
| 14 |
+
from openspace.utils.logging import Logger
|
| 15 |
+
|
| 16 |
+
logger = Logger.get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MCPSession(BaseSession):
|
| 20 |
+
"""Session manager for MCP connections.
|
| 21 |
+
|
| 22 |
+
This class manages the lifecycle of an MCP connection, including
|
| 23 |
+
authentication, initialization, and tool discovery.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
connector: MCPBaseConnector,
|
| 29 |
+
*,
|
| 30 |
+
session_id: str = "",
|
| 31 |
+
auto_connect: bool = True,
|
| 32 |
+
auto_initialize: bool = True,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""Initialize a new MCP session.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
connector: The connector to use for communicating with the MCP implementation.
|
| 38 |
+
session_id: Unique identifier for this session
|
| 39 |
+
auto_connect: Whether to automatically connect to the MCP implementation.
|
| 40 |
+
auto_initialize: Whether to automatically initialize the session.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__(
|
| 43 |
+
connector=connector,
|
| 44 |
+
session_id=session_id,
|
| 45 |
+
backend_type=BackendType.MCP,
|
| 46 |
+
auto_connect=auto_connect,
|
| 47 |
+
auto_initialize=auto_initialize,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
async def initialize(self) -> Dict[str, Any]:
|
| 51 |
+
"""Initialize the MCP session and discover available tools.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
The session information returned by the MCP implementation.
|
| 55 |
+
"""
|
| 56 |
+
# Make sure we're connected
|
| 57 |
+
if not self.is_connected and self.auto_connect:
|
| 58 |
+
await self.connect()
|
| 59 |
+
|
| 60 |
+
# Initialize the session through connector
|
| 61 |
+
logger.debug(f"Initializing MCP session {self.session_id}")
|
| 62 |
+
session_info = await self.connector.initialize()
|
| 63 |
+
|
| 64 |
+
# List tools from MCP server and convert to BaseTool
|
| 65 |
+
mcp_tools = self.connector.tools # MCPBaseConnector caches tools after initialize
|
| 66 |
+
logger.debug(f"Converting {len(mcp_tools)} MCP tools to BaseTool")
|
| 67 |
+
|
| 68 |
+
self.tools = [
|
| 69 |
+
convert_mcp_tool_to_base_tool(mcp_tool, self.connector)
|
| 70 |
+
for mcp_tool in mcp_tools
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
logger.debug(f"MCP session {self.session_id} initialized with {len(self.tools)} tools")
|
| 74 |
+
|
| 75 |
+
return session_info
|
openspace/grounding/backends/mcp/tool_cache.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from openspace.utils.logging import Logger
|
| 7 |
+
|
| 8 |
+
logger = Logger.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
# Cache path in project root directory (OpenSpace/)
|
| 11 |
+
# __file__ = .../OpenSpace/openspace/grounding/backends/mcp/tool_cache.py
|
| 12 |
+
# parent x5 = .../OpenSpace/
|
| 13 |
+
DEFAULT_CACHE_PATH = Path(__file__).parent.parent.parent.parent.parent / "mcp_tool_cache.json"
|
| 14 |
+
# Sanitized cache path (Claude API compatible JSON Schema)
|
| 15 |
+
DEFAULT_SANITIZED_CACHE_PATH = Path(__file__).parent.parent.parent.parent.parent / "mcp_tool_cache_sanitized.json"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MCPToolCache:
|
| 19 |
+
"""Simple file-based cache for MCP tool metadata."""
|
| 20 |
+
|
| 21 |
+
CACHE_VERSION = 1
|
| 22 |
+
|
| 23 |
+
def __init__(self, cache_path: Optional[Path] = None, sanitized_cache_path: Optional[Path] = None):
|
| 24 |
+
self.cache_path = cache_path or DEFAULT_CACHE_PATH
|
| 25 |
+
self.sanitized_cache_path = sanitized_cache_path or DEFAULT_SANITIZED_CACHE_PATH
|
| 26 |
+
self._cache: Optional[Dict] = None
|
| 27 |
+
self._sanitized_cache: Optional[Dict] = None
|
| 28 |
+
self._server_order: Optional[List[str]] = None
|
| 29 |
+
|
| 30 |
+
def set_server_order(self, order: List[str]):
|
| 31 |
+
"""Set expected server order (from config). Used when saving to disk."""
|
| 32 |
+
self._server_order = order
|
| 33 |
+
|
| 34 |
+
def _reorder_servers(self, servers: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
|
| 35 |
+
"""Reorder servers dict according to _server_order."""
|
| 36 |
+
if not self._server_order:
|
| 37 |
+
return servers
|
| 38 |
+
|
| 39 |
+
ordered = {}
|
| 40 |
+
# First add servers in config order
|
| 41 |
+
for name in self._server_order:
|
| 42 |
+
if name in servers:
|
| 43 |
+
ordered[name] = servers[name]
|
| 44 |
+
# Then add any remaining servers (not in config)
|
| 45 |
+
for name in servers:
|
| 46 |
+
if name not in ordered:
|
| 47 |
+
ordered[name] = servers[name]
|
| 48 |
+
return ordered
|
| 49 |
+
|
| 50 |
+
def _ensure_dir(self):
|
| 51 |
+
"""Ensure cache directory exists."""
|
| 52 |
+
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
def load(self) -> Dict[str, Any]:
|
| 55 |
+
"""Load cache from disk. Returns empty dict if not exists."""
|
| 56 |
+
if self._cache is not None:
|
| 57 |
+
return self._cache
|
| 58 |
+
|
| 59 |
+
if not self.cache_path.exists():
|
| 60 |
+
self._cache = {"version": self.CACHE_VERSION, "servers": {}}
|
| 61 |
+
return self._cache
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
with open(self.cache_path, "r", encoding="utf-8") as f:
|
| 65 |
+
self._cache = json.load(f)
|
| 66 |
+
logger.info(f"Loaded MCP tool cache: {len(self._cache.get('servers', {}))} servers")
|
| 67 |
+
return self._cache
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.warning(f"Failed to load cache: {e}")
|
| 70 |
+
self._cache = {"version": self.CACHE_VERSION, "servers": {}}
|
| 71 |
+
return self._cache
|
| 72 |
+
|
| 73 |
+
def save(self, servers: Dict[str, List[Dict]]):
|
| 74 |
+
"""
|
| 75 |
+
Save tool metadata to disk (overwrites existing cache).
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
servers: Dict mapping server_name -> list of tool metadata dicts
|
| 79 |
+
Each tool dict should have: name, description, parameters
|
| 80 |
+
"""
|
| 81 |
+
self._ensure_dir()
|
| 82 |
+
|
| 83 |
+
cache_data = {
|
| 84 |
+
"version": self.CACHE_VERSION,
|
| 85 |
+
"updated_at": datetime.now().isoformat(),
|
| 86 |
+
"servers": servers,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
with open(self.cache_path, "w", encoding="utf-8") as f:
|
| 91 |
+
json.dump(cache_data, f, indent=2, ensure_ascii=False)
|
| 92 |
+
self._cache = cache_data
|
| 93 |
+
logger.info(f"Saved MCP tool cache: {len(servers)} servers")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Failed to save cache: {e}")
|
| 96 |
+
|
| 97 |
+
def save_server(self, server_name: str, tools: List[Dict]):
|
| 98 |
+
"""
|
| 99 |
+
Save/update a single server's tools to cache (incremental append).
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
server_name: Name of the MCP server
|
| 103 |
+
tools: List of tool metadata dicts for this server
|
| 104 |
+
"""
|
| 105 |
+
self._ensure_dir()
|
| 106 |
+
|
| 107 |
+
# Load existing cache
|
| 108 |
+
cache = self.load()
|
| 109 |
+
|
| 110 |
+
# Update server entry
|
| 111 |
+
if "servers" not in cache:
|
| 112 |
+
cache["servers"] = {}
|
| 113 |
+
cache["servers"][server_name] = tools
|
| 114 |
+
cache["servers"] = self._reorder_servers(cache["servers"])
|
| 115 |
+
cache["updated_at"] = datetime.now().isoformat()
|
| 116 |
+
|
| 117 |
+
# Save back
|
| 118 |
+
try:
|
| 119 |
+
with open(self.cache_path, "w", encoding="utf-8") as f:
|
| 120 |
+
json.dump(cache, f, indent=2, ensure_ascii=False)
|
| 121 |
+
self._cache = cache
|
| 122 |
+
logger.debug(f"Saved {len(tools)} tools for server '{server_name}'")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Failed to save cache for server '{server_name}': {e}")
|
| 125 |
+
|
| 126 |
+
def get_server_tools(self, server_name: str) -> Optional[List[Dict]]:
|
| 127 |
+
"""Get cached tools for a specific server."""
|
| 128 |
+
cache = self.load()
|
| 129 |
+
return cache.get("servers", {}).get(server_name)
|
| 130 |
+
|
| 131 |
+
def get_all_tools(self) -> Dict[str, List[Dict]]:
|
| 132 |
+
"""Get all cached tools, grouped by server."""
|
| 133 |
+
cache = self.load()
|
| 134 |
+
return cache.get("servers", {})
|
| 135 |
+
|
| 136 |
+
def has_cache(self) -> bool:
|
| 137 |
+
"""Check if cache exists and has data."""
|
| 138 |
+
cache = self.load()
|
| 139 |
+
return bool(cache.get("servers"))
|
| 140 |
+
|
| 141 |
+
def clear(self):
|
| 142 |
+
"""Clear the cache."""
|
| 143 |
+
if self.cache_path.exists():
|
| 144 |
+
self.cache_path.unlink()
|
| 145 |
+
self._cache = None
|
| 146 |
+
logger.info("MCP tool cache cleared")
|
| 147 |
+
|
| 148 |
+
def save_failed_server(self, server_name: str, error: str):
|
| 149 |
+
"""
|
| 150 |
+
Record a failed server to cache.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
server_name: Name of the failed MCP server
|
| 154 |
+
error: Error message
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_dir()
|
| 157 |
+
|
| 158 |
+
# Load existing cache
|
| 159 |
+
cache = self.load()
|
| 160 |
+
|
| 161 |
+
# Add to failed_servers list
|
| 162 |
+
if "failed_servers" not in cache:
|
| 163 |
+
cache["failed_servers"] = {}
|
| 164 |
+
cache["failed_servers"][server_name] = {
|
| 165 |
+
"error": error,
|
| 166 |
+
"failed_at": datetime.now().isoformat(),
|
| 167 |
+
}
|
| 168 |
+
cache["updated_at"] = datetime.now().isoformat()
|
| 169 |
+
|
| 170 |
+
# Save back
|
| 171 |
+
try:
|
| 172 |
+
with open(self.cache_path, "w", encoding="utf-8") as f:
|
| 173 |
+
json.dump(cache, f, indent=2, ensure_ascii=False)
|
| 174 |
+
self._cache = cache
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"Failed to save failed server '{server_name}': {e}")
|
| 177 |
+
|
| 178 |
+
def get_failed_servers(self) -> Dict[str, Dict]:
|
| 179 |
+
"""Get list of failed servers from cache."""
|
| 180 |
+
cache = self.load()
|
| 181 |
+
return cache.get("failed_servers", {})
|
| 182 |
+
|
| 183 |
+
def load_sanitized(self) -> Dict[str, Any]:
|
| 184 |
+
"""Load sanitized cache from disk. Returns empty dict if not exists."""
|
| 185 |
+
if self._sanitized_cache is not None:
|
| 186 |
+
return self._sanitized_cache
|
| 187 |
+
|
| 188 |
+
if not self.sanitized_cache_path.exists():
|
| 189 |
+
self._sanitized_cache = {"version": self.CACHE_VERSION, "servers": {}}
|
| 190 |
+
return self._sanitized_cache
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
with open(self.sanitized_cache_path, "r", encoding="utf-8") as f:
|
| 194 |
+
self._sanitized_cache = json.load(f)
|
| 195 |
+
logger.info(f"Loaded sanitized MCP tool cache: {len(self._sanitized_cache.get('servers', {}))} servers")
|
| 196 |
+
return self._sanitized_cache
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.warning(f"Failed to load sanitized cache: {e}")
|
| 199 |
+
self._sanitized_cache = {"version": self.CACHE_VERSION, "servers": {}}
|
| 200 |
+
return self._sanitized_cache
|
| 201 |
+
|
| 202 |
+
def save_sanitized(self, servers: Dict[str, List[Dict]]):
|
| 203 |
+
"""
|
| 204 |
+
Save sanitized tool metadata to disk.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
servers: Dict mapping server_name -> list of sanitized tool metadata dicts
|
| 208 |
+
"""
|
| 209 |
+
self._ensure_dir()
|
| 210 |
+
|
| 211 |
+
cache_data = {
|
| 212 |
+
"version": self.CACHE_VERSION,
|
| 213 |
+
"updated_at": datetime.now().isoformat(),
|
| 214 |
+
"sanitized": True,
|
| 215 |
+
"servers": servers,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
with open(self.sanitized_cache_path, "w", encoding="utf-8") as f:
|
| 220 |
+
json.dump(cache_data, f, indent=2, ensure_ascii=False)
|
| 221 |
+
self._sanitized_cache = cache_data
|
| 222 |
+
logger.info(f"Saved sanitized MCP tool cache: {len(servers)} servers")
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Failed to save sanitized cache: {e}")
|
| 225 |
+
|
| 226 |
+
def get_all_sanitized_tools(self) -> Dict[str, List[Dict]]:
|
| 227 |
+
"""Get all sanitized cached tools, grouped by server."""
|
| 228 |
+
cache = self.load_sanitized()
|
| 229 |
+
return cache.get("servers", {})
|
| 230 |
+
|
| 231 |
+
def has_sanitized_cache(self) -> bool:
|
| 232 |
+
"""Check if sanitized cache exists and has data."""
|
| 233 |
+
cache = self.load_sanitized()
|
| 234 |
+
return bool(cache.get("servers"))
|
| 235 |
+
|
| 236 |
+
def clear_sanitized(self):
|
| 237 |
+
"""Clear the sanitized cache."""
|
| 238 |
+
if self.sanitized_cache_path.exists():
|
| 239 |
+
self.sanitized_cache_path.unlink()
|
| 240 |
+
self._sanitized_cache = None
|
| 241 |
+
logger.info("Sanitized MCP tool cache cleared")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# Global instance
|
| 245 |
+
_tool_cache: Optional[MCPToolCache] = None
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_tool_cache() -> MCPToolCache:
|
| 249 |
+
"""Get global tool cache instance."""
|
| 250 |
+
global _tool_cache
|
| 251 |
+
if _tool_cache is None:
|
| 252 |
+
_tool_cache = MCPToolCache()
|
| 253 |
+
return _tool_cache
|
| 254 |
+
|
openspace/grounding/backends/mcp/tool_converter.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tool converter for MCP.
|
| 3 |
+
|
| 4 |
+
This module provides utilities to convert MCP tools to BaseTool instances.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
from typing import Any, Dict
|
| 9 |
+
from mcp.types import Tool as MCPTool
|
| 10 |
+
|
| 11 |
+
from openspace.grounding.core.tool import BaseTool, RemoteTool
|
| 12 |
+
from openspace.grounding.core.types import BackendType, ToolSchema
|
| 13 |
+
from openspace.grounding.core.transport.connectors import BaseConnector
|
| 14 |
+
from openspace.utils.logging import Logger
|
| 15 |
+
|
| 16 |
+
logger = Logger.get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _sanitize_mcp_schema(params: Dict[str, Any]) -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Sanitize MCP tool schema to ensure Claude API compatibility (JSON Schema draft 2020-12).
|
| 22 |
+
|
| 23 |
+
Fixes:
|
| 24 |
+
- Empty schemas -> valid object schema
|
| 25 |
+
- Missing required fields (type, properties, required)
|
| 26 |
+
- Removes non-standard fields (title, examples, nullable, default, etc.)
|
| 27 |
+
- Recursively cleans nested properties and items
|
| 28 |
+
- Ensures every property has a valid type
|
| 29 |
+
- Ensures top-level type is 'object' (Anthropic API requirement)
|
| 30 |
+
"""
|
| 31 |
+
if not params:
|
| 32 |
+
return {"type": "object", "properties": {}, "required": []}
|
| 33 |
+
|
| 34 |
+
sanitized = copy.deepcopy(params)
|
| 35 |
+
sanitized = _deep_sanitize(sanitized)
|
| 36 |
+
|
| 37 |
+
# Anthropic API requires top-level type to be 'object'
|
| 38 |
+
# If it's not an object, wrap the schema as a property of an object
|
| 39 |
+
top_level_type = sanitized.get("type")
|
| 40 |
+
if top_level_type and top_level_type != "object":
|
| 41 |
+
logger.debug(f"[MCP_SCHEMA_SANITIZE] Wrapping non-object schema (type={top_level_type}) into object")
|
| 42 |
+
wrapped = {
|
| 43 |
+
"type": "object",
|
| 44 |
+
"properties": {
|
| 45 |
+
"value": sanitized # The original schema becomes a property
|
| 46 |
+
},
|
| 47 |
+
"required": ["value"] # Make it required
|
| 48 |
+
}
|
| 49 |
+
sanitized = wrapped
|
| 50 |
+
|
| 51 |
+
return sanitized
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _deep_sanitize(schema: Dict[str, Any]) -> Dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Recursively sanitize a JSON schema to conform to JSON Schema draft 2020-12.
|
| 57 |
+
Removes non-standard fields and ensures valid structure.
|
| 58 |
+
"""
|
| 59 |
+
if not isinstance(schema, dict):
|
| 60 |
+
return {"type": "string"}
|
| 61 |
+
|
| 62 |
+
# Allowed top-level keys for Claude API compatibility
|
| 63 |
+
allowed_keys = {
|
| 64 |
+
"type", "properties", "required", "items",
|
| 65 |
+
"description", "enum", "const",
|
| 66 |
+
"minimum", "maximum", "minLength", "maxLength",
|
| 67 |
+
"minItems", "maxItems", "pattern",
|
| 68 |
+
"additionalProperties", "anyOf", "oneOf", "allOf"
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Remove disallowed keys
|
| 72 |
+
keys_to_remove = [k for k in schema if k not in allowed_keys]
|
| 73 |
+
for k in keys_to_remove:
|
| 74 |
+
schema.pop(k, None)
|
| 75 |
+
|
| 76 |
+
# Ensure type exists
|
| 77 |
+
if "type" not in schema:
|
| 78 |
+
# Type is defined via anyOf/oneOf/allOf - don't add default type
|
| 79 |
+
# These combination keywords define the type themselves
|
| 80 |
+
if "anyOf" in schema or "oneOf" in schema or "allOf" in schema:
|
| 81 |
+
pass # Type is defined through combination keywords, do not add default type
|
| 82 |
+
# Try to infer type
|
| 83 |
+
elif "properties" in schema:
|
| 84 |
+
schema["type"] = "object"
|
| 85 |
+
elif "items" in schema:
|
| 86 |
+
schema["type"] = "array"
|
| 87 |
+
elif "enum" in schema:
|
| 88 |
+
# For enum, try to infer from values
|
| 89 |
+
enum_vals = schema.get("enum", [])
|
| 90 |
+
if enum_vals and all(isinstance(v, str) for v in enum_vals):
|
| 91 |
+
schema["type"] = "string"
|
| 92 |
+
elif enum_vals and all(isinstance(v, (int, float)) for v in enum_vals):
|
| 93 |
+
schema["type"] = "number"
|
| 94 |
+
else:
|
| 95 |
+
schema["type"] = "string"
|
| 96 |
+
elif not schema:
|
| 97 |
+
# Empty schema (e.g., only had $schema which was removed) -> no parameters needed
|
| 98 |
+
schema["type"] = "object"
|
| 99 |
+
schema["properties"] = {}
|
| 100 |
+
schema["required"] = []
|
| 101 |
+
else:
|
| 102 |
+
schema["type"] = "object"
|
| 103 |
+
|
| 104 |
+
# Handle object type
|
| 105 |
+
if schema.get("type") == "object":
|
| 106 |
+
if "properties" not in schema:
|
| 107 |
+
schema["properties"] = {}
|
| 108 |
+
if "required" not in schema:
|
| 109 |
+
schema["required"] = []
|
| 110 |
+
|
| 111 |
+
# Recursively sanitize properties
|
| 112 |
+
if isinstance(schema.get("properties"), dict):
|
| 113 |
+
for prop_name, prop_schema in list(schema["properties"].items()):
|
| 114 |
+
if isinstance(prop_schema, dict):
|
| 115 |
+
schema["properties"][prop_name] = _deep_sanitize(prop_schema)
|
| 116 |
+
else:
|
| 117 |
+
# Invalid property schema, replace with string
|
| 118 |
+
schema["properties"][prop_name] = {"type": "string"}
|
| 119 |
+
|
| 120 |
+
# Sanitize additionalProperties if present
|
| 121 |
+
if "additionalProperties" in schema and isinstance(schema["additionalProperties"], dict):
|
| 122 |
+
schema["additionalProperties"] = _deep_sanitize(schema["additionalProperties"])
|
| 123 |
+
|
| 124 |
+
# Handle array type
|
| 125 |
+
elif schema.get("type") == "array":
|
| 126 |
+
if "items" in schema:
|
| 127 |
+
if isinstance(schema["items"], dict):
|
| 128 |
+
schema["items"] = _deep_sanitize(schema["items"])
|
| 129 |
+
elif isinstance(schema["items"], list):
|
| 130 |
+
# Tuple validation - sanitize each item
|
| 131 |
+
schema["items"] = [_deep_sanitize(item) if isinstance(item, dict) else {"type": "string"} for item in schema["items"]]
|
| 132 |
+
else:
|
| 133 |
+
schema["items"] = {"type": "string"}
|
| 134 |
+
else:
|
| 135 |
+
# Default items to string if not specified
|
| 136 |
+
schema["items"] = {"type": "string"}
|
| 137 |
+
|
| 138 |
+
# Handle anyOf/oneOf/allOf
|
| 139 |
+
for combo_key in ["anyOf", "oneOf", "allOf"]:
|
| 140 |
+
if combo_key in schema and isinstance(schema[combo_key], list):
|
| 141 |
+
schema[combo_key] = [
|
| 142 |
+
_deep_sanitize(sub) if isinstance(sub, dict) else {"type": "string"}
|
| 143 |
+
for sub in schema[combo_key]
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
return schema
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def convert_mcp_tool_to_base_tool(
|
| 150 |
+
mcp_tool: MCPTool,
|
| 151 |
+
connector: BaseConnector
|
| 152 |
+
) -> BaseTool:
|
| 153 |
+
"""
|
| 154 |
+
Convert an MCP Tool to a BaseTool (RemoteTool) instance.
|
| 155 |
+
|
| 156 |
+
This function extracts the tool schema from an MCP tool object and creates
|
| 157 |
+
a RemoteTool that can be used within the grounding framework.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
mcp_tool: MCP Tool object from the MCP SDK
|
| 161 |
+
connector: Connector instance for communicating with the MCP server
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
RemoteTool instance wrapping the MCP tool
|
| 165 |
+
"""
|
| 166 |
+
# Extract tool metadata
|
| 167 |
+
tool_name = mcp_tool.name
|
| 168 |
+
tool_description = getattr(mcp_tool, 'description', None) or ""
|
| 169 |
+
|
| 170 |
+
# Convert MCP input schema to our parameter schema format (with sanitization)
|
| 171 |
+
input_schema: Dict[str, Any] = {}
|
| 172 |
+
if hasattr(mcp_tool, 'inputSchema') and mcp_tool.inputSchema:
|
| 173 |
+
input_schema = _sanitize_mcp_schema(mcp_tool.inputSchema)
|
| 174 |
+
else:
|
| 175 |
+
input_schema = {"type": "object", "properties": {}, "required": []}
|
| 176 |
+
|
| 177 |
+
# Create ToolSchema
|
| 178 |
+
schema = ToolSchema(
|
| 179 |
+
name=tool_name,
|
| 180 |
+
description=tool_description,
|
| 181 |
+
parameters=input_schema,
|
| 182 |
+
backend_type=BackendType.MCP,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Create and return RemoteTool
|
| 186 |
+
remote_tool = RemoteTool(
|
| 187 |
+
connector=connector,
|
| 188 |
+
remote_name=tool_name,
|
| 189 |
+
schema=schema,
|
| 190 |
+
backend=BackendType.MCP,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
logger.debug(f"Converted MCP tool '{tool_name}' to RemoteTool")
|
| 194 |
+
return remote_tool
|
openspace/grounding/backends/mcp/transport/connectors/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Connectors for various MCP transports.
|
| 3 |
+
|
| 4 |
+
This module provides interfaces for connecting to MCP implementations
|
| 5 |
+
through different transport mechanisms.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .base import MCPBaseConnector # noqa: F401
|
| 9 |
+
from .http import HttpConnector # noqa: F401
|
| 10 |
+
from .sandbox import SandboxConnector # noqa: F401
|
| 11 |
+
from .stdio import StdioConnector # noqa: F401
|
| 12 |
+
from .websocket import WebSocketConnector # noqa: F401
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"MCPBaseConnector",
|
| 16 |
+
"StdioConnector",
|
| 17 |
+
"HttpConnector",
|
| 18 |
+
"WebSocketConnector",
|
| 19 |
+
"SandboxConnector",
|
| 20 |
+
]
|
openspace/grounding/backends/mcp/transport/connectors/base.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base connector for MCP implementations.
|
| 3 |
+
|
| 4 |
+
This module provides the base connector interface that all MCP connectors must implement.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from abc import abstractmethod
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from mcp import ClientSession
|
| 12 |
+
from mcp.shared.exceptions import McpError
|
| 13 |
+
from mcp.types import CallToolResult, GetPromptResult, Prompt, ReadResourceResult, Resource, Tool
|
| 14 |
+
|
| 15 |
+
from openspace.grounding.core.transport.task_managers import BaseConnectionManager
|
| 16 |
+
from openspace.grounding.core.transport.connectors import BaseConnector
|
| 17 |
+
from openspace.utils.logging import Logger
|
| 18 |
+
|
| 19 |
+
logger = Logger.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
# Default retry settings for tool calls
|
| 22 |
+
DEFAULT_TOOL_CALL_MAX_RETRIES = 3
|
| 23 |
+
DEFAULT_TOOL_CALL_RETRY_DELAY = 1.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MCPBaseConnector(BaseConnector[ClientSession]):
|
| 27 |
+
"""Base class for MCP connectors.
|
| 28 |
+
|
| 29 |
+
This class defines the interface that all MCP connectors must implement.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
connection_manager: BaseConnectionManager[ClientSession],
|
| 35 |
+
tool_call_max_retries: int = DEFAULT_TOOL_CALL_MAX_RETRIES,
|
| 36 |
+
tool_call_retry_delay: float = DEFAULT_TOOL_CALL_RETRY_DELAY,
|
| 37 |
+
):
|
| 38 |
+
"""Initialize base connector with common attributes.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
connection_manager: The connection manager to use for the connection.
|
| 42 |
+
tool_call_max_retries: Maximum number of retries for tool calls (default: 3)
|
| 43 |
+
tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0)
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(connection_manager)
|
| 46 |
+
self.client_session: ClientSession | None = None
|
| 47 |
+
self._tools: list[Tool] | None = None
|
| 48 |
+
self._resources: list[Resource] | None = None
|
| 49 |
+
self._prompts: list[Prompt] | None = None
|
| 50 |
+
self.auto_reconnect = True # Whether to automatically reconnect on connection loss (not configurable for now)
|
| 51 |
+
self.tool_call_max_retries = tool_call_max_retries
|
| 52 |
+
self.tool_call_retry_delay = tool_call_retry_delay
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def public_identifier(self) -> str:
|
| 57 |
+
"""Get the identifier for the connector."""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
async def _get_streams_from_connection(self):
|
| 61 |
+
"""Get read and write streams from the connection. Override in subclasses if needed."""
|
| 62 |
+
# Default implementation for most MCP connectors (stdio, HTTP)
|
| 63 |
+
# Returns the connection directly as it should be a tuple of (read_stream, write_stream)
|
| 64 |
+
return self._connection
|
| 65 |
+
|
| 66 |
+
async def _after_connect(self) -> None:
|
| 67 |
+
"""Create ClientSession after connection is established.
|
| 68 |
+
|
| 69 |
+
Some connectors (like WebSocket) don't use ClientSession and may override this method.
|
| 70 |
+
"""
|
| 71 |
+
# Get streams from the connection
|
| 72 |
+
streams = await self._get_streams_from_connection()
|
| 73 |
+
|
| 74 |
+
if streams is None:
|
| 75 |
+
# Some connectors (like WebSocket) don't use ClientSession
|
| 76 |
+
# They should override this method to set up their own resources
|
| 77 |
+
logger.debug("No streams returned, ClientSession creation skipped")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
if isinstance(streams, tuple) and len(streams) == 2:
|
| 81 |
+
read_stream, write_stream = streams
|
| 82 |
+
# Create the client session
|
| 83 |
+
self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
|
| 84 |
+
await self.client_session.__aenter__()
|
| 85 |
+
logger.debug("MCP ClientSession created successfully")
|
| 86 |
+
else:
|
| 87 |
+
raise RuntimeError(f"Invalid streams format: expected tuple of 2 elements, got {type(streams)}")
|
| 88 |
+
|
| 89 |
+
async def _before_disconnect(self) -> None:
|
| 90 |
+
"""Clean up MCP-specific resources before disconnection."""
|
| 91 |
+
errors = []
|
| 92 |
+
|
| 93 |
+
# Close the client session
|
| 94 |
+
if self.client_session:
|
| 95 |
+
try:
|
| 96 |
+
logger.debug("Closing MCP client session")
|
| 97 |
+
await self.client_session.__aexit__(None, None, None)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
error_msg = f"Error closing client session: {e}"
|
| 100 |
+
logger.warning(error_msg)
|
| 101 |
+
errors.append(error_msg)
|
| 102 |
+
finally:
|
| 103 |
+
self.client_session = None
|
| 104 |
+
|
| 105 |
+
# Reset tools, resources, and prompts
|
| 106 |
+
self._tools = None
|
| 107 |
+
self._resources = None
|
| 108 |
+
self._prompts = None
|
| 109 |
+
|
| 110 |
+
if errors:
|
| 111 |
+
logger.warning(f"Encountered {len(errors)} errors during MCP resource cleanup")
|
| 112 |
+
|
| 113 |
+
async def _cleanup_on_connect_failure(self) -> None:
|
| 114 |
+
"""Override to add MCP-specific cleanup on connection failure."""
|
| 115 |
+
# Clean up client session if it was created
|
| 116 |
+
if self.client_session:
|
| 117 |
+
try:
|
| 118 |
+
await self.client_session.__aexit__(None, None, None)
|
| 119 |
+
except Exception:
|
| 120 |
+
pass
|
| 121 |
+
finally:
|
| 122 |
+
self.client_session = None
|
| 123 |
+
|
| 124 |
+
# Call parent cleanup
|
| 125 |
+
await super()._cleanup_on_connect_failure()
|
| 126 |
+
|
| 127 |
+
async def initialize(self) -> dict[str, Any]:
|
| 128 |
+
"""Initialize the MCP session and return session information."""
|
| 129 |
+
if not self.client_session:
|
| 130 |
+
raise RuntimeError("MCP client is not connected")
|
| 131 |
+
|
| 132 |
+
logger.debug("Initializing MCP session")
|
| 133 |
+
|
| 134 |
+
# Initialize the session
|
| 135 |
+
result = await self.client_session.initialize()
|
| 136 |
+
|
| 137 |
+
server_capabilities = result.capabilities
|
| 138 |
+
|
| 139 |
+
if server_capabilities.tools:
|
| 140 |
+
# Get available tools
|
| 141 |
+
tools_result = await self.list_tools()
|
| 142 |
+
self._tools = tools_result or []
|
| 143 |
+
else:
|
| 144 |
+
self._tools = []
|
| 145 |
+
|
| 146 |
+
if server_capabilities.resources:
|
| 147 |
+
# Get available resources
|
| 148 |
+
resources_result = await self.list_resources()
|
| 149 |
+
self._resources = resources_result or []
|
| 150 |
+
else:
|
| 151 |
+
self._resources = []
|
| 152 |
+
|
| 153 |
+
if server_capabilities.prompts:
|
| 154 |
+
# Get available prompts
|
| 155 |
+
prompts_result = await self.list_prompts()
|
| 156 |
+
self._prompts = prompts_result or []
|
| 157 |
+
else:
|
| 158 |
+
self._prompts = []
|
| 159 |
+
|
| 160 |
+
logger.debug(
|
| 161 |
+
f"MCP session initialized with {len(self._tools)} tools, "
|
| 162 |
+
f"{len(self._resources)} resources, "
|
| 163 |
+
f"and {len(self._prompts)} prompts"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return result
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def tools(self) -> list[Tool]:
|
| 170 |
+
"""Get the list of available tools."""
|
| 171 |
+
if self._tools is None:
|
| 172 |
+
raise RuntimeError("MCP client is not initialized")
|
| 173 |
+
return self._tools
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def resources(self) -> list[Resource]:
|
| 177 |
+
"""Get the list of available resources."""
|
| 178 |
+
if self._resources is None:
|
| 179 |
+
raise RuntimeError("MCP client is not initialized")
|
| 180 |
+
return self._resources
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def prompts(self) -> list[Prompt]:
|
| 184 |
+
"""Get the list of available prompts."""
|
| 185 |
+
if self._prompts is None:
|
| 186 |
+
raise RuntimeError("MCP client is not initialized")
|
| 187 |
+
return self._prompts
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def is_connected(self) -> bool:
|
| 191 |
+
"""Check if the connector is actually connected and the connection is alive.
|
| 192 |
+
|
| 193 |
+
This property checks not only the connected flag but also verifies that
|
| 194 |
+
the client session exists and the underlying connection is still active.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
True if the connector is connected and the connection is alive, False otherwise.
|
| 198 |
+
"""
|
| 199 |
+
# First check the basic connected flag
|
| 200 |
+
if not self._connected:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
# Check if we have a client session
|
| 204 |
+
if not self.client_session:
|
| 205 |
+
self._connected = False
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
# Check if connection manager task is still running (if applicable)
|
| 209 |
+
if self._connection_manager and hasattr(self._connection_manager, "_task"):
|
| 210 |
+
task = self._connection_manager._task
|
| 211 |
+
if task and task.done():
|
| 212 |
+
logger.debug("Connection manager task is done, marking as disconnected")
|
| 213 |
+
self._connected = False
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
return True
|
| 217 |
+
|
| 218 |
+
async def _ensure_connected(self) -> None:
|
| 219 |
+
"""Ensure the connector is connected, reconnecting if necessary.
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
RuntimeError: If connection cannot be established and auto_reconnect is False.
|
| 223 |
+
"""
|
| 224 |
+
if not self.client_session:
|
| 225 |
+
raise RuntimeError("MCP client is not connected")
|
| 226 |
+
|
| 227 |
+
if not self.is_connected:
|
| 228 |
+
if self.auto_reconnect:
|
| 229 |
+
logger.debug("Connection lost, attempting to reconnect...")
|
| 230 |
+
try:
|
| 231 |
+
await self.connect()
|
| 232 |
+
logger.debug("Reconnection successful")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
raise RuntimeError(f"Failed to reconnect to MCP server: {e}") from e
|
| 235 |
+
else:
|
| 236 |
+
raise RuntimeError(
|
| 237 |
+
"Connection to MCP server has been lost. Auto-reconnection is disabled. Please reconnect manually."
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
async def call_tool(self, name: str, arguments: dict[str, Any]) -> CallToolResult:
|
| 241 |
+
"""Call an MCP tool with automatic reconnection handling and retry logic.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
name: The name of the tool to call.
|
| 245 |
+
arguments: The arguments to pass to the tool.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
The result of the tool call.
|
| 249 |
+
|
| 250 |
+
Raises:
|
| 251 |
+
RuntimeError: If the connection is lost and cannot be reestablished.
|
| 252 |
+
Exception: If the tool call fails after all retries.
|
| 253 |
+
"""
|
| 254 |
+
last_error: Exception | None = None
|
| 255 |
+
|
| 256 |
+
for attempt in range(self.tool_call_max_retries):
|
| 257 |
+
# Ensure we're connected
|
| 258 |
+
await self._ensure_connected()
|
| 259 |
+
|
| 260 |
+
logger.debug(f"Calling tool '{name}' with arguments: {arguments} (attempt {attempt + 1}/{self.tool_call_max_retries})")
|
| 261 |
+
try:
|
| 262 |
+
result = await self.client_session.call_tool(name, arguments)
|
| 263 |
+
logger.debug(f"Tool '{name}' called successfully")
|
| 264 |
+
return result
|
| 265 |
+
except Exception as e:
|
| 266 |
+
last_error = e
|
| 267 |
+
error_str = str(e).lower()
|
| 268 |
+
|
| 269 |
+
# Check if the error might be due to connection loss
|
| 270 |
+
if not self.is_connected:
|
| 271 |
+
logger.warning(f"Tool call '{name}' failed due to connection loss: {e}")
|
| 272 |
+
# Try to reconnect on next iteration
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
# Check for retryable HTTP errors (400, 500, 502, 503, 504)
|
| 276 |
+
is_retryable = any(code in error_str for code in ['400', '500', '502', '503', '504', 'bad request', 'internal server error', 'service unavailable', 'gateway timeout'])
|
| 277 |
+
|
| 278 |
+
if is_retryable and attempt < self.tool_call_max_retries - 1:
|
| 279 |
+
delay = self.tool_call_retry_delay * (2 ** attempt) # Exponential backoff
|
| 280 |
+
logger.warning(
|
| 281 |
+
f"Tool call '{name}' failed with retryable error: {e}, "
|
| 282 |
+
f"retrying in {delay:.1f}s (attempt {attempt + 1}/{self.tool_call_max_retries})"
|
| 283 |
+
)
|
| 284 |
+
await asyncio.sleep(delay)
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
# Non-retryable error or max retries reached, re-raise
|
| 288 |
+
raise
|
| 289 |
+
|
| 290 |
+
# All retries exhausted
|
| 291 |
+
error_msg = f"Tool call '{name}' failed after {self.tool_call_max_retries} retries"
|
| 292 |
+
logger.error(error_msg)
|
| 293 |
+
raise RuntimeError(error_msg) from last_error
|
| 294 |
+
|
| 295 |
+
async def list_tools(self) -> list[Tool]:
|
| 296 |
+
"""List all available tools from the MCP implementation."""
|
| 297 |
+
|
| 298 |
+
# Ensure we're connected
|
| 299 |
+
await self._ensure_connected()
|
| 300 |
+
|
| 301 |
+
logger.debug("Listing tools")
|
| 302 |
+
try:
|
| 303 |
+
result = await self.client_session.list_tools()
|
| 304 |
+
return result.tools
|
| 305 |
+
except McpError as e:
|
| 306 |
+
logger.error(f"Error listing tools: {e}")
|
| 307 |
+
return []
|
| 308 |
+
|
| 309 |
+
async def list_resources(self) -> list[Resource]:
|
| 310 |
+
"""List all available resources from the MCP implementation."""
|
| 311 |
+
# Ensure we're connected
|
| 312 |
+
await self._ensure_connected()
|
| 313 |
+
|
| 314 |
+
logger.debug("Listing resources")
|
| 315 |
+
try:
|
| 316 |
+
result = await self.client_session.list_resources()
|
| 317 |
+
return result.resources
|
| 318 |
+
except McpError as e:
|
| 319 |
+
logger.error(f"Error listing resources: {e}")
|
| 320 |
+
return []
|
| 321 |
+
|
| 322 |
+
async def read_resource(self, uri: str) -> ReadResourceResult:
|
| 323 |
+
"""Read a resource by URI."""
|
| 324 |
+
if not self.client_session:
|
| 325 |
+
raise RuntimeError("MCP client is not connected")
|
| 326 |
+
|
| 327 |
+
logger.debug(f"Reading resource: {uri}")
|
| 328 |
+
result = await self.client_session.read_resource(uri)
|
| 329 |
+
return result
|
| 330 |
+
|
| 331 |
+
async def list_prompts(self) -> list[Prompt]:
|
| 332 |
+
"""List all available prompts from the MCP implementation."""
|
| 333 |
+
# Ensure we're connected
|
| 334 |
+
await self._ensure_connected()
|
| 335 |
+
|
| 336 |
+
logger.debug("Listing prompts")
|
| 337 |
+
try:
|
| 338 |
+
result = await self.client_session.list_prompts()
|
| 339 |
+
return result.prompts
|
| 340 |
+
except McpError as e:
|
| 341 |
+
logger.error(f"Error listing prompts: {e}")
|
| 342 |
+
return []
|
| 343 |
+
|
| 344 |
+
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
|
| 345 |
+
"""Get a prompt by name."""
|
| 346 |
+
# Ensure we're connected
|
| 347 |
+
await self._ensure_connected()
|
| 348 |
+
|
| 349 |
+
logger.debug(f"Getting prompt: {name}")
|
| 350 |
+
result = await self.client_session.get_prompt(name, arguments)
|
| 351 |
+
return result
|
| 352 |
+
|
| 353 |
+
async def request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
| 354 |
+
"""Send a raw request to the MCP implementation."""
|
| 355 |
+
# Ensure we're connected
|
| 356 |
+
await self._ensure_connected()
|
| 357 |
+
|
| 358 |
+
logger.debug(f"Sending request: {method} with params: {params}")
|
| 359 |
+
return await self.client_session.request({"method": method, "params": params or {}})
|
| 360 |
+
|
| 361 |
+
async def invoke(self, name: str, params: dict[str, Any]) -> Any:
|
| 362 |
+
await self._ensure_connected()
|
| 363 |
+
|
| 364 |
+
if not name.startswith("__"):
|
| 365 |
+
return await self.call_tool(name, params)
|
| 366 |
+
|
| 367 |
+
if name == "__read_resource__":
|
| 368 |
+
return await self.read_resource(params["uri"])
|
| 369 |
+
if name == "__list_prompts__":
|
| 370 |
+
return await self.list_prompts()
|
| 371 |
+
if name == "__get_prompt__":
|
| 372 |
+
return await self.get_prompt(params["name"], params.get("args"))
|
| 373 |
+
|
| 374 |
+
raise ValueError(f"Unsupported MCP invoke name: {name}")
|
openspace/grounding/backends/mcp/transport/connectors/http.py
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HTTP connector for MCP implementations.
|
| 3 |
+
|
| 4 |
+
This module provides a connector for communicating with MCP implementations
|
| 5 |
+
through HTTP APIs with SSE, Streamable HTTP, or simple JSON-RPC for transport.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import anyio
|
| 10 |
+
import httpx
|
| 11 |
+
from typing import Any, Dict, List
|
| 12 |
+
from mcp import ClientSession
|
| 13 |
+
from mcp.types import (
|
| 14 |
+
CallToolResult,
|
| 15 |
+
TextContent,
|
| 16 |
+
ImageContent,
|
| 17 |
+
EmbeddedResource,
|
| 18 |
+
Tool,
|
| 19 |
+
Resource,
|
| 20 |
+
Prompt,
|
| 21 |
+
GetPromptResult,
|
| 22 |
+
ReadResourceResult,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from openspace.utils.logging import Logger
|
| 26 |
+
from openspace.grounding.core.transport.task_managers.base import BaseConnectionManager
|
| 27 |
+
from openspace.grounding.backends.mcp.transport.task_managers import SseConnectionManager, StreamableHttpConnectionManager
|
| 28 |
+
from openspace.grounding.backends.mcp.transport.connectors.base import MCPBaseConnector, DEFAULT_TOOL_CALL_MAX_RETRIES, DEFAULT_TOOL_CALL_RETRY_DELAY
|
| 29 |
+
|
| 30 |
+
logger = Logger.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HttpConnector(MCPBaseConnector):
|
| 34 |
+
"""Connector for MCP implementations using HTTP transport.
|
| 35 |
+
|
| 36 |
+
This connector uses HTTP/SSE or streamable HTTP to communicate with remote MCP implementations,
|
| 37 |
+
using a connection manager to handle the proper lifecycle management.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
base_url: str,
|
| 43 |
+
auth_token: str | None = None,
|
| 44 |
+
headers: dict[str, str] | None = None,
|
| 45 |
+
timeout: float = 5,
|
| 46 |
+
sse_read_timeout: float = 60 * 5,
|
| 47 |
+
tool_call_max_retries: int = DEFAULT_TOOL_CALL_MAX_RETRIES,
|
| 48 |
+
tool_call_retry_delay: float = DEFAULT_TOOL_CALL_RETRY_DELAY,
|
| 49 |
+
):
|
| 50 |
+
"""Initialize a new HTTP connector.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
base_url: The base URL of the MCP HTTP API.
|
| 54 |
+
auth_token: Optional authentication token.
|
| 55 |
+
headers: Optional additional headers.
|
| 56 |
+
timeout: Timeout for HTTP operations in seconds.
|
| 57 |
+
sse_read_timeout: Timeout for SSE read operations in seconds.
|
| 58 |
+
tool_call_max_retries: Maximum number of retries for tool calls (default: 3)
|
| 59 |
+
tool_call_retry_delay: Initial delay between retries in seconds (default: 1.0)
|
| 60 |
+
"""
|
| 61 |
+
self.base_url = base_url.rstrip("/")
|
| 62 |
+
self.auth_token = auth_token
|
| 63 |
+
self.headers = headers or {}
|
| 64 |
+
if auth_token:
|
| 65 |
+
self.headers["Authorization"] = f"Bearer {auth_token}"
|
| 66 |
+
self.timeout = timeout
|
| 67 |
+
self.sse_read_timeout = sse_read_timeout
|
| 68 |
+
|
| 69 |
+
# JSON-RPC HTTP mode fields
|
| 70 |
+
self._use_jsonrpc = False
|
| 71 |
+
self._jsonrpc_client: httpx.AsyncClient | None = None
|
| 72 |
+
self._jsonrpc_request_id = 0
|
| 73 |
+
|
| 74 |
+
# Create a placeholder connection manager (will be set up later in connect())
|
| 75 |
+
# We use a placeholder here because the actual transport type (SSE vs Streamable HTTP)
|
| 76 |
+
# can only be determined at runtime through server negotiation as per MCP specification
|
| 77 |
+
from openspace.grounding.core.transport.task_managers import PlaceholderConnectionManager
|
| 78 |
+
connection_manager = PlaceholderConnectionManager()
|
| 79 |
+
super().__init__(
|
| 80 |
+
connection_manager,
|
| 81 |
+
tool_call_max_retries=tool_call_max_retries,
|
| 82 |
+
tool_call_retry_delay=tool_call_retry_delay,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
async def connect(self) -> None:
|
| 86 |
+
"""Create the underlying session/connection.
|
| 87 |
+
|
| 88 |
+
For JSON-RPC mode, we don't use a connection manager.
|
| 89 |
+
"""
|
| 90 |
+
if self._connected:
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Hook: before connection - this sets up transport type
|
| 95 |
+
await self._before_connect()
|
| 96 |
+
|
| 97 |
+
if self._use_jsonrpc:
|
| 98 |
+
# JSON-RPC mode doesn't use connection manager
|
| 99 |
+
# Just call _after_connect to set up the HTTP client
|
| 100 |
+
await self._after_connect()
|
| 101 |
+
self._connected = True
|
| 102 |
+
else:
|
| 103 |
+
# Use normal connection flow with connection manager
|
| 104 |
+
# If _before_connect() already established a connection, reuse it
|
| 105 |
+
if self._connection is None:
|
| 106 |
+
self._connection = await self._connection_manager.start()
|
| 107 |
+
await self._after_connect()
|
| 108 |
+
self._connected = True
|
| 109 |
+
except Exception:
|
| 110 |
+
await self._cleanup_on_connect_failure()
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
async def disconnect(self) -> None:
|
| 114 |
+
"""Close the session/connection and reset state."""
|
| 115 |
+
if not self._connected:
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
# Hook: before disconnection
|
| 119 |
+
await self._before_disconnect()
|
| 120 |
+
|
| 121 |
+
if not self._use_jsonrpc:
|
| 122 |
+
# Stop the connection manager only for non-JSON-RPC modes
|
| 123 |
+
if self._connection_manager:
|
| 124 |
+
await self._connection_manager.stop()
|
| 125 |
+
self._connection = None
|
| 126 |
+
|
| 127 |
+
# Hook: after disconnection
|
| 128 |
+
await self._after_disconnect()
|
| 129 |
+
|
| 130 |
+
self._connected = False
|
| 131 |
+
|
| 132 |
+
async def _before_connect(self) -> None:
|
| 133 |
+
"""Negotiate transport type and set up the appropriate connection manager.
|
| 134 |
+
|
| 135 |
+
Tries transports in order:
|
| 136 |
+
1. Streamable HTTP (new MCP transport)
|
| 137 |
+
2. SSE (legacy MCP transport)
|
| 138 |
+
3. Simple JSON-RPC HTTP (for custom servers)
|
| 139 |
+
|
| 140 |
+
This implements backwards compatibility per MCP specification.
|
| 141 |
+
"""
|
| 142 |
+
self.transport_type = None
|
| 143 |
+
self._use_jsonrpc = False
|
| 144 |
+
connection_manager = None
|
| 145 |
+
streamable_error = None
|
| 146 |
+
sse_error = None
|
| 147 |
+
|
| 148 |
+
# First, try the new streamable HTTP transport
|
| 149 |
+
try:
|
| 150 |
+
logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}")
|
| 151 |
+
connection_manager = StreamableHttpConnectionManager(
|
| 152 |
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Test the connection by starting it with built-in timeout
|
| 156 |
+
read_stream, write_stream = await connection_manager.start(timeout=self.timeout)
|
| 157 |
+
|
| 158 |
+
# Create and verify ClientSession
|
| 159 |
+
test_client = ClientSession(read_stream, write_stream, sampling_callback=None)
|
| 160 |
+
|
| 161 |
+
# Add timeout to __aenter__ - use asyncio.wait_for instead of anyio.fail_after
|
| 162 |
+
# to avoid cancel scope conflicts with background tasks
|
| 163 |
+
try:
|
| 164 |
+
await asyncio.wait_for(test_client.__aenter__(), timeout=self.timeout)
|
| 165 |
+
except asyncio.TimeoutError:
|
| 166 |
+
raise TimeoutError(f"ClientSession enter timed out after {self.timeout}s")
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Add timeout to initialize() using asyncio.wait_for to prevent hanging
|
| 170 |
+
try:
|
| 171 |
+
await asyncio.wait_for(test_client.initialize(), timeout=self.timeout)
|
| 172 |
+
except asyncio.TimeoutError:
|
| 173 |
+
raise TimeoutError(f"initialize() timed out after {self.timeout}s")
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
await asyncio.wait_for(test_client.list_tools(), timeout=self.timeout)
|
| 177 |
+
except asyncio.TimeoutError:
|
| 178 |
+
raise TimeoutError(f"list_tools() timed out after {self.timeout}s")
|
| 179 |
+
|
| 180 |
+
# SUCCESS! Keep the client session (don't close it, closing destroys the streams)
|
| 181 |
+
# Store it directly as the client_session for later use
|
| 182 |
+
self.transport_type = "streamable HTTP"
|
| 183 |
+
self._connection_manager = connection_manager
|
| 184 |
+
self._connection = connection_manager.get_streams()
|
| 185 |
+
self.client_session = test_client # Reuse the working session
|
| 186 |
+
logger.debug("Streamable HTTP transport selected")
|
| 187 |
+
return
|
| 188 |
+
except TimeoutError:
|
| 189 |
+
try:
|
| 190 |
+
await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2)
|
| 191 |
+
except (asyncio.TimeoutError, Exception):
|
| 192 |
+
pass
|
| 193 |
+
raise
|
| 194 |
+
except Exception as init_error:
|
| 195 |
+
# Clean up the test client only on error
|
| 196 |
+
try:
|
| 197 |
+
await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2)
|
| 198 |
+
except (asyncio.TimeoutError, Exception):
|
| 199 |
+
pass
|
| 200 |
+
raise init_error
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
streamable_error = e
|
| 204 |
+
logger.debug(f"Streamable HTTP failed: {e}")
|
| 205 |
+
|
| 206 |
+
# Clean up the failed connection manager
|
| 207 |
+
if connection_manager:
|
| 208 |
+
try:
|
| 209 |
+
await asyncio.wait_for(connection_manager.stop(), timeout=2)
|
| 210 |
+
except (asyncio.TimeoutError, Exception):
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
# Try SSE fallback
|
| 214 |
+
try:
|
| 215 |
+
logger.debug(f"Attempting SSE fallback connection to: {self.base_url}")
|
| 216 |
+
connection_manager = SseConnectionManager(
|
| 217 |
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Test the connection by starting it with built-in timeout
|
| 221 |
+
read_stream, write_stream = await connection_manager.start(timeout=self.timeout)
|
| 222 |
+
|
| 223 |
+
# Create and verify ClientSession
|
| 224 |
+
test_client = ClientSession(read_stream, write_stream, sampling_callback=None)
|
| 225 |
+
|
| 226 |
+
# Add timeout to __aenter__ - use asyncio.wait_for instead of anyio.fail_after
|
| 227 |
+
# to avoid cancel scope conflicts with background tasks
|
| 228 |
+
try:
|
| 229 |
+
await asyncio.wait_for(test_client.__aenter__(), timeout=self.timeout)
|
| 230 |
+
except asyncio.TimeoutError:
|
| 231 |
+
raise TimeoutError(f"ClientSession enter timed out after {self.timeout}s")
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
try:
|
| 235 |
+
await asyncio.wait_for(test_client.initialize(), timeout=self.timeout)
|
| 236 |
+
except asyncio.TimeoutError:
|
| 237 |
+
raise TimeoutError(f"initialize() timed out after {self.timeout}s")
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
await asyncio.wait_for(test_client.list_tools(), timeout=self.timeout)
|
| 241 |
+
except asyncio.TimeoutError:
|
| 242 |
+
raise TimeoutError(f"list_tools() timed out after {self.timeout}s")
|
| 243 |
+
|
| 244 |
+
# SUCCESS! Keep the client session (don't close it, closing destroys the streams)
|
| 245 |
+
# Store it directly as the client_session for later use
|
| 246 |
+
self.transport_type = "SSE"
|
| 247 |
+
self._connection_manager = connection_manager
|
| 248 |
+
self._connection = connection_manager.get_streams()
|
| 249 |
+
self.client_session = test_client # Reuse the working session
|
| 250 |
+
logger.debug("SSE transport selected")
|
| 251 |
+
return
|
| 252 |
+
except TimeoutError:
|
| 253 |
+
try:
|
| 254 |
+
await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2)
|
| 255 |
+
except (asyncio.TimeoutError, Exception):
|
| 256 |
+
pass
|
| 257 |
+
raise
|
| 258 |
+
except Exception as init_error:
|
| 259 |
+
# Clean up the test client only on error
|
| 260 |
+
try:
|
| 261 |
+
await asyncio.wait_for(test_client.__aexit__(None, None, None), timeout=2)
|
| 262 |
+
except (asyncio.TimeoutError, Exception):
|
| 263 |
+
pass
|
| 264 |
+
raise init_error
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
sse_error = e
|
| 268 |
+
logger.debug(f"SSE failed: {e}")
|
| 269 |
+
|
| 270 |
+
# Clean up the failed connection manager
|
| 271 |
+
if connection_manager:
|
| 272 |
+
try:
|
| 273 |
+
await asyncio.wait_for(connection_manager.stop(), timeout=2)
|
| 274 |
+
except (asyncio.TimeoutError, Exception):
|
| 275 |
+
pass
|
| 276 |
+
|
| 277 |
+
# Both MCP transports failed, try simple JSON-RPC HTTP as last resort
|
| 278 |
+
# This is useful for custom MCP servers that don't implement proper MCP transports
|
| 279 |
+
logger.debug(f"Attempting JSON-RPC HTTP fallback to: {self.base_url}")
|
| 280 |
+
try:
|
| 281 |
+
# Test JSON-RPC connection
|
| 282 |
+
await self._try_jsonrpc_connection()
|
| 283 |
+
|
| 284 |
+
self.transport_type = "JSON-RPC HTTP"
|
| 285 |
+
self._use_jsonrpc = True
|
| 286 |
+
logger.info(f"JSON-RPC HTTP transport selected for: {self.base_url}")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
except Exception as jsonrpc_error:
|
| 290 |
+
# All transports failed
|
| 291 |
+
logger.error(
|
| 292 |
+
f"All transport methods failed for {self.base_url}. "
|
| 293 |
+
f"Streamable HTTP: {streamable_error}, SSE: {sse_error}, JSON-RPC: {jsonrpc_error}"
|
| 294 |
+
)
|
| 295 |
+
# Raise the most relevant error - prefer the original streamable error
|
| 296 |
+
raise streamable_error or sse_error or jsonrpc_error
|
| 297 |
+
|
| 298 |
+
async def _try_jsonrpc_connection(self) -> None:
|
| 299 |
+
"""Test JSON-RPC HTTP connection by sending an initialize request."""
|
| 300 |
+
headers = {**self.headers, "Content-Type": "application/json"}
|
| 301 |
+
|
| 302 |
+
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout), headers=headers) as client:
|
| 303 |
+
payload = {
|
| 304 |
+
"jsonrpc": "2.0",
|
| 305 |
+
"id": 1,
|
| 306 |
+
"method": "initialize",
|
| 307 |
+
"params": {
|
| 308 |
+
"protocolVersion": "2024-11-05",
|
| 309 |
+
"capabilities": {},
|
| 310 |
+
"clientInfo": {"name": "OpenSpace", "version": "1.0.0"},
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
response = await client.post(self.base_url, json=payload)
|
| 315 |
+
response.raise_for_status()
|
| 316 |
+
|
| 317 |
+
data = response.json()
|
| 318 |
+
|
| 319 |
+
# Check for JSON-RPC error
|
| 320 |
+
if "error" in data:
|
| 321 |
+
error = data["error"]
|
| 322 |
+
raise RuntimeError(f"JSON-RPC error: {error.get('message', str(error))}")
|
| 323 |
+
|
| 324 |
+
# Success - server supports JSON-RPC
|
| 325 |
+
logger.debug(f"JSON-RPC test succeeded: {data.get('result', {})}")
|
| 326 |
+
|
| 327 |
+
async def _after_connect(self) -> None:
|
| 328 |
+
"""Create ClientSession (or set up JSON-RPC client) and log success."""
|
| 329 |
+
if self._use_jsonrpc:
|
| 330 |
+
# Set up JSON-RPC HTTP client
|
| 331 |
+
headers = {**self.headers, "Content-Type": "application/json"}
|
| 332 |
+
self._jsonrpc_client = httpx.AsyncClient(
|
| 333 |
+
timeout=httpx.Timeout(self.timeout),
|
| 334 |
+
headers=headers,
|
| 335 |
+
)
|
| 336 |
+
logger.debug(f"JSON-RPC HTTP client set up for: {self.base_url}")
|
| 337 |
+
else:
|
| 338 |
+
# Skip creating ClientSession if _before_connect() already created one
|
| 339 |
+
if self.client_session is None:
|
| 340 |
+
await super()._after_connect()
|
| 341 |
+
else:
|
| 342 |
+
logger.debug("Reusing ClientSession from _before_connect()")
|
| 343 |
+
|
| 344 |
+
logger.debug(f"Successfully connected to MCP implementation via {self.transport_type}: {self.base_url}")
|
| 345 |
+
|
| 346 |
+
async def _before_disconnect(self) -> None:
|
| 347 |
+
"""Clean up resources before disconnection."""
|
| 348 |
+
# Clean up JSON-RPC client if used
|
| 349 |
+
if self._jsonrpc_client:
|
| 350 |
+
try:
|
| 351 |
+
await self._jsonrpc_client.aclose()
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logger.warning(f"Error closing JSON-RPC client: {e}")
|
| 354 |
+
finally:
|
| 355 |
+
self._jsonrpc_client = None
|
| 356 |
+
|
| 357 |
+
# Call parent cleanup for MCP resources
|
| 358 |
+
await super()._before_disconnect()
|
| 359 |
+
|
| 360 |
+
@property
|
| 361 |
+
def public_identifier(self) -> str:
|
| 362 |
+
"""Get the identifier for the connector."""
|
| 363 |
+
return {"type": self.transport_type, "base_url": self.base_url}
|
| 364 |
+
|
| 365 |
+
# =====================
|
| 366 |
+
# JSON-RPC HTTP Methods
|
| 367 |
+
# =====================
|
| 368 |
+
|
| 369 |
+
def _next_jsonrpc_id(self) -> int:
|
| 370 |
+
"""Get next JSON-RPC request ID."""
|
| 371 |
+
self._jsonrpc_request_id += 1
|
| 372 |
+
return self._jsonrpc_request_id
|
| 373 |
+
|
| 374 |
+
async def _send_jsonrpc_request(
|
| 375 |
+
self,
|
| 376 |
+
method: str,
|
| 377 |
+
params: Dict[str, Any] = None,
|
| 378 |
+
max_retries: int = 3,
|
| 379 |
+
retry_delay: float = 1.0,
|
| 380 |
+
) -> Any:
|
| 381 |
+
"""Send a JSON-RPC request and return the result.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
method: The JSON-RPC method name (e.g., "tools/list", "tools/call")
|
| 385 |
+
params: The method parameters
|
| 386 |
+
max_retries: Maximum number of retries for transient errors (400, 503, etc.)
|
| 387 |
+
retry_delay: Initial delay between retries (doubles each retry)
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
The result field from the JSON-RPC response
|
| 391 |
+
"""
|
| 392 |
+
if not self._jsonrpc_client:
|
| 393 |
+
raise RuntimeError("JSON-RPC client not initialized")
|
| 394 |
+
|
| 395 |
+
last_error = None
|
| 396 |
+
|
| 397 |
+
for attempt in range(max_retries):
|
| 398 |
+
request_id = self._next_jsonrpc_id()
|
| 399 |
+
payload = {
|
| 400 |
+
"jsonrpc": "2.0",
|
| 401 |
+
"id": request_id,
|
| 402 |
+
"method": method,
|
| 403 |
+
"params": params or {},
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
logger.debug(f"Sending JSON-RPC request: {method} (id={request_id}, attempt {attempt + 1}/{max_retries})")
|
| 407 |
+
|
| 408 |
+
try:
|
| 409 |
+
response = await self._jsonrpc_client.post(self.base_url, json=payload)
|
| 410 |
+
response.raise_for_status()
|
| 411 |
+
|
| 412 |
+
data = response.json()
|
| 413 |
+
|
| 414 |
+
if "error" in data:
|
| 415 |
+
error = data["error"]
|
| 416 |
+
error_msg = error.get("message", str(error))
|
| 417 |
+
raise RuntimeError(f"JSON-RPC error: {error_msg}")
|
| 418 |
+
|
| 419 |
+
return data.get("result", {})
|
| 420 |
+
|
| 421 |
+
except httpx.HTTPStatusError as e:
|
| 422 |
+
last_error = e
|
| 423 |
+
status_code = e.response.status_code
|
| 424 |
+
|
| 425 |
+
# Retry on 400 (Bad Request) and 5xx errors
|
| 426 |
+
# 400 can happen when MCP server is temporarily not ready
|
| 427 |
+
if status_code in (400, 500, 502, 503, 504) and attempt < max_retries - 1:
|
| 428 |
+
delay = retry_delay * (2 ** attempt)
|
| 429 |
+
logger.warning(
|
| 430 |
+
f"HTTP {status_code} error on {method}, retrying in {delay:.1f}s "
|
| 431 |
+
f"(attempt {attempt + 1}/{max_retries})"
|
| 432 |
+
)
|
| 433 |
+
await asyncio.sleep(delay)
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
raise RuntimeError(f"HTTP error: {status_code}") from e
|
| 437 |
+
|
| 438 |
+
except httpx.RequestError as e:
|
| 439 |
+
last_error = e
|
| 440 |
+
# Retry on connection errors
|
| 441 |
+
if attempt < max_retries - 1:
|
| 442 |
+
delay = retry_delay * (2 ** attempt)
|
| 443 |
+
logger.warning(
|
| 444 |
+
f"Request error on {method}: {e}, retrying in {delay:.1f}s "
|
| 445 |
+
f"(attempt {attempt + 1}/{max_retries})"
|
| 446 |
+
)
|
| 447 |
+
await asyncio.sleep(delay)
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
raise RuntimeError(f"Request error: {e}") from e
|
| 451 |
+
|
| 452 |
+
# Should not reach here, but just in case
|
| 453 |
+
raise RuntimeError(f"Max retries exceeded for {method}") from last_error
|
| 454 |
+
|
| 455 |
+
def _parse_tools_from_json(self, tools_data: List[Dict]) -> List[Tool]:
|
| 456 |
+
"""Parse tool data into Tool objects."""
|
| 457 |
+
tools = []
|
| 458 |
+
for tool_dict in tools_data:
|
| 459 |
+
try:
|
| 460 |
+
tool = Tool(
|
| 461 |
+
name=tool_dict.get("name", ""),
|
| 462 |
+
description=tool_dict.get("description", ""),
|
| 463 |
+
inputSchema=tool_dict.get("inputSchema", {}),
|
| 464 |
+
)
|
| 465 |
+
tools.append(tool)
|
| 466 |
+
except Exception as e:
|
| 467 |
+
logger.warning(f"Failed to parse tool: {e}")
|
| 468 |
+
return tools
|
| 469 |
+
|
| 470 |
+
def _parse_resources_from_json(self, resources_data: List[Dict]) -> List[Resource]:
|
| 471 |
+
"""Parse resource data into Resource objects."""
|
| 472 |
+
resources = []
|
| 473 |
+
for res_dict in resources_data:
|
| 474 |
+
try:
|
| 475 |
+
resource = Resource(
|
| 476 |
+
uri=res_dict.get("uri", ""),
|
| 477 |
+
name=res_dict.get("name", ""),
|
| 478 |
+
description=res_dict.get("description"),
|
| 479 |
+
mimeType=res_dict.get("mimeType"),
|
| 480 |
+
)
|
| 481 |
+
resources.append(resource)
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.warning(f"Failed to parse resource: {e}")
|
| 484 |
+
return resources
|
| 485 |
+
|
| 486 |
+
def _parse_prompts_from_json(self, prompts_data: List[Dict]) -> List[Prompt]:
|
| 487 |
+
"""Parse prompt data into Prompt objects."""
|
| 488 |
+
prompts = []
|
| 489 |
+
for prompt_dict in prompts_data:
|
| 490 |
+
try:
|
| 491 |
+
prompt = Prompt(
|
| 492 |
+
name=prompt_dict.get("name", ""),
|
| 493 |
+
description=prompt_dict.get("description"),
|
| 494 |
+
arguments=prompt_dict.get("arguments"),
|
| 495 |
+
)
|
| 496 |
+
prompts.append(prompt)
|
| 497 |
+
except Exception as e:
|
| 498 |
+
logger.warning(f"Failed to parse prompt: {e}")
|
| 499 |
+
return prompts
|
| 500 |
+
|
| 501 |
+
# =====================
|
| 502 |
+
# Override MCP Methods for JSON-RPC Support
|
| 503 |
+
# =====================
|
| 504 |
+
|
| 505 |
+
async def initialize(self) -> Dict[str, Any]:
|
| 506 |
+
"""Initialize the MCP session."""
|
| 507 |
+
if not self._use_jsonrpc:
|
| 508 |
+
return await super().initialize()
|
| 509 |
+
|
| 510 |
+
# JSON-RPC mode
|
| 511 |
+
logger.debug("Initializing JSON-RPC HTTP MCP session")
|
| 512 |
+
|
| 513 |
+
result = await self._send_jsonrpc_request("initialize", {
|
| 514 |
+
"protocolVersion": "2024-11-05",
|
| 515 |
+
"capabilities": {},
|
| 516 |
+
"clientInfo": {"name": "OpenSpace", "version": "1.0.0"},
|
| 517 |
+
})
|
| 518 |
+
|
| 519 |
+
capabilities = result.get("capabilities", {})
|
| 520 |
+
|
| 521 |
+
# List tools
|
| 522 |
+
if capabilities.get("tools"):
|
| 523 |
+
try:
|
| 524 |
+
tools_result = await self._send_jsonrpc_request("tools/list", {})
|
| 525 |
+
self._tools = self._parse_tools_from_json(tools_result.get("tools", []))
|
| 526 |
+
except Exception:
|
| 527 |
+
self._tools = []
|
| 528 |
+
else:
|
| 529 |
+
# Try anyway - some servers don't advertise capabilities correctly
|
| 530 |
+
try:
|
| 531 |
+
tools_result = await self._send_jsonrpc_request("tools/list", {})
|
| 532 |
+
self._tools = self._parse_tools_from_json(tools_result.get("tools", []))
|
| 533 |
+
except Exception:
|
| 534 |
+
self._tools = []
|
| 535 |
+
|
| 536 |
+
# List resources
|
| 537 |
+
if capabilities.get("resources"):
|
| 538 |
+
try:
|
| 539 |
+
resources_result = await self._send_jsonrpc_request("resources/list", {})
|
| 540 |
+
self._resources = self._parse_resources_from_json(resources_result.get("resources", []))
|
| 541 |
+
except Exception:
|
| 542 |
+
self._resources = []
|
| 543 |
+
else:
|
| 544 |
+
self._resources = []
|
| 545 |
+
|
| 546 |
+
# List prompts
|
| 547 |
+
if capabilities.get("prompts"):
|
| 548 |
+
try:
|
| 549 |
+
prompts_result = await self._send_jsonrpc_request("prompts/list", {})
|
| 550 |
+
self._prompts = self._parse_prompts_from_json(prompts_result.get("prompts", []))
|
| 551 |
+
except Exception:
|
| 552 |
+
self._prompts = []
|
| 553 |
+
else:
|
| 554 |
+
self._prompts = []
|
| 555 |
+
|
| 556 |
+
logger.info(
|
| 557 |
+
f"JSON-RPC HTTP MCP session initialized with {len(self._tools)} tools, "
|
| 558 |
+
f"{len(self._resources)} resources, {len(self._prompts)} prompts"
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
return result
|
| 562 |
+
|
| 563 |
+
@property
|
| 564 |
+
def is_connected(self) -> bool:
|
| 565 |
+
"""Check if the connector is connected."""
|
| 566 |
+
if self._use_jsonrpc:
|
| 567 |
+
return self._connected and self._jsonrpc_client is not None
|
| 568 |
+
return super().is_connected
|
| 569 |
+
|
| 570 |
+
async def _ensure_connected(self) -> None:
|
| 571 |
+
"""Ensure the connector is connected."""
|
| 572 |
+
if self._use_jsonrpc:
|
| 573 |
+
if not self._connected or not self._jsonrpc_client:
|
| 574 |
+
raise RuntimeError("JSON-RPC HTTP connector is not connected")
|
| 575 |
+
else:
|
| 576 |
+
await super()._ensure_connected()
|
| 577 |
+
|
| 578 |
+
async def list_tools(self) -> List[Tool]:
|
| 579 |
+
"""List all available tools."""
|
| 580 |
+
if not self._use_jsonrpc:
|
| 581 |
+
return await super().list_tools()
|
| 582 |
+
|
| 583 |
+
await self._ensure_connected()
|
| 584 |
+
try:
|
| 585 |
+
tools_result = await self._send_jsonrpc_request("tools/list", {})
|
| 586 |
+
self._tools = self._parse_tools_from_json(tools_result.get("tools", []))
|
| 587 |
+
return self._tools
|
| 588 |
+
except Exception as e:
|
| 589 |
+
logger.error(f"Error listing tools: {e}")
|
| 590 |
+
return []
|
| 591 |
+
|
| 592 |
+
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> CallToolResult:
|
| 593 |
+
"""Call an MCP tool."""
|
| 594 |
+
if not self._use_jsonrpc:
|
| 595 |
+
return await super().call_tool(name, arguments)
|
| 596 |
+
|
| 597 |
+
await self._ensure_connected()
|
| 598 |
+
logger.debug(f"Calling tool '{name}' with arguments: {arguments}")
|
| 599 |
+
|
| 600 |
+
result = await self._send_jsonrpc_request("tools/call", {
|
| 601 |
+
"name": name,
|
| 602 |
+
"arguments": arguments,
|
| 603 |
+
})
|
| 604 |
+
|
| 605 |
+
# Parse the result into CallToolResult
|
| 606 |
+
content = []
|
| 607 |
+
for item in result.get("content", []):
|
| 608 |
+
item_type = item.get("type", "text")
|
| 609 |
+
if item_type == "text":
|
| 610 |
+
content.append(TextContent(type="text", text=item.get("text", "")))
|
| 611 |
+
elif item_type == "image":
|
| 612 |
+
content.append(ImageContent(
|
| 613 |
+
type="image",
|
| 614 |
+
data=item.get("data", ""),
|
| 615 |
+
mimeType=item.get("mimeType", "image/png"),
|
| 616 |
+
))
|
| 617 |
+
elif item_type == "resource":
|
| 618 |
+
content.append(EmbeddedResource(
|
| 619 |
+
type="resource",
|
| 620 |
+
resource=item.get("resource", {}),
|
| 621 |
+
))
|
| 622 |
+
|
| 623 |
+
if not content and result:
|
| 624 |
+
content.append(TextContent(type="text", text=str(result)))
|
| 625 |
+
|
| 626 |
+
return CallToolResult(
|
| 627 |
+
content=content,
|
| 628 |
+
isError=result.get("isError", False),
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
async def list_resources(self) -> List[Resource]:
|
| 632 |
+
"""List all available resources."""
|
| 633 |
+
if not self._use_jsonrpc:
|
| 634 |
+
return await super().list_resources()
|
| 635 |
+
|
| 636 |
+
await self._ensure_connected()
|
| 637 |
+
try:
|
| 638 |
+
resources_result = await self._send_jsonrpc_request("resources/list", {})
|
| 639 |
+
self._resources = self._parse_resources_from_json(resources_result.get("resources", []))
|
| 640 |
+
return self._resources
|
| 641 |
+
except Exception as e:
|
| 642 |
+
logger.error(f"Error listing resources: {e}")
|
| 643 |
+
return []
|
| 644 |
+
|
| 645 |
+
async def read_resource(self, uri: str) -> ReadResourceResult:
|
| 646 |
+
"""Read a resource by URI."""
|
| 647 |
+
if not self._use_jsonrpc:
|
| 648 |
+
return await super().read_resource(uri)
|
| 649 |
+
|
| 650 |
+
await self._ensure_connected()
|
| 651 |
+
result = await self._send_jsonrpc_request("resources/read", {"uri": uri})
|
| 652 |
+
return ReadResourceResult(**result)
|
| 653 |
+
|
| 654 |
+
async def list_prompts(self) -> List[Prompt]:
|
| 655 |
+
"""List all available prompts."""
|
| 656 |
+
if not self._use_jsonrpc:
|
| 657 |
+
return await super().list_prompts()
|
| 658 |
+
|
| 659 |
+
await self._ensure_connected()
|
| 660 |
+
try:
|
| 661 |
+
prompts_result = await self._send_jsonrpc_request("prompts/list", {})
|
| 662 |
+
self._prompts = self._parse_prompts_from_json(prompts_result.get("prompts", []))
|
| 663 |
+
return self._prompts
|
| 664 |
+
except Exception as e:
|
| 665 |
+
logger.error(f"Error listing prompts: {e}")
|
| 666 |
+
return []
|
| 667 |
+
|
| 668 |
+
async def get_prompt(self, name: str, arguments: Dict[str, Any] | None = None) -> GetPromptResult:
|
| 669 |
+
"""Get a prompt by name."""
|
| 670 |
+
if not self._use_jsonrpc:
|
| 671 |
+
return await super().get_prompt(name, arguments)
|
| 672 |
+
|
| 673 |
+
await self._ensure_connected()
|
| 674 |
+
result = await self._send_jsonrpc_request("prompts/get", {
|
| 675 |
+
"name": name,
|
| 676 |
+
"arguments": arguments or {},
|
| 677 |
+
})
|
| 678 |
+
return GetPromptResult(**result)
|
| 679 |
+
|
| 680 |
+
async def request(self, method: str, params: Dict[str, Any] | None = None) -> Any:
|
| 681 |
+
"""Send a raw request to the MCP implementation."""
|
| 682 |
+
if not self._use_jsonrpc:
|
| 683 |
+
return await super().request(method, params)
|
| 684 |
+
|
| 685 |
+
await self._ensure_connected()
|
| 686 |
+
return await self._send_jsonrpc_request(method, params or {})
|
| 687 |
+
|
| 688 |
+
async def invoke(self, name: str, params: Dict[str, Any]) -> Any:
|
| 689 |
+
"""Invoke a tool or special method."""
|
| 690 |
+
if not self._use_jsonrpc:
|
| 691 |
+
return await super().invoke(name, params)
|
| 692 |
+
|
| 693 |
+
await self._ensure_connected()
|
| 694 |
+
|
| 695 |
+
if not name.startswith("__"):
|
| 696 |
+
return await self.call_tool(name, params)
|
| 697 |
+
|
| 698 |
+
if name == "__read_resource__":
|
| 699 |
+
return await self.read_resource(params["uri"])
|
| 700 |
+
if name == "__list_prompts__":
|
| 701 |
+
return await self.list_prompts()
|
| 702 |
+
if name == "__get_prompt__":
|
| 703 |
+
return await self.get_prompt(params["name"], params.get("args"))
|
| 704 |
+
|
| 705 |
+
raise ValueError(f"Unsupported MCP invoke name: {name}")
|
openspace/grounding/backends/mcp/transport/connectors/sandbox.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sandbox connector for MCP implementations.
|
| 3 |
+
|
| 4 |
+
This module provides a connector for communicating with MCP implementations
|
| 5 |
+
that are executed inside a sandbox environment (supports any BaseSandbox implementation).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import aiohttp
|
| 13 |
+
from mcp import ClientSession
|
| 14 |
+
|
| 15 |
+
from openspace.utils.logging import Logger
|
| 16 |
+
from openspace.grounding.backends.mcp.transport.task_managers import SseConnectionManager
|
| 17 |
+
from openspace.grounding.core.security import BaseSandbox
|
| 18 |
+
from openspace.grounding.backends.mcp.transport.connectors.base import MCPBaseConnector
|
| 19 |
+
|
| 20 |
+
logger = Logger.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SandboxConnector(MCPBaseConnector):
|
| 24 |
+
"""Connector for MCP implementations running in a sandbox environment.
|
| 25 |
+
|
| 26 |
+
This connector runs a user-defined stdio command within a sandbox environment
|
| 27 |
+
through a BaseSandbox implementation (e.g., E2BSandbox), potentially wrapped
|
| 28 |
+
by a utility like 'supergateway' to expose its stdio.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
sandbox: BaseSandbox,
|
| 34 |
+
command: str,
|
| 35 |
+
args: list[str],
|
| 36 |
+
env: dict[str, str] | None = None,
|
| 37 |
+
supergateway_command: str = "npx -y supergateway",
|
| 38 |
+
port: int = 3000,
|
| 39 |
+
timeout: float = 5,
|
| 40 |
+
sse_read_timeout: float = 60 * 5,
|
| 41 |
+
):
|
| 42 |
+
"""Initialize a new sandbox connector.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
sandbox: A BaseSandbox implementation (e.g., E2BSandbox) to run commands in.
|
| 46 |
+
command: The user's MCP server command to execute in the sandbox.
|
| 47 |
+
args: Command line arguments for the user's MCP server command.
|
| 48 |
+
env: Environment variables for the user's MCP server command.
|
| 49 |
+
supergateway_command: Command to run supergateway (default: "npx -y supergateway").
|
| 50 |
+
port: Port number for the sandbox server (default: 3000).
|
| 51 |
+
timeout: Timeout for the sandbox process in seconds.
|
| 52 |
+
sse_read_timeout: Timeout for the SSE connection in seconds.
|
| 53 |
+
"""
|
| 54 |
+
# Store user command configuration
|
| 55 |
+
self.user_command = command
|
| 56 |
+
self.user_args = args or []
|
| 57 |
+
self.user_env = env or {}
|
| 58 |
+
self.port = port
|
| 59 |
+
|
| 60 |
+
# Create a placeholder connection manager (will be set up in connect())
|
| 61 |
+
# We need the sandbox to start first to get the base_url, so we can't create
|
| 62 |
+
# the real SseConnectionManager until connect() is called
|
| 63 |
+
from openspace.grounding.core.transport.task_managers import PlaceholderConnectionManager
|
| 64 |
+
connection_manager = PlaceholderConnectionManager()
|
| 65 |
+
super().__init__(connection_manager)
|
| 66 |
+
|
| 67 |
+
# Sandbox configuration
|
| 68 |
+
self._sandbox = sandbox
|
| 69 |
+
self.supergateway_cmd_parts = supergateway_command
|
| 70 |
+
|
| 71 |
+
# Runtime state
|
| 72 |
+
self.process = None
|
| 73 |
+
self.client_session: ClientSession | None = None
|
| 74 |
+
self.errlog = sys.stderr
|
| 75 |
+
self.base_url: str | None = None
|
| 76 |
+
self._connected = False
|
| 77 |
+
self._connection_manager: SseConnectionManager | None = None
|
| 78 |
+
|
| 79 |
+
# SSE connection parameters
|
| 80 |
+
self.headers = {}
|
| 81 |
+
self.timeout = timeout
|
| 82 |
+
self.sse_read_timeout = sse_read_timeout
|
| 83 |
+
|
| 84 |
+
self.stdout_lines: list[str] = []
|
| 85 |
+
self.stderr_lines: list[str] = []
|
| 86 |
+
self._server_ready = asyncio.Event()
|
| 87 |
+
|
| 88 |
+
def _handle_stdout(self, data: str) -> None:
|
| 89 |
+
"""Handle stdout data from the sandbox process."""
|
| 90 |
+
self.stdout_lines.append(data)
|
| 91 |
+
logger.debug(f"[SANDBOX STDOUT] {data}", end="", flush=True)
|
| 92 |
+
|
| 93 |
+
def _handle_stderr(self, data: str) -> None:
|
| 94 |
+
"""Handle stderr data from the sandbox process."""
|
| 95 |
+
self.stderr_lines.append(data)
|
| 96 |
+
logger.debug(f"[SANDBOX STDERR] {data}", file=self.errlog, end="", flush=True)
|
| 97 |
+
|
| 98 |
+
async def wait_for_server_response(self, base_url: str, timeout: int = 30) -> bool:
|
| 99 |
+
"""Wait for the server to respond to HTTP requests.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
base_url: The base URL to check for server readiness
|
| 103 |
+
timeout: Maximum time to wait in seconds
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
True if server is responding, raises TimeoutError otherwise
|
| 107 |
+
"""
|
| 108 |
+
logger.info(f"Waiting for server at {base_url} to respond...")
|
| 109 |
+
sys.stdout.flush()
|
| 110 |
+
|
| 111 |
+
start_time = time.time()
|
| 112 |
+
ping_url = f"{base_url}/sse"
|
| 113 |
+
|
| 114 |
+
# Try to connect to the server
|
| 115 |
+
while time.time() - start_time < timeout:
|
| 116 |
+
try:
|
| 117 |
+
async with aiohttp.ClientSession() as session:
|
| 118 |
+
try:
|
| 119 |
+
# First try the endpoint
|
| 120 |
+
async with session.get(ping_url, timeout=2) as response:
|
| 121 |
+
if response.status == 200:
|
| 122 |
+
elapsed = time.time() - start_time
|
| 123 |
+
logger.info(f"Server is ready! SSE endpoint responded with 200 after {elapsed:.1f}s")
|
| 124 |
+
return True
|
| 125 |
+
except Exception:
|
| 126 |
+
# If sse endpoint doesn't work, try the base URL
|
| 127 |
+
async with session.get(base_url, timeout=2) as response:
|
| 128 |
+
if response.status < 500: # Accept any non-server error
|
| 129 |
+
elapsed = time.time() - start_time
|
| 130 |
+
logger.info(
|
| 131 |
+
f"Server is ready! Base URL responded with {response.status} after {elapsed:.1f}s"
|
| 132 |
+
)
|
| 133 |
+
return True
|
| 134 |
+
except Exception:
|
| 135 |
+
# Wait a bit before trying again
|
| 136 |
+
await asyncio.sleep(0.5)
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
# If we get here, the request failed
|
| 140 |
+
await asyncio.sleep(0.5)
|
| 141 |
+
|
| 142 |
+
# Log status every 5 seconds
|
| 143 |
+
elapsed = time.time() - start_time
|
| 144 |
+
if int(elapsed) % 5 == 0:
|
| 145 |
+
logger.info(f"Still waiting for server to respond... ({elapsed:.1f}s elapsed)")
|
| 146 |
+
sys.stdout.flush()
|
| 147 |
+
|
| 148 |
+
# If we get here, we timed out
|
| 149 |
+
raise TimeoutError(f"Timeout waiting for server to respond (waited {timeout} seconds)")
|
| 150 |
+
|
| 151 |
+
async def _before_connect(self) -> None:
|
| 152 |
+
"""Set up the sandbox and prepare the connection manager."""
|
| 153 |
+
logger.debug("Connecting to MCP implementation in sandbox")
|
| 154 |
+
|
| 155 |
+
# Start the sandbox if not already active
|
| 156 |
+
if not self._sandbox.is_active:
|
| 157 |
+
logger.debug("Starting sandbox...")
|
| 158 |
+
await self._sandbox.start()
|
| 159 |
+
|
| 160 |
+
# Get the host for the sandbox
|
| 161 |
+
# Note: This assumes the sandbox implementation has a get_host method
|
| 162 |
+
# For E2BSandbox, this is available
|
| 163 |
+
host = self._sandbox.get_host(self.port)
|
| 164 |
+
self.base_url = f"https://{host}".rstrip("/")
|
| 165 |
+
|
| 166 |
+
# Append command with args
|
| 167 |
+
command = f"{self.user_command} {' '.join(self.user_args)}"
|
| 168 |
+
|
| 169 |
+
# Construct the full command with supergateway
|
| 170 |
+
full_command = f'{self.supergateway_cmd_parts} \
|
| 171 |
+
--base-url {self.base_url} \
|
| 172 |
+
--port {self.port} \
|
| 173 |
+
--cors \
|
| 174 |
+
--stdio "{command}"'
|
| 175 |
+
|
| 176 |
+
logger.debug(f"Full command: {full_command}")
|
| 177 |
+
|
| 178 |
+
# Execute the command in the sandbox
|
| 179 |
+
self.process = await self._sandbox.execute_safe(
|
| 180 |
+
full_command,
|
| 181 |
+
envs=self.user_env,
|
| 182 |
+
timeout=1000 * 60 * 10, # 10 minutes timeout
|
| 183 |
+
background=True,
|
| 184 |
+
on_stdout=self._handle_stdout,
|
| 185 |
+
on_stderr=self._handle_stderr,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Wait for the server to be ready
|
| 189 |
+
await self.wait_for_server_response(self.base_url, timeout=30)
|
| 190 |
+
logger.debug("Initializing connection manager...")
|
| 191 |
+
|
| 192 |
+
# Create the SSE connection URL
|
| 193 |
+
sse_url = f"{self.base_url}/sse"
|
| 194 |
+
|
| 195 |
+
# Create and set up the connection manager
|
| 196 |
+
self._connection_manager = SseConnectionManager(sse_url, self.headers, self.timeout, self.sse_read_timeout)
|
| 197 |
+
|
| 198 |
+
async def _after_connect(self) -> None:
|
| 199 |
+
"""Create ClientSession and log success."""
|
| 200 |
+
await super()._after_connect()
|
| 201 |
+
logger.debug(f"Successfully connected to MCP implementation via HTTP/SSE in sandbox: {self.base_url}")
|
| 202 |
+
|
| 203 |
+
async def _before_disconnect(self) -> None:
|
| 204 |
+
"""Clean up sandbox-specific resources before disconnection."""
|
| 205 |
+
logger.debug("Cleaning up sandbox resources")
|
| 206 |
+
|
| 207 |
+
# Stop the sandbox (which will clean up processes)
|
| 208 |
+
if self._sandbox and self._sandbox.is_active:
|
| 209 |
+
try:
|
| 210 |
+
logger.debug("Stopping sandbox instance")
|
| 211 |
+
await self._sandbox.stop()
|
| 212 |
+
logger.debug("Sandbox instance stopped successfully")
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.warning(f"Error stopping sandbox: {e}")
|
| 215 |
+
|
| 216 |
+
self.process = None
|
| 217 |
+
|
| 218 |
+
# Call the parent method to clean up MCP resources
|
| 219 |
+
await super()._before_disconnect()
|
| 220 |
+
|
| 221 |
+
# Clear any collected output
|
| 222 |
+
self.stdout_lines = []
|
| 223 |
+
self.stderr_lines = []
|
| 224 |
+
self.base_url = None
|
| 225 |
+
|
| 226 |
+
async def _cleanup_on_connect_failure(self) -> None:
|
| 227 |
+
"""Clean up sandbox resources on connection failure."""
|
| 228 |
+
# Stop the sandbox if it was started
|
| 229 |
+
if self._sandbox and self._sandbox.is_active:
|
| 230 |
+
try:
|
| 231 |
+
await self._sandbox.stop()
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.warning(f"Error stopping sandbox during cleanup: {e}")
|
| 234 |
+
|
| 235 |
+
self.process = None
|
| 236 |
+
self.stdout_lines = []
|
| 237 |
+
self.stderr_lines = []
|
| 238 |
+
self.base_url = None
|
| 239 |
+
|
| 240 |
+
# Call parent cleanup
|
| 241 |
+
await super()._cleanup_on_connect_failure()
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def sandbox(self) -> BaseSandbox:
|
| 245 |
+
"""Get the underlying sandbox instance."""
|
| 246 |
+
return self._sandbox
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def public_identifier(self) -> str:
|
| 250 |
+
"""Get the identifier for the connector."""
|
| 251 |
+
return {"type": "sandbox", "command": self.user_command, "args": self.user_args}
|
openspace/grounding/backends/mcp/transport/connectors/stdio.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
StdIO connector for MCP implementations.
|
| 3 |
+
|
| 4 |
+
This module provides a connector for communicating with MCP implementations
|
| 5 |
+
through the standard input/output streams.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
from mcp import ClientSession, StdioServerParameters
|
| 11 |
+
|
| 12 |
+
from openspace.utils.logging import Logger
|
| 13 |
+
from ..task_managers import StdioConnectionManager
|
| 14 |
+
from .base import MCPBaseConnector
|
| 15 |
+
|
| 16 |
+
logger = Logger.get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StdioConnector(MCPBaseConnector):
|
| 20 |
+
"""Connector for MCP implementations using stdio transport.
|
| 21 |
+
|
| 22 |
+
This connector uses the stdio transport to communicate with MCP implementations
|
| 23 |
+
that are executed as child processes. It uses a connection manager to handle
|
| 24 |
+
the proper lifecycle management of the stdio client.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
command: str = "npx",
|
| 30 |
+
args: list[str] | None = None,
|
| 31 |
+
env: dict[str, str] | None = None,
|
| 32 |
+
errlog=None,
|
| 33 |
+
):
|
| 34 |
+
"""Initialize a new stdio connector.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
command: The command to execute.
|
| 38 |
+
args: Optional command line arguments.
|
| 39 |
+
env: Optional environment variables.
|
| 40 |
+
errlog: Stream to write error output to (defaults to filtered stderr).
|
| 41 |
+
StdioConnectionManager will wrap this to filter harmless errors.
|
| 42 |
+
"""
|
| 43 |
+
self.command = command
|
| 44 |
+
self.args = args or [] # Ensure args is never None
|
| 45 |
+
|
| 46 |
+
# Ensure env is not None and add settings to suppress non-JSON output from servers
|
| 47 |
+
self.env = env or {}
|
| 48 |
+
# Add environment variables to encourage MCP servers to suppress non-JSON output
|
| 49 |
+
# Many Node.js-based servers respect NODE_ENV=production
|
| 50 |
+
if "NODE_ENV" not in self.env:
|
| 51 |
+
self.env["NODE_ENV"] = "production"
|
| 52 |
+
# Add flag to suppress informational messages (some servers respect this)
|
| 53 |
+
if "MCP_SILENT" not in self.env:
|
| 54 |
+
self.env["MCP_SILENT"] = "true"
|
| 55 |
+
|
| 56 |
+
self.errlog = errlog
|
| 57 |
+
|
| 58 |
+
# Create server parameters and connection manager
|
| 59 |
+
# StdioConnectionManager will wrap errlog in FilteredStderrWrapper
|
| 60 |
+
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
| 61 |
+
connection_manager = StdioConnectionManager(server_params, self.errlog)
|
| 62 |
+
super().__init__(connection_manager)
|
| 63 |
+
|
| 64 |
+
async def _before_connect(self) -> None:
|
| 65 |
+
"""Log connection attempt."""
|
| 66 |
+
logger.debug(f"Connecting to MCP implementation: {self.command}")
|
| 67 |
+
|
| 68 |
+
async def _after_connect(self) -> None:
|
| 69 |
+
"""Create ClientSession and log success."""
|
| 70 |
+
# Call parent's _after_connect to create the ClientSession
|
| 71 |
+
await super()._after_connect()
|
| 72 |
+
logger.debug(f"Successfully connected to MCP implementation: {self.command}")
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def public_identifier(self) -> dict[str, str]:
|
| 76 |
+
return {"type": "stdio", "command&args": f"{self.command} {' '.join(self.args)}"}
|