repl / src /core /env_server /mcp_environment.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
81b02bf verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
MCP Environment base class for OpenEnv.
This module provides the MCPEnvironment base class that integrates FastMCP servers
with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery
and invocation through the step() API, following RFC 003.
Key features:
- Automatic routing of ListToolsAction and CallToolAction to MCP server
- Reserved tool name validation (reset, step, state, close are protected)
- Timeout handling for tool calls
- Proper error categorization (tool not found, execution errors, timeouts)
- Mode-aware tool registration (production vs simulation)
- Code mode support via get_callables() and execute_code()
Usage:
from fastmcp import FastMCP
from openenv.core.env_server.mcp_environment import MCPEnvironment
class MyMCPEnv(MCPEnvironment):
def __init__(self):
mcp = FastMCP("my-server")
# Register mode-specific tools
@self.tool(mode="production")
def my_tool(arg: str) -> str:
return f"Production: {arg}"
@self.tool(mode="simulation")
def my_tool(arg: str) -> str:
return f"Simulation: {arg}"
super().__init__(mcp)
def reset(self, seed=None, episode_id=None, **kwargs):
# Reset logic here
...
def _step_impl(self, action):
# Handle non-MCP actions
...
@property
def state(self):
# Return current state
...
"""
import asyncio
import inspect
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Optional
from fastmcp import Client
from fastmcp.client.client import CallToolResult
from mcp.types import TextContent
from ..utils import run_async_safely
from .interfaces import Environment
from .mcp_types import (
CallToolAction,
CallToolObservation,
ListToolsAction,
ListToolsObservation,
RESERVED_TOOL_NAMES,
Tool,
ToolError,
ToolErrorType,
)
from .types import Action, Observation
# Default timeout for MCP tool calls in seconds
MCP_TOOL_CALL_TIMEOUT = 30.0
# Valid modes for tool registration
VALID_MODES = {"production", "simulation"}
def get_server_tools(mcp_server: Any) -> Dict[str, Any]:
"""
Get tools from a FastMCP server, compatible with both 2.x and 3.x.
Returns:
Dictionary mapping tool names to tool objects.
"""
# FastMCP 2.x: get_tools() returns dict {name: Tool}
if hasattr(mcp_server, "get_tools"):
result = run_async_safely(mcp_server.get_tools())
if isinstance(result, dict):
return result
# FastMCP 3.x: list_tools() returns list of Tool objects
if hasattr(mcp_server, "list_tools"):
tools_list = run_async_safely(mcp_server.list_tools())
return {t.name: t for t in tools_list}
return {}
class MCPEnvironment(Environment):
"""
Base class for environments that expose tools via MCP (Model Context Protocol).
MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing
agents to discover and invoke MCP tools through the standard step() interface.
The class automatically handles:
- ListToolsAction: Returns available tools from the MCP server
- CallToolAction: Invokes a specific tool with arguments
All other actions are delegated to the abstract _step_impl() method,
which subclasses must implement.
Args:
mcp_server: A FastMCP server instance containing tool definitions.
The server's tools will be validated against reserved names.
transform: Optional transform to apply to observations (inherited from Environment).
Raises:
ValueError: If any tool in the MCP server uses a reserved name
(reset, step, state, close).
Example:
>>> from fastmcp import FastMCP
>>> mcp = FastMCP("calculator")
>>> @mcp.tool()
... def add(a: int, b: int) -> int:
... return a + b
>>> env = MyMCPEnvironment(mcp)
>>> obs = env.step(ListToolsAction())
>>> obs.tools[0].name
'add'
"""
def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None:
"""
Initialize the MCP environment.
Args:
mcp_server: A FastMCP server instance with tool definitions.
transform: Optional transform to apply to observations.
Raises:
ValueError: If any tool uses a reserved name (reset, step, state, close).
"""
super().__init__(transform=transform)
# Validate tool names before storing
self._validate_tool_names(mcp_server)
self.mcp_server = mcp_server
self.mcp_client = Client(mcp_server)
# Track mode-specific tools: {tool_name: {mode: func}}
# mode can be "production", "simulation", or None (available in all modes)
self._mode_tools = defaultdict(dict)
# Track tool schemas for list_tools: {tool_name: {mode: schema}}
self._mode_tool_schemas = defaultdict(dict)
@property
def supports_code_mode(self) -> bool:
"""Check if this environment supports code mode (execute_code)."""
return True
def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]:
"""
Get tools from a FastMCP server, compatible with both 2.x and 3.x.
Returns:
Dictionary mapping tool names to tool objects.
"""
return get_server_tools(mcp_server)
def get_callables(self) -> Dict[str, Callable]:
"""
Get callable functions for code mode.
Returns tool functions as direct Python callables, enabling code mode
where agents write Python code that calls tools directly (no JSON-RPC
overhead). Mode-specific tools are filtered by the current mode.
Returns:
Dictionary mapping tool names to callables.
"""
callables: Dict[str, Callable] = {}
current_mode = getattr(self, "_mode", None)
# Extract callables from FastMCP server using public API
for tool_name, tool in self._get_server_tools(self.mcp_server).items():
if hasattr(tool, "fn") and callable(tool.fn):
callables[tool_name] = tool.fn
# Add mode-specific tools available in current mode
for tool_name, mode_funcs in self._mode_tools.items():
if None in mode_funcs:
# Tool available in all modes (already in FastMCP if registered there)
if tool_name not in callables:
callables[tool_name] = mode_funcs[None]
elif current_mode in mode_funcs:
# Tool available in current mode only
callables[tool_name] = mode_funcs[current_mode]
return callables
def execute_code(self, code: str) -> Observation:
"""
Execute Python code with tools available as callables.
This enables the CodeAct pattern where agents write Python code
that calls tools directly as functions, avoiding JSON-RPC overhead.
Args:
code: Python code to execute. Tools are available as functions
in the execution namespace. Set a variable named 'result'
to capture the return value.
Returns:
Observation with result in metadata["result"] or error in
metadata["error"].
"""
namespace = self.get_callables()
result_dict: Dict[str, Any] = {}
try:
exec(code, namespace, result_dict)
result = result_dict.get("result")
return Observation(done=False, reward=0.0, metadata={"result": result})
except SyntaxError as e:
return Observation(
done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"}
)
except Exception as e:
return Observation(done=False, reward=0.0, metadata={"error": str(e)})
def _validate_tool_names(self, mcp_server: Any) -> None:
"""
Validate that no tools use reserved names.
Reserved names (reset, step, state, close) are protected to maintain
the dual API boundary between infrastructure and agent APIs.
Args:
mcp_server: The FastMCP server to validate.
Raises:
ValueError: If any tool uses a reserved name.
"""
tools_dict = self._get_server_tools(mcp_server)
if tools_dict:
tool_names = set(tools_dict.keys())
conflicts = tool_names & RESERVED_TOOL_NAMES
if conflicts:
raise ValueError(
f"MCP tools cannot use reserved names: {sorted(conflicts)}. "
f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
)
def tool(self, mode: Optional[str] = None) -> Callable:
"""
Decorator for registering mode-aware tools.
Args:
mode: Optional mode for the tool ("production" or "simulation").
If None, tool is available in all modes.
Returns:
A decorator function for registering tools.
Raises:
ValueError: If mode is not None, "production", or "simulation".
"""
if mode is not None and mode not in VALID_MODES:
raise ValueError(
f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None."
)
def decorator(func: Callable) -> Callable:
tool_name = func.__name__
# Validate tool name is not reserved
if tool_name in RESERVED_TOOL_NAMES:
raise ValueError(
f"Tool name '{tool_name}' is reserved and cannot be used. "
f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
)
# If mode is None, register with FastMCP as usual
if mode is None:
decorated_func = self.mcp_server.tool()(func)
self._mode_tools[tool_name][None] = func
return decorated_func
# For mode-specific tools, don't register with FastMCP
# Instead, track them ourselves
self._mode_tools[tool_name][mode] = func
# Extract schema information from function signature
sig = inspect.signature(func)
schema = {
"type": "object",
"properties": {},
"required": [],
}
for param_name, param in sig.parameters.items():
# Get type annotation
param_type = param.annotation
json_type = "string" # default
if param_type in (int, "int"):
json_type = "integer"
elif param_type in (float, "float"):
json_type = "number"
elif param_type in (bool, "bool"):
json_type = "boolean"
schema["properties"][param_name] = {"type": json_type}
# If no default value, it's required
if param.default == inspect.Parameter.empty:
schema["required"].append(param_name)
# Store the schema for this mode-specific tool
self._mode_tool_schemas[tool_name][mode] = {
"name": tool_name,
"description": func.__doc__ or "",
"input_schema": schema,
}
return func
return decorator
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
"""
Execute an action in the environment.
This method routes MCP-specific actions (ListToolsAction, CallToolAction)
to the appropriate handlers, while delegating all other actions to
the subclass's _step_impl() method.
Args:
action: The action to execute. Can be:
- ListToolsAction: Returns available MCP tools
- CallToolAction: Invokes a specific MCP tool
- Any other Action: Delegated to _step_impl()
timeout_s: Optional timeout in seconds for the action.
Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions.
**kwargs: Additional arguments passed to handlers.
Returns:
Observation appropriate to the action type:
- ListToolsObservation for ListToolsAction
- CallToolObservation for CallToolAction
- Subclass-defined Observation for other actions
"""
if isinstance(action, ListToolsAction):
return self._handle_list_tools()
elif isinstance(action, CallToolAction):
return self._handle_call_tool(action, timeout_s=timeout_s)
else:
return self._step_impl(action, timeout_s=timeout_s, **kwargs)
def _handle_list_tools(self) -> ListToolsObservation:
"""
Handle a ListToolsAction by querying the MCP server.
Returns:
ListToolsObservation containing all available tools with their
names, descriptions, and input schemas, filtered by current mode.
"""
try:
# Get current mode
current_mode = getattr(self, "_mode", None)
# Start with tools from FastMCP server (mode=None tools)
tools_result = run_async_safely(self._async_list_tools())
# Build list of Tool objects
tools = []
# Add FastMCP tools that are not mode-specific
for tool in tools_result:
if tool.name not in self._mode_tool_schemas:
tools.append(
Tool(
name=tool.name,
description=tool.description or "",
input_schema=tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
)
)
# Add mode-specific tools available in current mode
for tool_name, mode_schemas in self._mode_tool_schemas.items():
if None in mode_schemas:
# Tool available in all modes
schema = mode_schemas[None]
tools.append(
Tool(
name=schema["name"],
description=schema["description"],
input_schema=schema["input_schema"],
)
)
elif current_mode in mode_schemas:
# Tool available in current mode
schema = mode_schemas[current_mode]
tools.append(
Tool(
name=schema["name"],
description=schema["description"],
input_schema=schema["input_schema"],
)
)
return ListToolsObservation(tools=tools)
except Exception as e:
# Return an observation with error in metadata
return ListToolsObservation(
tools=[],
metadata={
"error": str(e),
"error_type": "list_tools_failed",
},
)
async def _async_list_tools(self) -> list:
"""
Async helper to list tools from the MCP client.
Returns:
List of tool objects from the MCP server.
"""
async with self.mcp_client:
return await self.mcp_client.list_tools()
def _handle_call_tool(
self,
action: CallToolAction,
timeout_s: Optional[float] = None,
) -> CallToolObservation:
"""
Handle a CallToolAction by invoking the specified tool.
Args:
action: The CallToolAction containing tool_name and arguments.
timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
Returns:
CallToolObservation with the tool's result or an error.
"""
timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
# Check if this is a mode-specific tool
tool_name = action.tool_name
current_mode = getattr(self, "_mode", None)
if tool_name in self._mode_tools:
mode_info = self._mode_tools[tool_name]
# Check if tool is available in current mode
# Tool is available if:
# 1. It has a None mode (available in all modes), OR
# 2. It has an implementation for the current mode
if None in mode_info:
# Use the mode-agnostic version
func = mode_info[None]
elif current_mode in mode_info:
# Use the mode-specific version
func = mode_info[current_mode]
else:
# Tool not available in current mode
return CallToolObservation(
tool_name=tool_name,
result=None,
error=ToolError(
error_type=ToolErrorType.TOOL_NOT_FOUND,
message=f"Tool '{tool_name}' not available in {current_mode} mode",
),
)
# Call the mode-specific function directly
try:
# Check if function is async and await if necessary
if inspect.iscoroutinefunction(func):
result = run_async_safely(func(**action.arguments))
else:
result = func(**action.arguments)
# Wrap result in CallToolResult format to match FastMCP behavior
return CallToolObservation(
tool_name=tool_name,
result=CallToolResult(
content=[TextContent(type="text", text=str(result))],
structured_content={"result": result},
meta=None,
data=result,
is_error=False,
),
)
except Exception as e:
return CallToolObservation(
tool_name=tool_name,
result=None,
error=ToolError(
error_type=ToolErrorType.EXECUTION_ERROR,
message=str(e),
),
)
# Not a mode-specific tool, use FastMCP
try:
# Run the async call_tool with timeout
# Use run_async_safely to handle both sync and async contexts
result = run_async_safely(
asyncio.wait_for(
self._async_call_tool(action.tool_name, action.arguments),
timeout=timeout,
)
)
return CallToolObservation(
tool_name=action.tool_name,
result=result,
)
except asyncio.TimeoutError:
return CallToolObservation(
tool_name=action.tool_name,
result=None,
error=ToolError(
error_type=ToolErrorType.TIMEOUT,
message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
),
)
except Exception as e:
error_message = str(e)
# Determine error type based on the exception
if (
"not found" in error_message.lower()
or "unknown tool" in error_message.lower()
):
error_type = ToolErrorType.TOOL_NOT_FOUND
elif (
"invalid" in error_message.lower()
or "argument" in error_message.lower()
):
error_type = ToolErrorType.INVALID_ARGS
else:
error_type = ToolErrorType.EXECUTION_ERROR
return CallToolObservation(
tool_name=action.tool_name,
result=None,
error=ToolError(
error_type=error_type,
message=error_message,
),
)
async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
"""
Async helper to call a tool on the MCP server.
Args:
tool_name: Name of the tool to invoke.
arguments: Dictionary of arguments to pass to the tool.
Returns:
The result from the tool execution.
"""
async with self.mcp_client:
return await self.mcp_client.call_tool(tool_name, arguments)
@abstractmethod
def _step_impl(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
"""
Handle non-MCP actions in the environment.
Subclasses must implement this method to handle any actions that are
not ListToolsAction or CallToolAction. This is where environment-specific
action processing should occur.
Args:
action: The action to execute (guaranteed not to be an MCP action).
timeout_s: Optional timeout in seconds.
**kwargs: Additional arguments.
Returns:
An Observation appropriate for the action.
"""
pass
def close(self) -> None:
"""
Clean up resources used by the environment.
This method cleans up the MCP client and any other resources.
Subclasses should call super().close() if they override this method.
"""
# The MCP client uses async context manager, so cleanup happens
# automatically when the context exits. We just clear references.
self.mcp_client = None
self.mcp_server = None