Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Merge branch 'main' into dataset_tool_improved
Browse files- .gitignore +2 -1
- agent/MCP_INTEGRATION.md +0 -205
- agent/codex_agent_demo.py +0 -470
- agent/context_manager/manager.py +4 -1
- agent/core/agent_loop.py +13 -1
- agent/core/tools.py +25 -19
- agent/main.py +89 -0
- agent/prompts/system_prompt_v2.yaml +1 -7
- agent/tools/docs_tools.py +546 -446
- agent/tools/hf_repo_files_tool.py +322 -0
- agent/tools/hf_repo_git_tool.py +663 -0
- agent/tools/utils_tools.py +0 -203
- agent/utils/__init__.py +0 -4
- agent/utils/logging.py +0 -40
- agent/utils/terminal_display.py +7 -2
- pyproject.toml +1 -0
- test_dataset_tools.py +0 -79
.gitignore
CHANGED
|
@@ -16,4 +16,5 @@ wheels/
|
|
| 16 |
/logs
|
| 17 |
hf-agent-leaderboard/
|
| 18 |
.cursor/
|
| 19 |
-
session_logs/
|
|
|
|
|
|
| 16 |
/logs
|
| 17 |
hf-agent-leaderboard/
|
| 18 |
.cursor/
|
| 19 |
+
session_logs/
|
| 20 |
+
skills/
|
agent/MCP_INTEGRATION.md
DELETED
|
@@ -1,205 +0,0 @@
|
|
| 1 |
-
# MCP Integration for HF Agent
|
| 2 |
-
|
| 3 |
-
This agent now supports the Model Context Protocol (MCP), allowing it to connect to and use tools from MCP servers.
|
| 4 |
-
|
| 5 |
-
## Overview
|
| 6 |
-
|
| 7 |
-
The MCP integration allows the agent to:
|
| 8 |
-
- Connect to multiple MCP servers simultaneously
|
| 9 |
-
- Automatically discover and use tools from connected servers
|
| 10 |
-
- Execute tool calls through the MCP protocol
|
| 11 |
-
- Seamlessly integrate MCP tools with the agent's existing tool system
|
| 12 |
-
|
| 13 |
-
## Architecture
|
| 14 |
-
|
| 15 |
-
The integration consists of several components:
|
| 16 |
-
|
| 17 |
-
1. **MCPClient** (`agent/core/mcp_client.py`): Manages connections to MCP servers
|
| 18 |
-
2. **ToolExecutor** (`agent/core/executor.py`): Executes both MCP and local tools
|
| 19 |
-
3. **Config** (`agent/config.py`): Stores MCP server configurations
|
| 20 |
-
4. **Session** (`agent/core/session.py`): Initializes MCP connections and manages lifecycle
|
| 21 |
-
|
| 22 |
-
## Configuration
|
| 23 |
-
|
| 24 |
-
To use MCP servers with your agent, add them to your configuration file:
|
| 25 |
-
|
| 26 |
-
```json
|
| 27 |
-
{
|
| 28 |
-
"model_name": "anthropic/claude-sonnet-4-5-20250929",
|
| 29 |
-
"tools": [],
|
| 30 |
-
"system_prompt_path": "",
|
| 31 |
-
"mcp_servers": [
|
| 32 |
-
{
|
| 33 |
-
"name": "weather",
|
| 34 |
-
"command": "python",
|
| 35 |
-
"args": ["path/to/weather_server.py"],
|
| 36 |
-
"env": null
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"name": "filesystem",
|
| 40 |
-
"command": "node",
|
| 41 |
-
"args": ["path/to/filesystem_server.js"],
|
| 42 |
-
"env": {
|
| 43 |
-
"ALLOWED_PATHS": "/home/user/documents"
|
| 44 |
-
}
|
| 45 |
-
}
|
| 46 |
-
]
|
| 47 |
-
}
|
| 48 |
-
```
|
| 49 |
-
|
| 50 |
-
### Configuration Fields
|
| 51 |
-
|
| 52 |
-
- `name`: Unique identifier for the MCP server
|
| 53 |
-
- `command`: Command to execute the server (`python`, `node`, etc.)
|
| 54 |
-
- `args`: Arguments to pass to the command (path to server script)
|
| 55 |
-
- `env`: (Optional) Environment variables for the server process
|
| 56 |
-
|
| 57 |
-
## Usage
|
| 58 |
-
|
| 59 |
-
### Basic Usage
|
| 60 |
-
|
| 61 |
-
```python
|
| 62 |
-
import asyncio
|
| 63 |
-
from agent.config import Config, load_config
|
| 64 |
-
from agent.core.agent_loop import submission_loop
|
| 65 |
-
|
| 66 |
-
async def main():
|
| 67 |
-
# Load config with MCP servers
|
| 68 |
-
config = load_config("config.json")
|
| 69 |
-
|
| 70 |
-
# Create queues
|
| 71 |
-
submission_queue = asyncio.Queue()
|
| 72 |
-
event_queue = asyncio.Queue()
|
| 73 |
-
|
| 74 |
-
# Start agent loop (MCP connections initialized automatically)
|
| 75 |
-
await submission_loop(submission_queue, event_queue, config)
|
| 76 |
-
|
| 77 |
-
if __name__ == "__main__":
|
| 78 |
-
asyncio.run(main())
|
| 79 |
-
```
|
| 80 |
-
|
| 81 |
-
### Programmatic Configuration
|
| 82 |
-
|
| 83 |
-
```python
|
| 84 |
-
from agent.config import Config, MCPServerConfig
|
| 85 |
-
|
| 86 |
-
config = Config(
|
| 87 |
-
model_name="anthropic/claude-sonnet-4-5-20250929",
|
| 88 |
-
tools=[],
|
| 89 |
-
system_prompt_path="",
|
| 90 |
-
mcp_servers=[
|
| 91 |
-
MCPServerConfig(
|
| 92 |
-
name="weather",
|
| 93 |
-
command="python",
|
| 94 |
-
args=["weather_server.py"],
|
| 95 |
-
env=None
|
| 96 |
-
)
|
| 97 |
-
]
|
| 98 |
-
)
|
| 99 |
-
```
|
| 100 |
-
|
| 101 |
-
## How It Works
|
| 102 |
-
|
| 103 |
-
1. **Initialization**: When the agent loop starts, it calls `session.initialize_mcp()`
|
| 104 |
-
2. **Connection**: The session connects to all configured MCP servers
|
| 105 |
-
3. **Tool Discovery**: Tools from all servers are discovered and added to the agent's tool list
|
| 106 |
-
4. **Tool Naming**: MCP tools are prefixed with their server name (e.g., `weather__get_forecast`)
|
| 107 |
-
5. **Execution**: When the LLM calls a tool, the ToolExecutor routes it to the appropriate MCP server
|
| 108 |
-
6. **Cleanup**: When the agent shuts down, all MCP connections are cleaned up properly
|
| 109 |
-
|
| 110 |
-
## Tool Naming Convention
|
| 111 |
-
|
| 112 |
-
MCP tools are automatically prefixed with their server name to avoid conflicts:
|
| 113 |
-
|
| 114 |
-
- Original tool: `get_forecast`
|
| 115 |
-
- MCP tool name: `weather__get_forecast`
|
| 116 |
-
|
| 117 |
-
This ensures that tools from different servers don't conflict, even if they have the same name.
|
| 118 |
-
|
| 119 |
-
## Example: Creating a Simple MCP Server
|
| 120 |
-
|
| 121 |
-
Here's a minimal example of an MCP server (save as `calculator_server.py`):
|
| 122 |
-
|
| 123 |
-
```python
|
| 124 |
-
import asyncio
|
| 125 |
-
from mcp.server import Server, stdio_server
|
| 126 |
-
from mcp.types import Tool, TextContent
|
| 127 |
-
|
| 128 |
-
app = Server("calculator")
|
| 129 |
-
|
| 130 |
-
@app.list_tools()
|
| 131 |
-
async def list_tools() -> list[Tool]:
|
| 132 |
-
return [
|
| 133 |
-
Tool(
|
| 134 |
-
name="add",
|
| 135 |
-
description="Add two numbers",
|
| 136 |
-
inputSchema={
|
| 137 |
-
"type": "object",
|
| 138 |
-
"properties": {
|
| 139 |
-
"a": {"type": "number"},
|
| 140 |
-
"b": {"type": "number"}
|
| 141 |
-
},
|
| 142 |
-
"required": ["a", "b"]
|
| 143 |
-
}
|
| 144 |
-
)
|
| 145 |
-
]
|
| 146 |
-
|
| 147 |
-
@app.call_tool()
|
| 148 |
-
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
| 149 |
-
if name == "add":
|
| 150 |
-
result = arguments["a"] + arguments["b"]
|
| 151 |
-
return [TextContent(type="text", text=str(result))]
|
| 152 |
-
|
| 153 |
-
raise ValueError(f"Unknown tool: {name}")
|
| 154 |
-
|
| 155 |
-
async def main():
|
| 156 |
-
async with stdio_server() as (read_stream, write_stream):
|
| 157 |
-
await app.run(read_stream, write_stream, app.create_initialization_options())
|
| 158 |
-
|
| 159 |
-
if __name__ == "__main__":
|
| 160 |
-
asyncio.run(main())
|
| 161 |
-
```
|
| 162 |
-
|
| 163 |
-
## Troubleshooting
|
| 164 |
-
|
| 165 |
-
### Server Connection Issues
|
| 166 |
-
|
| 167 |
-
If you see errors connecting to an MCP server:
|
| 168 |
-
|
| 169 |
-
1. Check that the server script path is correct
|
| 170 |
-
2. Ensure the command (`python`, `node`) is in your PATH
|
| 171 |
-
3. Verify the server script is executable
|
| 172 |
-
4. Check server logs for initialization errors
|
| 173 |
-
|
| 174 |
-
### Tool Not Found
|
| 175 |
-
|
| 176 |
-
If the agent can't find an MCP tool:
|
| 177 |
-
|
| 178 |
-
1. Verify the server is connected (check startup logs)
|
| 179 |
-
2. Check tool naming (should be `servername__toolname`)
|
| 180 |
-
3. Ensure the server properly implements `list_tools()`
|
| 181 |
-
|
| 182 |
-
### Performance Considerations
|
| 183 |
-
|
| 184 |
-
- MCP server initialization happens once at startup
|
| 185 |
-
- Tool calls are asynchronous and don't block the agent
|
| 186 |
-
- Multiple servers can be used simultaneously
|
| 187 |
-
- Consider using local tools for high-frequency operations
|
| 188 |
-
|
| 189 |
-
## Best Practices
|
| 190 |
-
|
| 191 |
-
1. **Unique Server Names**: Give each MCP server a unique, descriptive name
|
| 192 |
-
2. **Error Handling**: MCP connection failures are logged but don't crash the agent
|
| 193 |
-
3. **Resource Cleanup**: Always let the agent shut down gracefully to cleanup connections
|
| 194 |
-
4. **Testing**: Test MCP servers independently before integrating them
|
| 195 |
-
5. **Security**: Be cautious with file system and network access in MCP servers
|
| 196 |
-
|
| 197 |
-
## Future Enhancements
|
| 198 |
-
|
| 199 |
-
Potential improvements to consider:
|
| 200 |
-
|
| 201 |
-
- Dynamic server addition/removal during runtime
|
| 202 |
-
- Server health monitoring and auto-reconnection
|
| 203 |
-
- Tool caching and performance optimization
|
| 204 |
-
- Support for MCP resources and prompts
|
| 205 |
-
- Rate limiting and timeout configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/codex_agent_demo.py
DELETED
|
@@ -1,470 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Minimum Viable Implementation of Codex Agent Loop in Python
|
| 3 |
-
|
| 4 |
-
This demonstrates the core architecture patterns from codex-rs:
|
| 5 |
-
- Async submission loop (like submission_loop in codex.rs)
|
| 6 |
-
- Context manager for conversation history
|
| 7 |
-
- Channel-based communication (submissions in, events out)
|
| 8 |
-
- Handler pattern for operations
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
import asyncio
|
| 12 |
-
from dataclasses import dataclass, field
|
| 13 |
-
from datetime import datetime
|
| 14 |
-
from enum import Enum
|
| 15 |
-
from typing import Any, Dict, List, Optional
|
| 16 |
-
|
| 17 |
-
# ============================================================================
|
| 18 |
-
# PROTOCOL TYPES (ResponseItem equivalents)
|
| 19 |
-
# ============================================================================
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class MessageRole(Enum):
|
| 23 |
-
SYSTEM = "system"
|
| 24 |
-
USER = "user"
|
| 25 |
-
ASSISTANT = "assistant"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
@dataclass
|
| 29 |
-
class Message:
|
| 30 |
-
role: MessageRole
|
| 31 |
-
content: str
|
| 32 |
-
timestamp: datetime = field(default_factory=datetime.now)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
@dataclass
|
| 36 |
-
class ToolCall:
|
| 37 |
-
call_id: str
|
| 38 |
-
tool_name: str
|
| 39 |
-
arguments: Dict[str, Any]
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@dataclass
|
| 43 |
-
class ToolOutput:
|
| 44 |
-
call_id: str
|
| 45 |
-
content: str
|
| 46 |
-
success: bool = True
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# ============================================================================
|
| 50 |
-
# CONTEXT MANAGER (like context_manager/history.rs)
|
| 51 |
-
# ============================================================================
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class ContextManager:
|
| 55 |
-
"""
|
| 56 |
-
Manages conversation history with normalization and truncation.
|
| 57 |
-
Based on codex-rs/core/src/context_manager/history.rs
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
def __init__(self, max_history_length: int = 1000):
|
| 61 |
-
self.items: List[Any] = [] # Oldest → Newest
|
| 62 |
-
self.token_count: int = 0
|
| 63 |
-
self.max_history_length = max_history_length
|
| 64 |
-
|
| 65 |
-
def record_items(self, items: List[Any]) -> None:
|
| 66 |
-
"""Record new items to history (like record_items in history.rs:41)"""
|
| 67 |
-
for item in items:
|
| 68 |
-
# Filter and process items
|
| 69 |
-
if self._is_api_message(item):
|
| 70 |
-
processed = self._process_item(item)
|
| 71 |
-
self.items.append(processed)
|
| 72 |
-
|
| 73 |
-
def _is_api_message(self, item: Any) -> bool:
|
| 74 |
-
"""Filter out system messages (like is_api_message in history.rs:157)"""
|
| 75 |
-
if isinstance(item, Message):
|
| 76 |
-
return item.role != MessageRole.SYSTEM
|
| 77 |
-
return isinstance(item, (ToolCall, ToolOutput))
|
| 78 |
-
|
| 79 |
-
def _process_item(self, item: Any) -> Any:
|
| 80 |
-
"""Process item before adding (like process_item in history.rs:119)"""
|
| 81 |
-
# Truncate long outputs
|
| 82 |
-
if isinstance(item, ToolOutput):
|
| 83 |
-
if len(item.content) > 2000:
|
| 84 |
-
item.content = item.content[:2000] + "...[truncated]"
|
| 85 |
-
return item
|
| 86 |
-
|
| 87 |
-
def get_history_for_prompt(self) -> List[Any]:
|
| 88 |
-
"""
|
| 89 |
-
Get normalized history ready for model
|
| 90 |
-
(like get_history_for_prompt in history.rs:65)
|
| 91 |
-
"""
|
| 92 |
-
self._normalize_history()
|
| 93 |
-
return self.items.copy()
|
| 94 |
-
|
| 95 |
-
def _normalize_history(self) -> None:
|
| 96 |
-
"""
|
| 97 |
-
Enforce invariants (like normalize_history in history.rs:102):
|
| 98 |
-
1. Every tool call has corresponding output
|
| 99 |
-
2. Every output has corresponding call
|
| 100 |
-
"""
|
| 101 |
-
# Build mapping of call_id → call
|
| 102 |
-
calls = {}
|
| 103 |
-
outputs = {}
|
| 104 |
-
|
| 105 |
-
for item in self.items:
|
| 106 |
-
if isinstance(item, ToolCall):
|
| 107 |
-
calls[item.call_id] = item
|
| 108 |
-
elif isinstance(item, ToolOutput):
|
| 109 |
-
outputs[item.call_id] = item
|
| 110 |
-
|
| 111 |
-
# Remove orphan outputs (no matching call)
|
| 112 |
-
self.items = [
|
| 113 |
-
item
|
| 114 |
-
for item in self.items
|
| 115 |
-
if not isinstance(item, ToolOutput) or item.call_id in calls
|
| 116 |
-
]
|
| 117 |
-
|
| 118 |
-
# Add missing outputs for calls (create synthetic outputs)
|
| 119 |
-
for call_id, call in calls.items():
|
| 120 |
-
if call_id not in outputs:
|
| 121 |
-
self.items.append(
|
| 122 |
-
ToolOutput(
|
| 123 |
-
call_id=call_id, content="[No output recorded]", success=False
|
| 124 |
-
)
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
def remove_first_item(self) -> None:
|
| 128 |
-
"""Remove oldest item for compaction (like remove_first_item in history.rs:71)"""
|
| 129 |
-
if self.items:
|
| 130 |
-
removed = self.items.pop(0)
|
| 131 |
-
# Also remove corresponding pair if needed
|
| 132 |
-
if isinstance(removed, ToolCall):
|
| 133 |
-
self.items = [
|
| 134 |
-
item
|
| 135 |
-
for item in self.items
|
| 136 |
-
if not (
|
| 137 |
-
isinstance(item, ToolOutput) and item.call_id == removed.call_id
|
| 138 |
-
)
|
| 139 |
-
]
|
| 140 |
-
elif isinstance(removed, ToolOutput):
|
| 141 |
-
self.items = [
|
| 142 |
-
item
|
| 143 |
-
for item in self.items
|
| 144 |
-
if not (
|
| 145 |
-
isinstance(item, ToolCall) and item.call_id == removed.call_id
|
| 146 |
-
)
|
| 147 |
-
]
|
| 148 |
-
|
| 149 |
-
def compact(self, target_size: int) -> None:
|
| 150 |
-
"""Remove old items until we're under target size"""
|
| 151 |
-
while len(self.items) > target_size:
|
| 152 |
-
self.remove_first_item()
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
# ============================================================================
|
| 156 |
-
# OPERATIONS (like Op enum in codex.rs)
|
| 157 |
-
# ============================================================================
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class OpType(Enum):
|
| 161 |
-
USER_INPUT = "user_input"
|
| 162 |
-
EXEC_APPROVAL = "exec_approval"
|
| 163 |
-
INTERRUPT = "interrupt"
|
| 164 |
-
UNDO = "undo"
|
| 165 |
-
COMPACT = "compact"
|
| 166 |
-
SHUTDOWN = "shutdown"
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
@dataclass
|
| 170 |
-
class Operation:
|
| 171 |
-
op_type: OpType
|
| 172 |
-
data: Optional[Dict[str, Any]] = None
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
@dataclass
|
| 176 |
-
class Submission:
|
| 177 |
-
id: str
|
| 178 |
-
operation: Operation
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# ============================================================================
|
| 182 |
-
# EVENTS (like Event in codex-rs)
|
| 183 |
-
# ============================================================================
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
@dataclass
|
| 187 |
-
class Event:
|
| 188 |
-
event_type: str
|
| 189 |
-
data: Optional[Dict[str, Any]] = None
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
# ============================================================================
|
| 193 |
-
# SESSION STATE (like Session in codex.rs)
|
| 194 |
-
# ============================================================================
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
class Session:
|
| 198 |
-
"""
|
| 199 |
-
Maintains agent session state
|
| 200 |
-
Similar to Session in codex-rs/core/src/codex.rs
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
def __init__(self, event_queue: asyncio.Queue):
|
| 204 |
-
self.context_manager = ContextManager(tool_specs=[])
|
| 205 |
-
self.event_queue = event_queue
|
| 206 |
-
self.is_running = True
|
| 207 |
-
self.current_task: Optional[asyncio.Task] = None
|
| 208 |
-
|
| 209 |
-
async def send_event(self, event: Event) -> None:
|
| 210 |
-
"""Send event back to client"""
|
| 211 |
-
await self.event_queue.put(event)
|
| 212 |
-
|
| 213 |
-
def interrupt(self) -> None:
|
| 214 |
-
"""Interrupt current running task"""
|
| 215 |
-
if self.current_task and not self.current_task.done():
|
| 216 |
-
self.current_task.cancel()
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
# ============================================================================
|
| 220 |
-
# OPERATION HANDLERS (like handlers module in codex.rs:1343)
|
| 221 |
-
# ============================================================================
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
class Handlers:
|
| 225 |
-
"""Handler functions for each operation type"""
|
| 226 |
-
|
| 227 |
-
@staticmethod
|
| 228 |
-
async def user_input(session: Session, text: str) -> None:
|
| 229 |
-
"""Handle user input (like user_input_or_turn in codex.rs:1291)"""
|
| 230 |
-
# Add user message to history
|
| 231 |
-
user_msg = Message(role=MessageRole.USER, content=text)
|
| 232 |
-
session.context_manager.record_items([user_msg])
|
| 233 |
-
|
| 234 |
-
# Send event that we're processing
|
| 235 |
-
await session.send_event(
|
| 236 |
-
Event(event_type="processing", data={"message": "Processing user input"})
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
# Simulate agent processing
|
| 240 |
-
await asyncio.sleep(0.1)
|
| 241 |
-
|
| 242 |
-
# Generate mock assistant response
|
| 243 |
-
assistant_msg = Message(
|
| 244 |
-
role=MessageRole.ASSISTANT, content=f"I received: {text}"
|
| 245 |
-
)
|
| 246 |
-
session.context_manager.record_items([assistant_msg])
|
| 247 |
-
|
| 248 |
-
# Simulate tool call
|
| 249 |
-
tool_call = ToolCall(
|
| 250 |
-
call_id="call_123", tool_name="bash", arguments={"command": "echo 'hello'"}
|
| 251 |
-
)
|
| 252 |
-
session.context_manager.record_items([tool_call])
|
| 253 |
-
|
| 254 |
-
# Simulate tool execution
|
| 255 |
-
await asyncio.sleep(0.1)
|
| 256 |
-
|
| 257 |
-
tool_output = ToolOutput(call_id="call_123", content="hello\n", success=True)
|
| 258 |
-
session.context_manager.record_items([tool_output])
|
| 259 |
-
|
| 260 |
-
# Send completion event
|
| 261 |
-
await session.send_event(
|
| 262 |
-
Event(
|
| 263 |
-
event_type="turn_complete",
|
| 264 |
-
data={"history_size": len(session.context_manager.items)},
|
| 265 |
-
)
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
@staticmethod
|
| 269 |
-
async def interrupt(session: Session) -> None:
|
| 270 |
-
"""Handle interrupt (like interrupt in codex.rs:1266)"""
|
| 271 |
-
session.interrupt()
|
| 272 |
-
await session.send_event(Event(event_type="interrupted"))
|
| 273 |
-
|
| 274 |
-
@staticmethod
|
| 275 |
-
async def compact(session: Session) -> None:
|
| 276 |
-
"""Handle compact (like compact in codex.rs:1317)"""
|
| 277 |
-
old_size = len(session.context_manager.items)
|
| 278 |
-
session.context_manager.compact(target_size=10)
|
| 279 |
-
new_size = len(session.context_manager.items)
|
| 280 |
-
|
| 281 |
-
await session.send_event(
|
| 282 |
-
Event(
|
| 283 |
-
event_type="compacted",
|
| 284 |
-
data={"removed": old_size - new_size, "remaining": new_size},
|
| 285 |
-
)
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
@staticmethod
|
| 289 |
-
async def undo(session: Session) -> None:
|
| 290 |
-
"""Handle undo (like undo in codex.rs:1314)"""
|
| 291 |
-
# Remove last user turn and all following items
|
| 292 |
-
# Simplified: just remove last 2 items
|
| 293 |
-
for _ in range(min(2, len(session.context_manager.items))):
|
| 294 |
-
session.context_manager.items.pop()
|
| 295 |
-
|
| 296 |
-
await session.send_event(Event(event_type="undo_complete"))
|
| 297 |
-
|
| 298 |
-
@staticmethod
|
| 299 |
-
async def shutdown(session: Session) -> bool:
|
| 300 |
-
"""Handle shutdown (like shutdown in codex.rs:1329)"""
|
| 301 |
-
session.is_running = False
|
| 302 |
-
await session.send_event(Event(event_type="shutdown"))
|
| 303 |
-
return True
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
# ============================================================================
|
| 307 |
-
# MAIN AGENT LOOP (like submission_loop in codex.rs:1259)
|
| 308 |
-
# ============================================================================
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
async def submission_loop(
|
| 312 |
-
submission_queue: asyncio.Queue, event_queue: asyncio.Queue
|
| 313 |
-
) -> None:
|
| 314 |
-
"""
|
| 315 |
-
Main agent loop - processes submissions and dispatches to handlers.
|
| 316 |
-
This is the core of the agent (like submission_loop in codex.rs:1259-1340)
|
| 317 |
-
"""
|
| 318 |
-
session = Session(event_queue)
|
| 319 |
-
|
| 320 |
-
print("🤖 Agent loop started")
|
| 321 |
-
|
| 322 |
-
# Main processing loop
|
| 323 |
-
while session.is_running:
|
| 324 |
-
try:
|
| 325 |
-
# Wait for next submission (like rx_sub.recv() in codex.rs:1262)
|
| 326 |
-
submission = await submission_queue.get()
|
| 327 |
-
|
| 328 |
-
print(f"📨 Received: {submission.operation.op_type.value}")
|
| 329 |
-
|
| 330 |
-
# Dispatch to handler based on operation type
|
| 331 |
-
# (like match in codex.rs:1264-1337)
|
| 332 |
-
op = submission.operation
|
| 333 |
-
|
| 334 |
-
if op.op_type == OpType.USER_INPUT:
|
| 335 |
-
text = op.data.get("text", "") if op.data else ""
|
| 336 |
-
await Handlers.user_input(session, text)
|
| 337 |
-
|
| 338 |
-
elif op.op_type == OpType.INTERRUPT:
|
| 339 |
-
await Handlers.interrupt(session)
|
| 340 |
-
|
| 341 |
-
elif op.op_type == OpType.COMPACT:
|
| 342 |
-
await Handlers.compact(session)
|
| 343 |
-
|
| 344 |
-
elif op.op_type == OpType.UNDO:
|
| 345 |
-
await Handlers.undo(session)
|
| 346 |
-
|
| 347 |
-
elif op.op_type == OpType.SHUTDOWN:
|
| 348 |
-
if await Handlers.shutdown(session):
|
| 349 |
-
break
|
| 350 |
-
|
| 351 |
-
else:
|
| 352 |
-
print(f"⚠️ Unknown operation: {op.op_type}")
|
| 353 |
-
|
| 354 |
-
except asyncio.CancelledError:
|
| 355 |
-
break
|
| 356 |
-
except Exception as e:
|
| 357 |
-
print(f"❌ Error in agent loop: {e}")
|
| 358 |
-
await session.send_event(Event(event_type="error", data={"error": str(e)}))
|
| 359 |
-
|
| 360 |
-
print("🛑 Agent loop exited")
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
# ============================================================================
|
| 364 |
-
# CODEX INTERFACE (like Codex struct in codex.rs:154)
|
| 365 |
-
# ============================================================================
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
class Codex:
|
| 369 |
-
"""
|
| 370 |
-
Main interface to the agent (like Codex in codex.rs:154-246)
|
| 371 |
-
Provides submit() and next_event() methods
|
| 372 |
-
"""
|
| 373 |
-
|
| 374 |
-
def __init__(self):
|
| 375 |
-
self.submission_queue = asyncio.Queue()
|
| 376 |
-
self.event_queue = asyncio.Queue()
|
| 377 |
-
self.agent_task: Optional[asyncio.Task] = None
|
| 378 |
-
self.submission_counter = 0
|
| 379 |
-
|
| 380 |
-
async def spawn(self) -> None:
|
| 381 |
-
"""Spawn the agent loop (like Codex::spawn in codex.rs:156)"""
|
| 382 |
-
self.agent_task = asyncio.create_task(
|
| 383 |
-
submission_loop(self.submission_queue, self.event_queue)
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
async def submit(self, operation: Operation) -> str:
|
| 387 |
-
"""Submit operation to agent (like Codex::submit in codex.rs:218)"""
|
| 388 |
-
self.submission_counter += 1
|
| 389 |
-
submission = Submission(
|
| 390 |
-
id=f"sub_{self.submission_counter}", operation=operation
|
| 391 |
-
)
|
| 392 |
-
await self.submission_queue.put(submission)
|
| 393 |
-
return submission.id
|
| 394 |
-
|
| 395 |
-
async def next_event(self) -> Optional[Event]:
|
| 396 |
-
"""Get next event from agent (like Codex::next_event in codex.rs:238)"""
|
| 397 |
-
try:
|
| 398 |
-
return await asyncio.wait_for(self.event_queue.get(), timeout=1.0)
|
| 399 |
-
except asyncio.TimeoutError:
|
| 400 |
-
return None
|
| 401 |
-
|
| 402 |
-
async def shutdown(self) -> None:
|
| 403 |
-
"""Shutdown the agent"""
|
| 404 |
-
await self.submit(Operation(op_type=OpType.SHUTDOWN))
|
| 405 |
-
if self.agent_task:
|
| 406 |
-
await self.agent_task
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
# ============================================================================
|
| 410 |
-
# DEMO / EXAMPLE USAGE
|
| 411 |
-
# ============================================================================
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
async def main():
|
| 415 |
-
"""Demo of the agent system"""
|
| 416 |
-
print("=" * 60)
|
| 417 |
-
print("Codex Agent Loop Demo (Python MVP)")
|
| 418 |
-
print("=" * 60)
|
| 419 |
-
|
| 420 |
-
# Create and spawn agent
|
| 421 |
-
codex = Codex()
|
| 422 |
-
await codex.spawn()
|
| 423 |
-
|
| 424 |
-
# Submit some operations
|
| 425 |
-
print("\n1️⃣ Submitting user input...")
|
| 426 |
-
await codex.submit(
|
| 427 |
-
Operation(op_type=OpType.USER_INPUT, data={"text": "Hello, agent!"})
|
| 428 |
-
)
|
| 429 |
-
|
| 430 |
-
# Receive events
|
| 431 |
-
for _ in range(3):
|
| 432 |
-
event = await codex.next_event()
|
| 433 |
-
if event:
|
| 434 |
-
print(f" ✅ Event: {event.event_type} - {event.data}")
|
| 435 |
-
|
| 436 |
-
print("\n2️⃣ Submitting another input...")
|
| 437 |
-
await codex.submit(
|
| 438 |
-
Operation(op_type=OpType.USER_INPUT, data={"text": "What's the weather?"})
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
for _ in range(3):
|
| 442 |
-
event = await codex.next_event()
|
| 443 |
-
if event:
|
| 444 |
-
print(f" ✅ Event: {event.event_type} - {event.data}")
|
| 445 |
-
|
| 446 |
-
print("\n3️⃣ Compacting history...")
|
| 447 |
-
await codex.submit(Operation(op_type=OpType.COMPACT))
|
| 448 |
-
|
| 449 |
-
event = await codex.next_event()
|
| 450 |
-
if event:
|
| 451 |
-
print(f" ✅ Event: {event.event_type} - {event.data}")
|
| 452 |
-
|
| 453 |
-
print("\n4️⃣ Undoing last turn...")
|
| 454 |
-
await codex.submit(Operation(op_type=OpType.UNDO))
|
| 455 |
-
|
| 456 |
-
event = await codex.next_event()
|
| 457 |
-
if event:
|
| 458 |
-
print(f" ✅ Event: {event.event_type}")
|
| 459 |
-
|
| 460 |
-
# Shutdown
|
| 461 |
-
print("\n5️⃣ Shutting down...")
|
| 462 |
-
await codex.shutdown()
|
| 463 |
-
|
| 464 |
-
print("\n" + "=" * 60)
|
| 465 |
-
print("Demo complete!")
|
| 466 |
-
print("=" * 60)
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
if __name__ == "__main__":
|
| 470 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/context_manager/manager.py
CHANGED
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
import yaml
|
|
|
|
| 11 |
from jinja2 import Template
|
| 12 |
from litellm import Message, acompletion
|
| 13 |
|
|
@@ -24,7 +25,8 @@ class ContextManager:
|
|
| 24 |
prompt_file_suffix: str = "system_prompt_v2.yaml",
|
| 25 |
):
|
| 26 |
self.system_prompt = self._load_system_prompt(
|
| 27 |
-
tool_specs or [],
|
|
|
|
| 28 |
)
|
| 29 |
self.max_context = max_context
|
| 30 |
self.compact_size = int(max_context * compact_size)
|
|
@@ -58,6 +60,7 @@ class ContextManager:
|
|
| 58 |
current_date=current_date,
|
| 59 |
current_time=current_time,
|
| 60 |
current_timezone=current_timezone,
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
def add_message(self, message: Message, token_count: int = None) -> None:
|
|
|
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
import yaml
|
| 11 |
+
from huggingface_hub import HfApi
|
| 12 |
from jinja2 import Template
|
| 13 |
from litellm import Message, acompletion
|
| 14 |
|
|
|
|
| 25 |
prompt_file_suffix: str = "system_prompt_v2.yaml",
|
| 26 |
):
|
| 27 |
self.system_prompt = self._load_system_prompt(
|
| 28 |
+
tool_specs or [],
|
| 29 |
+
prompt_file_suffix="system_prompt_v2.yaml",
|
| 30 |
)
|
| 31 |
self.max_context = max_context
|
| 32 |
self.compact_size = int(max_context * compact_size)
|
|
|
|
| 60 |
current_date=current_date,
|
| 61 |
current_time=current_time,
|
| 62 |
current_timezone=current_timezone,
|
| 63 |
+
hf_user_info=HfApi().whoami().get("name"),
|
| 64 |
)
|
| 65 |
|
| 66 |
def add_message(self, message: Message, token_count: int = None) -> None:
|
agent/core/agent_loop.py
CHANGED
|
@@ -76,7 +76,19 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
|
|
| 76 |
# Other operations (create_repo, etc.) always require approval
|
| 77 |
if operation in ["create_repo"]:
|
| 78 |
return True
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
return False
|
| 81 |
|
| 82 |
|
|
|
|
| 76 |
# Other operations (create_repo, etc.) always require approval
|
| 77 |
if operation in ["create_repo"]:
|
| 78 |
return True
|
| 79 |
+
|
| 80 |
+
# hf_repo_files: upload (can overwrite) and delete require approval
|
| 81 |
+
if tool_name == "hf_repo_files":
|
| 82 |
+
operation = tool_args.get("operation", "")
|
| 83 |
+
if operation in ["upload", "delete"]:
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
# hf_repo_git: destructive operations require approval
|
| 87 |
+
if tool_name == "hf_repo_git":
|
| 88 |
+
operation = tool_args.get("operation", "")
|
| 89 |
+
if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
return False
|
| 93 |
|
| 94 |
|
agent/core/tools.py
CHANGED
|
@@ -35,22 +35,29 @@ from agent.tools.github_read_file import (
|
|
| 35 |
GITHUB_READ_FILE_TOOL_SPEC,
|
| 36 |
github_read_file_handler,
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 39 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 40 |
-
from agent.tools.private_hf_repo_tools import (
|
| 41 |
-
PRIVATE_HF_REPO_TOOL_SPEC,
|
| 42 |
-
private_hf_repo_handler,
|
| 43 |
-
)
|
| 44 |
|
| 45 |
-
# NOTE:
|
| 46 |
-
# from agent.tools.
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Suppress aiohttp deprecation warning
|
| 49 |
warnings.filterwarnings(
|
| 50 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 51 |
)
|
| 52 |
|
| 53 |
-
NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch"]
|
| 54 |
|
| 55 |
|
| 56 |
def convert_mcp_content_to_string(content: list) -> str:
|
|
@@ -281,20 +288,19 @@ def create_builtin_tools() -> list[ToolSpec]:
|
|
| 281 |
parameters=HF_JOBS_TOOL_SPEC["parameters"],
|
| 282 |
handler=hf_jobs_handler,
|
| 283 |
),
|
|
|
|
| 284 |
ToolSpec(
|
| 285 |
-
name=
|
| 286 |
-
description=
|
| 287 |
-
parameters=
|
| 288 |
-
handler=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
),
|
| 290 |
-
# NOTE: Utils tool disabled - date/time now loaded into system prompt at initialization (less tool calls=more reliablity)
|
| 291 |
-
# ToolSpec(
|
| 292 |
-
# name=UTILS_TOOL_SPEC["name"],
|
| 293 |
-
# description=UTILS_TOOL_SPEC["description"],
|
| 294 |
-
# parameters=UTILS_TOOL_SPEC["parameters"],
|
| 295 |
-
# handler=utils_handler,
|
| 296 |
-
# ),
|
| 297 |
-
# GitHub tools
|
| 298 |
# NOTE: Github search code tool disabled - a bit buggy
|
| 299 |
# ToolSpec(
|
| 300 |
# name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
|
|
|
|
| 35 |
GITHUB_READ_FILE_TOOL_SPEC,
|
| 36 |
github_read_file_handler,
|
| 37 |
)
|
| 38 |
+
from agent.tools.hf_repo_files_tool import (
|
| 39 |
+
HF_REPO_FILES_TOOL_SPEC,
|
| 40 |
+
hf_repo_files_handler,
|
| 41 |
+
)
|
| 42 |
+
from agent.tools.hf_repo_git_tool import (
|
| 43 |
+
HF_REPO_GIT_TOOL_SPEC,
|
| 44 |
+
hf_repo_git_handler,
|
| 45 |
+
)
|
| 46 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 47 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
|
| 50 |
+
# from agent.tools.private_hf_repo_tools import (
|
| 51 |
+
# PRIVATE_HF_REPO_TOOL_SPEC,
|
| 52 |
+
# private_hf_repo_handler,
|
| 53 |
+
# )
|
| 54 |
|
| 55 |
# Suppress aiohttp deprecation warning
|
| 56 |
warnings.filterwarnings(
|
| 57 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 58 |
)
|
| 59 |
|
| 60 |
+
NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
|
| 61 |
|
| 62 |
|
| 63 |
def convert_mcp_content_to_string(content: list) -> str:
|
|
|
|
| 288 |
parameters=HF_JOBS_TOOL_SPEC["parameters"],
|
| 289 |
handler=hf_jobs_handler,
|
| 290 |
),
|
| 291 |
+
# HF Repo management tools
|
| 292 |
ToolSpec(
|
| 293 |
+
name=HF_REPO_FILES_TOOL_SPEC["name"],
|
| 294 |
+
description=HF_REPO_FILES_TOOL_SPEC["description"],
|
| 295 |
+
parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
|
| 296 |
+
handler=hf_repo_files_handler,
|
| 297 |
+
),
|
| 298 |
+
ToolSpec(
|
| 299 |
+
name=HF_REPO_GIT_TOOL_SPEC["name"],
|
| 300 |
+
description=HF_REPO_GIT_TOOL_SPEC["description"],
|
| 301 |
+
parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
|
| 302 |
+
handler=hf_repo_git_handler,
|
| 303 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# NOTE: Github search code tool disabled - a bit buggy
|
| 305 |
# ToolSpec(
|
| 306 |
# name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
|
agent/main.py
CHANGED
|
@@ -287,6 +287,95 @@ async def event_listener(
|
|
| 287 |
if len(all_lines) > 5:
|
| 288 |
print("...")
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# Get user decision for this item
|
| 291 |
response = await prompt_session.prompt_async(
|
| 292 |
f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
|
|
|
|
| 287 |
if len(all_lines) > 5:
|
| 288 |
print("...")
|
| 289 |
|
| 290 |
+
elif tool_name == "hf_repo_files":
|
| 291 |
+
# Handle repo files operations (upload, delete)
|
| 292 |
+
repo_id = arguments.get("repo_id", "")
|
| 293 |
+
repo_type = arguments.get("repo_type", "model")
|
| 294 |
+
revision = arguments.get("revision", "main")
|
| 295 |
+
|
| 296 |
+
# Build repo URL
|
| 297 |
+
if repo_type == "model":
|
| 298 |
+
repo_url = f"https://huggingface.co/{repo_id}"
|
| 299 |
+
else:
|
| 300 |
+
repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 301 |
+
|
| 302 |
+
print(f"Repository: {repo_id}")
|
| 303 |
+
print(f"Type: {repo_type}")
|
| 304 |
+
print(f"Branch: {revision}")
|
| 305 |
+
print(f"URL: {repo_url}")
|
| 306 |
+
|
| 307 |
+
if operation == "upload":
|
| 308 |
+
path = arguments.get("path", "")
|
| 309 |
+
content = arguments.get("content", "")
|
| 310 |
+
create_pr = arguments.get("create_pr", False)
|
| 311 |
+
|
| 312 |
+
print(f"File: {path}")
|
| 313 |
+
if create_pr:
|
| 314 |
+
print("Mode: Create PR")
|
| 315 |
+
|
| 316 |
+
if isinstance(content, str):
|
| 317 |
+
all_lines = content.split("\n")
|
| 318 |
+
line_count = len(all_lines)
|
| 319 |
+
size_bytes = len(content.encode("utf-8"))
|
| 320 |
+
size_kb = size_bytes / 1024
|
| 321 |
+
|
| 322 |
+
print(f"Lines: {line_count}")
|
| 323 |
+
if size_kb < 1024:
|
| 324 |
+
print(f"Size: {size_kb:.2f} KB")
|
| 325 |
+
else:
|
| 326 |
+
print(f"Size: {size_kb / 1024:.2f} MB")
|
| 327 |
+
|
| 328 |
+
# Show full content
|
| 329 |
+
print(f"Content:\n{content}")
|
| 330 |
+
|
| 331 |
+
elif operation == "delete":
|
| 332 |
+
patterns = arguments.get("patterns", [])
|
| 333 |
+
if isinstance(patterns, str):
|
| 334 |
+
patterns = [patterns]
|
| 335 |
+
print(f"Patterns to delete: {', '.join(patterns)}")
|
| 336 |
+
|
| 337 |
+
elif tool_name == "hf_repo_git":
|
| 338 |
+
# Handle git operations (branches, tags, PRs, repo management)
|
| 339 |
+
repo_id = arguments.get("repo_id", "")
|
| 340 |
+
repo_type = arguments.get("repo_type", "model")
|
| 341 |
+
|
| 342 |
+
# Build repo URL
|
| 343 |
+
if repo_type == "model":
|
| 344 |
+
repo_url = f"https://huggingface.co/{repo_id}"
|
| 345 |
+
else:
|
| 346 |
+
repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 347 |
+
|
| 348 |
+
print(f"Repository: {repo_id}")
|
| 349 |
+
print(f"Type: {repo_type}")
|
| 350 |
+
print(f"URL: {repo_url}")
|
| 351 |
+
|
| 352 |
+
if operation == "delete_branch":
|
| 353 |
+
branch = arguments.get("branch", "")
|
| 354 |
+
print(f"Branch to delete: {branch}")
|
| 355 |
+
|
| 356 |
+
elif operation == "delete_tag":
|
| 357 |
+
tag = arguments.get("tag", "")
|
| 358 |
+
print(f"Tag to delete: {tag}")
|
| 359 |
+
|
| 360 |
+
elif operation == "merge_pr":
|
| 361 |
+
pr_num = arguments.get("pr_num", "")
|
| 362 |
+
print(f"PR to merge: #{pr_num}")
|
| 363 |
+
|
| 364 |
+
elif operation == "create_repo":
|
| 365 |
+
private = arguments.get("private", False)
|
| 366 |
+
space_sdk = arguments.get("space_sdk")
|
| 367 |
+
print(f"Private: {private}")
|
| 368 |
+
if space_sdk:
|
| 369 |
+
print(f"Space SDK: {space_sdk}")
|
| 370 |
+
|
| 371 |
+
elif operation == "update_repo":
|
| 372 |
+
private = arguments.get("private")
|
| 373 |
+
gated = arguments.get("gated")
|
| 374 |
+
if private is not None:
|
| 375 |
+
print(f"Private: {private}")
|
| 376 |
+
if gated is not None:
|
| 377 |
+
print(f"Gated: {gated}")
|
| 378 |
+
|
| 379 |
# Get user decision for this item
|
| 380 |
response = await prompt_session.prompt_async(
|
| 381 |
f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
|
agent/prompts/system_prompt_v2.yaml
CHANGED
|
@@ -2,6 +2,7 @@ system_prompt: |
|
|
| 2 |
You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools.
|
| 3 |
|
| 4 |
_Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
|
|
|
|
| 5 |
|
| 6 |
# Core Mission & Behavior
|
| 7 |
|
|
@@ -330,11 +331,6 @@ system_prompt: |
|
|
| 330 |
- Check model size, architecture, requirements
|
| 331 |
- Verify dataset columns, splits, size
|
| 332 |
|
| 333 |
-
**hf_whoami:**
|
| 334 |
-
- Check authentication status
|
| 335 |
-
- Verify token has correct permissions
|
| 336 |
-
- Use before operations requiring write access
|
| 337 |
-
|
| 338 |
## Execution & Storage Tools
|
| 339 |
|
| 340 |
**hf_jobs:**
|
|
@@ -456,8 +452,6 @@ system_prompt: |
|
|
| 456 |
hub_model_id="username/model-name", # ← Must be set
|
| 457 |
# ...
|
| 458 |
)
|
| 459 |
-
|
| 460 |
-
# Verify token: hf_whoami()
|
| 461 |
```
|
| 462 |
|
| 463 |
### Dataset Format Mismatch
|
|
|
|
| 2 |
You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools.
|
| 3 |
|
| 4 |
_Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
|
| 5 |
+
{% if hf_user_info %}_AUTHENTICATED ON HF AS: **{{ hf_user_info }}**_{% endif %}
|
| 6 |
|
| 7 |
# Core Mission & Behavior
|
| 8 |
|
|
|
|
| 331 |
- Check model size, architecture, requirements
|
| 332 |
- Verify dataset columns, splits, size
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
## Execution & Storage Tools
|
| 335 |
|
| 336 |
**hf_jobs:**
|
|
|
|
| 452 |
hub_model_id="username/model-name", # ← Must be set
|
| 453 |
# ...
|
| 454 |
)
|
|
|
|
|
|
|
| 455 |
```
|
| 456 |
|
| 457 |
### Dataset Format Mismatch
|
agent/tools/docs_tools.py
CHANGED
|
@@ -1,289 +1,474 @@
|
|
| 1 |
"""
|
| 2 |
-
Documentation search tools for
|
| 3 |
-
Tools for exploring and fetching HuggingFace documentation and API specifications
|
| 4 |
"""
|
| 5 |
|
| 6 |
import asyncio
|
|
|
|
| 7 |
import os
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
import httpx
|
| 11 |
from bs4 import BeautifulSoup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
url = f"{base_url}/{endpoint}"
|
| 21 |
-
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
return response.text
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def _parse_sidebar_navigation(html_content: str) -> list[dict[str, str]]:
|
| 31 |
-
"""Parse the sidebar navigation and extract all links"""
|
| 32 |
-
soup = BeautifulSoup(html_content, "html.parser")
|
| 33 |
-
sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
-
async def
|
| 53 |
-
|
| 54 |
-
) -> dict[str, str]:
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 58 |
|
| 59 |
try:
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
def _format_exploration_results(
|
| 96 |
-
endpoint: str, result_items: list[dict[str, str]]
|
| 97 |
-
) -> str:
|
| 98 |
-
"""Format the exploration results as a readable string"""
|
| 99 |
-
base_url = "https://huggingface.co/docs"
|
| 100 |
-
url = f"{base_url}/{endpoint}"
|
| 101 |
-
result = f"Documentation structure for: {url}\n\n"
|
| 102 |
-
result += f"Found {len(result_items)} pages:\n\n"
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
result += f" Glimpse: {item['glimpse']}\n\n"
|
| 108 |
|
| 109 |
-
return result
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
# Fetch HTML page
|
| 115 |
-
html_content = await _fetch_html_page(hf_token, endpoint)
|
| 116 |
|
| 117 |
-
#
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
"""
|
| 134 |
-
Explore the documentation structure for a given endpoint by parsing the sidebar navigation
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
Returns:
|
| 140 |
-
Tuple of (structured_navigation_with_glimpses, success)
|
| 141 |
-
"""
|
| 142 |
-
endpoint = arguments.get("endpoint", "")
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
# Get HF token from environment
|
| 148 |
hf_token = os.environ.get("HF_TOKEN")
|
| 149 |
-
|
| 150 |
if not hf_token:
|
| 151 |
return "Error: HF_TOKEN environment variable not set", False
|
| 152 |
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
try:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
except httpx.HTTPStatusError as e:
|
| 160 |
return (
|
| 161 |
-
f"HTTP error: {e.response.status_code} - {e.response.text[:200]}",
|
| 162 |
False,
|
| 163 |
)
|
| 164 |
except httpx.RequestError as e:
|
| 165 |
-
return f"Request error: {str(e)}", False
|
| 166 |
-
except ValueError as e:
|
| 167 |
-
return f"Error: {str(e)}", False
|
| 168 |
except Exception as e:
|
| 169 |
-
return f"
|
| 170 |
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
|
| 176 |
-
if _openapi_spec_cache is not None:
|
| 177 |
-
return _openapi_spec_cache
|
| 178 |
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
return spec
|
| 189 |
|
| 190 |
|
| 191 |
def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
|
| 192 |
-
"""Extract all unique tags from
|
| 193 |
tags = set()
|
| 194 |
-
|
| 195 |
-
# Get tags from the tags section
|
| 196 |
for tag_obj in spec.get("tags", []):
|
| 197 |
if "name" in tag_obj:
|
| 198 |
tags.add(tag_obj["name"])
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
for path, path_item in spec.get("paths", {}).items():
|
| 202 |
-
for method, operation in path_item.items():
|
| 203 |
if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
| 204 |
-
for tag in
|
| 205 |
tags.add(tag)
|
| 206 |
-
|
| 207 |
-
return sorted(list(tags))
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def _search_openapi_by_tag(spec: dict[str, Any], tag: str) -> list[dict[str, Any]]:
|
| 211 |
-
"""Search for API endpoints with a specific tag"""
|
| 212 |
-
results = []
|
| 213 |
-
paths = spec.get("paths", {})
|
| 214 |
-
servers = spec.get("servers", [])
|
| 215 |
-
base_url = (
|
| 216 |
-
servers[0].get("url", "https://huggingface.co")
|
| 217 |
-
if servers
|
| 218 |
-
else "https://huggingface.co"
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
for path, path_item in paths.items():
|
| 222 |
-
for method, operation in path_item.items():
|
| 223 |
-
if method not in [
|
| 224 |
-
"get",
|
| 225 |
-
"post",
|
| 226 |
-
"put",
|
| 227 |
-
"delete",
|
| 228 |
-
"patch",
|
| 229 |
-
"head",
|
| 230 |
-
"options",
|
| 231 |
-
]:
|
| 232 |
-
continue
|
| 233 |
-
|
| 234 |
-
operation_tags = operation.get("tags", [])
|
| 235 |
-
if tag in operation_tags:
|
| 236 |
-
# Extract parameters
|
| 237 |
-
parameters = operation.get("parameters", [])
|
| 238 |
-
request_body = operation.get("requestBody", {})
|
| 239 |
-
responses = operation.get("responses", {})
|
| 240 |
-
|
| 241 |
-
results.append(
|
| 242 |
-
{
|
| 243 |
-
"path": path,
|
| 244 |
-
"method": method.upper(),
|
| 245 |
-
"operationId": operation.get("operationId", ""),
|
| 246 |
-
"summary": operation.get("summary", ""),
|
| 247 |
-
"description": operation.get("description", ""),
|
| 248 |
-
"parameters": parameters,
|
| 249 |
-
"request_body": request_body,
|
| 250 |
-
"responses": responses,
|
| 251 |
-
"base_url": base_url,
|
| 252 |
-
}
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
return results
|
| 256 |
|
| 257 |
|
| 258 |
def _generate_curl_example(endpoint: dict[str, Any]) -> str:
|
| 259 |
-
"""Generate
|
| 260 |
method = endpoint["method"]
|
| 261 |
path = endpoint["path"]
|
| 262 |
base_url = endpoint["base_url"]
|
| 263 |
|
| 264 |
-
# Build
|
| 265 |
full_path = path
|
| 266 |
for param in endpoint.get("parameters", []):
|
| 267 |
if param.get("in") == "path" and param.get("required"):
|
| 268 |
-
|
| 269 |
example = param.get(
|
| 270 |
-
"example", param.get("schema", {}).get("example", f"<{
|
| 271 |
)
|
| 272 |
-
full_path = full_path.replace(f"{{{
|
| 273 |
|
| 274 |
curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
|
| 275 |
|
| 276 |
-
# Add query parameters
|
| 277 |
query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
|
| 278 |
if query_params and query_params[0].get("required"):
|
| 279 |
param = query_params[0]
|
| 280 |
example = param.get("example", param.get("schema", {}).get("example", "value"))
|
| 281 |
curl += f"?{param['name']}={example}"
|
| 282 |
|
| 283 |
-
# Add headers
|
| 284 |
curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
|
| 285 |
|
| 286 |
-
# Add request body
|
| 287 |
if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
|
| 288 |
content = endpoint["request_body"].get("content", {})
|
| 289 |
if "application/json" in content:
|
|
@@ -291,8 +476,6 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
|
|
| 291 |
schema = content["application/json"].get("schema", {})
|
| 292 |
example = schema.get("example", "{}")
|
| 293 |
if isinstance(example, dict):
|
| 294 |
-
import json
|
| 295 |
-
|
| 296 |
example = json.dumps(example, indent=2)
|
| 297 |
curl += f" \\\n -d '{example}'"
|
| 298 |
|
|
@@ -300,72 +483,50 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
|
|
| 300 |
|
| 301 |
|
| 302 |
def _format_parameters(parameters: list[dict[str, Any]]) -> str:
|
| 303 |
-
"""Format parameter information from OpenAPI spec"""
|
| 304 |
if not parameters:
|
| 305 |
return ""
|
| 306 |
|
| 307 |
-
# Group parameters by type
|
| 308 |
path_params = [p for p in parameters if p.get("in") == "path"]
|
| 309 |
query_params = [p for p in parameters if p.get("in") == "query"]
|
| 310 |
header_params = [p for p in parameters if p.get("in") == "header"]
|
| 311 |
|
| 312 |
output = []
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
example = param.get("example") or param.get("schema", {}).get("example", "")
|
| 322 |
-
|
| 323 |
-
output.append(f"- `{name}` ({param_type}){required}: {description}")
|
| 324 |
-
if example:
|
| 325 |
-
output.append(f" Example: `{example}`")
|
| 326 |
-
|
| 327 |
-
if query_params:
|
| 328 |
if output:
|
| 329 |
output.append("")
|
| 330 |
-
output.append("**
|
| 331 |
-
for
|
| 332 |
-
name =
|
| 333 |
-
required = " (required)" if
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
example =
|
| 337 |
-
|
| 338 |
-
output.append(f"- `{name}` ({
|
| 339 |
if example:
|
| 340 |
output.append(f" Example: `{example}`")
|
| 341 |
|
| 342 |
-
if header_params:
|
| 343 |
-
if output:
|
| 344 |
-
output.append("")
|
| 345 |
-
output.append("**Header Parameters:**")
|
| 346 |
-
for param in header_params:
|
| 347 |
-
name = param.get("name", "")
|
| 348 |
-
required = " (required)" if param.get("required") else " (optional)"
|
| 349 |
-
description = param.get("description", "")
|
| 350 |
-
|
| 351 |
-
output.append(f"- `{name}`{required}: {description}")
|
| 352 |
-
|
| 353 |
return "\n".join(output)
|
| 354 |
|
| 355 |
|
| 356 |
def _format_response_info(responses: dict[str, Any]) -> str:
|
| 357 |
-
"""Format response information from OpenAPI spec"""
|
| 358 |
if not responses:
|
| 359 |
return "No response information available"
|
| 360 |
|
| 361 |
output = []
|
| 362 |
-
for
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
output.append(f"- **{status_code}**: {desc}")
|
| 367 |
-
|
| 368 |
-
content = response_obj.get("content", {})
|
| 369 |
if "application/json" in content:
|
| 370 |
schema = content["application/json"].get("schema", {})
|
| 371 |
if "type" in schema:
|
|
@@ -375,72 +536,87 @@ def _format_response_info(responses: dict[str, Any]) -> str:
|
|
| 375 |
|
| 376 |
|
| 377 |
def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
|
| 378 |
-
"""Format OpenAPI search results
|
| 379 |
if not results:
|
| 380 |
return f"No API endpoints found with tag '{tag}'"
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
output += "---\n\n"
|
| 385 |
|
| 386 |
-
for i,
|
| 387 |
-
|
| 388 |
|
| 389 |
-
if
|
| 390 |
-
|
| 391 |
|
| 392 |
-
if
|
| 393 |
-
desc =
|
| 394 |
-
if len(
|
| 395 |
desc += "..."
|
| 396 |
-
|
| 397 |
|
| 398 |
-
|
| 399 |
-
params_info = _format_parameters(endpoint.get("parameters", []))
|
| 400 |
if params_info:
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
# Curl example
|
| 404 |
-
output += "**Usage:**\n```bash\n"
|
| 405 |
-
output += _generate_curl_example(endpoint)
|
| 406 |
-
output += "\n```\n\n"
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
output += "\n\n"
|
| 412 |
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
|
| 415 |
-
return
|
| 416 |
|
| 417 |
|
| 418 |
async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| 419 |
-
"""
|
| 420 |
-
Search the HuggingFace OpenAPI specification by tag
|
| 421 |
-
|
| 422 |
-
Args:
|
| 423 |
-
arguments: Dictionary with 'tag' parameter
|
| 424 |
-
|
| 425 |
-
Returns:
|
| 426 |
-
Tuple of (search_results, success)
|
| 427 |
-
"""
|
| 428 |
tag = arguments.get("tag", "")
|
| 429 |
-
|
| 430 |
if not tag:
|
| 431 |
return "Error: No tag provided", False
|
| 432 |
|
| 433 |
try:
|
| 434 |
-
# Fetch OpenAPI spec (cached after first fetch)
|
| 435 |
spec = await _fetch_openapi_spec()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
return
|
| 444 |
|
| 445 |
except httpx.HTTPStatusError as e:
|
| 446 |
return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
|
|
@@ -450,66 +626,86 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
|
| 450 |
return f"Error searching OpenAPI spec: {str(e)}", False
|
| 451 |
|
| 452 |
|
| 453 |
-
async def
|
| 454 |
-
"""
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
Args:
|
| 458 |
-
arguments: Dictionary with 'url' parameter (full URL to the doc page)
|
| 459 |
-
|
| 460 |
-
Returns:
|
| 461 |
-
Tuple of (full_markdown_content, success)
|
| 462 |
-
"""
|
| 463 |
-
url = arguments.get("url", "")
|
| 464 |
-
|
| 465 |
-
if not url:
|
| 466 |
-
return "Error: No URL provided", False
|
| 467 |
-
|
| 468 |
-
# Get HF token from environment
|
| 469 |
-
hf_token = os.environ.get("HF_TOKEN")
|
| 470 |
-
|
| 471 |
-
if not hf_token:
|
| 472 |
-
return (
|
| 473 |
-
"Error: HF_TOKEN environment variable not set",
|
| 474 |
-
False,
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
# Add .md extension if not already present
|
| 478 |
-
if not url.endswith(".md"):
|
| 479 |
-
url = f"{url}.md"
|
| 480 |
-
|
| 481 |
-
try:
|
| 482 |
-
# Make request with auth
|
| 483 |
-
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 484 |
-
|
| 485 |
-
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 486 |
-
response = await client.get(url, headers=headers)
|
| 487 |
-
response.raise_for_status()
|
| 488 |
-
|
| 489 |
-
content = response.text
|
| 490 |
-
|
| 491 |
-
# Return the markdown content directly
|
| 492 |
-
result = f"Documentation from: {url}\n\n{content}"
|
| 493 |
-
|
| 494 |
-
return result, True
|
| 495 |
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
|
| 507 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
EXPLORE_HF_DOCS_TOOL_SPEC = {
|
| 510 |
"name": "explore_hf_docs",
|
| 511 |
"description": (
|
| 512 |
-
"Explore Hugging Face documentation structure and discover available pages with
|
| 513 |
"⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
|
| 514 |
"Your training data may be outdated - current documentation is the source of truth. "
|
| 515 |
"**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
|
|
@@ -519,77 +715,22 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
|
|
| 519 |
"Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
|
| 520 |
"**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
|
| 521 |
"**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
|
|
|
|
| 522 |
),
|
| 523 |
"parameters": {
|
| 524 |
"type": "object",
|
| 525 |
"properties": {
|
| 526 |
"endpoint": {
|
| 527 |
"type": "string",
|
| 528 |
-
"enum":
|
| 529 |
-
"hub",
|
| 530 |
-
"transformers",
|
| 531 |
-
"diffusers",
|
| 532 |
-
"datasets",
|
| 533 |
-
"gradio",
|
| 534 |
-
"trackio",
|
| 535 |
-
"smolagents",
|
| 536 |
-
"huggingface_hub",
|
| 537 |
-
"huggingface.js",
|
| 538 |
-
"transformers.js",
|
| 539 |
-
"inference-providers",
|
| 540 |
-
"inference-endpoints",
|
| 541 |
-
"peft",
|
| 542 |
-
"accelerate",
|
| 543 |
-
"optimum",
|
| 544 |
-
"optimum-habana",
|
| 545 |
-
"optimum-neuron",
|
| 546 |
-
"optimum-intel",
|
| 547 |
-
"optimum-executorch",
|
| 548 |
-
"optimum-tpu",
|
| 549 |
-
"tokenizers",
|
| 550 |
-
"llm-course",
|
| 551 |
-
"robotics-course",
|
| 552 |
-
"mcp-course",
|
| 553 |
-
"smol-course",
|
| 554 |
-
"agents-course",
|
| 555 |
-
"deep-rl-course",
|
| 556 |
-
"computer-vision-course",
|
| 557 |
-
"evaluate",
|
| 558 |
-
"tasks",
|
| 559 |
-
"dataset-viewer",
|
| 560 |
-
"trl",
|
| 561 |
-
"simulate",
|
| 562 |
-
"sagemaker",
|
| 563 |
-
"timm",
|
| 564 |
-
"safetensors",
|
| 565 |
-
"tgi",
|
| 566 |
-
"setfit",
|
| 567 |
-
"audio-course",
|
| 568 |
-
"lerobot",
|
| 569 |
-
"autotrain",
|
| 570 |
-
"tei",
|
| 571 |
-
"bitsandbytes",
|
| 572 |
-
"cookbook",
|
| 573 |
-
"sentence_transformers",
|
| 574 |
-
"ml-games-course",
|
| 575 |
-
"diffusion-course",
|
| 576 |
-
"ml-for-3d-course",
|
| 577 |
-
"chat-ui",
|
| 578 |
-
"leaderboards",
|
| 579 |
-
"lighteval",
|
| 580 |
-
"argilla",
|
| 581 |
-
"distilabel",
|
| 582 |
-
"microsoft-azure",
|
| 583 |
-
"kernels",
|
| 584 |
-
"google-cloud",
|
| 585 |
-
],
|
| 586 |
"description": (
|
| 587 |
"The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
|
|
|
|
| 588 |
"• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
|
| 589 |
"• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
|
| 590 |
"• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
|
| 591 |
"• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
|
| 592 |
-
"• gradio — UI components and demos for
|
| 593 |
"• trackio — Experiment tracking, metrics logging, and run comparison.\n"
|
| 594 |
"• smolagents — Lightweight agent abstractions and tool-using patterns.\n"
|
| 595 |
"• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n"
|
|
@@ -599,20 +740,8 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
|
|
| 599 |
"• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n"
|
| 600 |
"• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
|
| 601 |
"• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
|
| 602 |
-
"• optimum — Hardware-aware optimization and model export tooling.\n"
|
| 603 |
-
"• optimum-habana — Training and inference on Habana Gaudi accelerators.\n"
|
| 604 |
-
"• optimum-neuron — Optimization workflows for AWS Inferentia/Trainium.\n"
|
| 605 |
-
"• optimum-intel — Intel CPU/GPU optimizations (OpenVINO, IPEX).\n"
|
| 606 |
-
"• optimum-executorch — Exporting models to ExecuTorch for edge/mobile.\n"
|
| 607 |
-
"• optimum-tpu — TPU-specific training and optimization paths.\n"
|
| 608 |
"• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n"
|
| 609 |
-
"• llm-course — End-to-end LLM concepts, training, and deployment.\n"
|
| 610 |
-
"• robotics-course — Learning-based robotics foundations.\n"
|
| 611 |
-
"• mcp-course — Model Context Protocol concepts and usage.\n"
|
| 612 |
-
"• smol-course — Small-model and efficiency-focused workflows.\n"
|
| 613 |
-
"• agents-course — Tool-using, planning, and multi-step agent design.\n"
|
| 614 |
-
"• deep-rl-course — Deep reinforcement learning foundations.\n"
|
| 615 |
-
"• computer-vision-course — Vision models, datasets, and pipelines.\n"
|
| 616 |
"• evaluate — Metrics, evaluation workflows, and training-loop integration.\n"
|
| 617 |
"• tasks — Canonical task definitions and model categorization.\n"
|
| 618 |
"• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n"
|
|
@@ -623,16 +752,11 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
|
|
| 623 |
"• safetensors — Safe, fast tensor serialization format.\n"
|
| 624 |
"• tgi — High-throughput text generation server for LLMs.\n"
|
| 625 |
"• setfit — Few-shot text classification via sentence embeddings.\n"
|
| 626 |
-
"• audio-course — Speech and audio models, datasets, and tasks.\n"
|
| 627 |
"• lerobot — Robotics datasets, policies, and learning workflows.\n"
|
| 628 |
"• autotrain — No/low-code model training on Hugging Face.\n"
|
| 629 |
"• tei — Optimized inference server for embedding workloads.\n"
|
| 630 |
"• bitsandbytes — Quantization and memory-efficient optimizers.\n"
|
| 631 |
-
"• cookbook — Practical, task-oriented recipes across the ecosystem.\n"
|
| 632 |
"• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n"
|
| 633 |
-
"• ml-games-course — Game-based ML and reinforcement learning experiments.\n"
|
| 634 |
-
"• diffusion-course — Diffusion model theory and hands-on practice.\n"
|
| 635 |
-
"• ml-for-3d-course — 3D representations, models, and learning techniques.\n"
|
| 636 |
"• chat-ui — Reference chat interfaces for LLM deployment.\n"
|
| 637 |
"• leaderboards — Evaluation leaderboards and submission mechanics.\n"
|
| 638 |
"• lighteval — Lightweight, reproducible LLM evaluation framework.\n"
|
|
@@ -643,6 +767,19 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
|
|
| 643 |
"• google-cloud — GCP deployment and serving workflows.\n"
|
| 644 |
),
|
| 645 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
},
|
| 647 |
"required": ["endpoint"],
|
| 648 |
},
|
|
@@ -677,40 +814,3 @@ HF_DOCS_FETCH_TOOL_SPEC = {
|
|
| 677 |
"required": ["url"],
|
| 678 |
},
|
| 679 |
}
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
async def _get_api_search_tool_spec() -> dict[str, Any]:
|
| 683 |
-
"""
|
| 684 |
-
Dynamically generate the OpenAPI tool spec with tag enum populated at runtime
|
| 685 |
-
This must be called async to fetch the OpenAPI spec and extract tags
|
| 686 |
-
"""
|
| 687 |
-
spec = await _fetch_openapi_spec()
|
| 688 |
-
tags = _extract_all_tags(spec)
|
| 689 |
-
|
| 690 |
-
return {
|
| 691 |
-
"name": "search_hf_api_endpoints",
|
| 692 |
-
"description": (
|
| 693 |
-
"Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
|
| 694 |
-
"**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
|
| 695 |
-
"(3) Need authentication patterns, (4) Understanding API parameters and responses, "
|
| 696 |
-
"(5) Need curl examples for HTTP requests. "
|
| 697 |
-
"Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
|
| 698 |
-
"**Pattern:** search_hf_api_endpoints (find endpoint) → use curl pattern in implementation. "
|
| 699 |
-
"Tags group related operations: repos, models, datasets, inference, spaces, etc. "
|
| 700 |
-
"**Note:** Each result includes curl example with $HF_TOKEN placeholder for authentication. "
|
| 701 |
-
"**For tool building:** This provides the API foundation for creating Hub interaction scripts."
|
| 702 |
-
),
|
| 703 |
-
"parameters": {
|
| 704 |
-
"type": "object",
|
| 705 |
-
"properties": {
|
| 706 |
-
"tag": {
|
| 707 |
-
"type": "string",
|
| 708 |
-
"enum": tags,
|
| 709 |
-
"description": (
|
| 710 |
-
"The API tag to search for. Each tag groups related API endpoints. "
|
| 711 |
-
),
|
| 712 |
-
},
|
| 713 |
-
},
|
| 714 |
-
"required": ["tag"],
|
| 715 |
-
},
|
| 716 |
-
}
|
|
|
|
| 1 |
"""
|
| 2 |
+
Documentation search tools for exploring HuggingFace and Gradio documentation.
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import asyncio
|
| 6 |
+
import json
|
| 7 |
import os
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
import httpx
|
| 11 |
from bs4 import BeautifulSoup
|
| 12 |
+
from whoosh.analysis import StemmingAnalyzer
|
| 13 |
+
from whoosh.fields import ID, TEXT, Schema
|
| 14 |
+
from whoosh.filedb.filestore import RamStorage
|
| 15 |
+
from whoosh.qparser import MultifieldParser, OrGroup
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Configuration
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
DEFAULT_MAX_RESULTS = 20
|
| 22 |
+
MAX_RESULTS_CAP = 50
|
| 23 |
+
|
| 24 |
+
GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt"
|
| 25 |
+
GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt"
|
| 26 |
+
|
| 27 |
+
COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
|
| 28 |
+
"optimum": [
|
| 29 |
+
"optimum",
|
| 30 |
+
"optimum-habana",
|
| 31 |
+
"optimum-neuron",
|
| 32 |
+
"optimum-intel",
|
| 33 |
+
"optimum-executorch",
|
| 34 |
+
"optimum-tpu",
|
| 35 |
+
],
|
| 36 |
+
"courses": [
|
| 37 |
+
"llm-course",
|
| 38 |
+
"robotics-course",
|
| 39 |
+
"mcp-course",
|
| 40 |
+
"smol-course",
|
| 41 |
+
"agents-course",
|
| 42 |
+
"deep-rl-course",
|
| 43 |
+
"computer-vision-course",
|
| 44 |
+
"audio-course",
|
| 45 |
+
"ml-games-course",
|
| 46 |
+
"diffusion-course",
|
| 47 |
+
"ml-for-3d-course",
|
| 48 |
+
"cookbook",
|
| 49 |
+
],
|
| 50 |
+
}
|
| 51 |
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Caches
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
|
| 56 |
+
_docs_cache: dict[str, list[dict[str, str]]] = {}
|
| 57 |
+
_index_cache: dict[str, tuple[Any, MultifieldParser]] = {}
|
| 58 |
+
_cache_lock = asyncio.Lock()
|
| 59 |
+
_openapi_cache: dict[str, Any] | None = None
|
| 60 |
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Gradio Documentation
|
| 63 |
+
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
|
| 66 |
+
async def _fetch_gradio_docs(query: str | None = None) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Fetch Gradio documentation.
|
| 69 |
+
Without query: Get full documentation from llms.txt
|
| 70 |
+
With query: Run embedding search on guides/demos for relevant content
|
| 71 |
+
"""
|
| 72 |
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 73 |
+
if not query:
|
| 74 |
+
resp = await client.get(GRADIO_LLMS_TXT_URL)
|
| 75 |
+
resp.raise_for_status()
|
| 76 |
+
return resp.text
|
| 77 |
+
|
| 78 |
+
resp = await client.post(
|
| 79 |
+
GRADIO_SEARCH_URL,
|
| 80 |
+
headers={
|
| 81 |
+
"Content-Type": "application/json",
|
| 82 |
+
"Origin": "https://gradio-docs-mcp.up.railway.app",
|
| 83 |
+
},
|
| 84 |
+
json={
|
| 85 |
+
"prompt_to_embed": query,
|
| 86 |
+
"SYSTEM_PROMPT": "$INSERT_GUIDES_DOCS_DEMOS",
|
| 87 |
+
"FALLBACK_PROMPT": "No results found",
|
| 88 |
+
},
|
| 89 |
+
)
|
| 90 |
+
resp.raise_for_status()
|
| 91 |
+
return resp.json().get("SYS_PROMPT", "No results found")
|
| 92 |
|
|
|
|
| 93 |
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# HF Documentation - Fetching
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
|
| 100 |
+
"""Fetch all docs for an endpoint by parsing sidebar and fetching each page."""
|
| 101 |
+
url = f"https://huggingface.co/docs/{endpoint}"
|
| 102 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 103 |
|
| 104 |
+
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 105 |
+
resp = await client.get(url, headers=headers)
|
| 106 |
+
resp.raise_for_status()
|
| 107 |
+
|
| 108 |
+
soup = BeautifulSoup(resp.text, "html.parser")
|
| 109 |
+
sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
|
| 110 |
+
if not sidebar:
|
| 111 |
+
raise ValueError(f"Could not find navigation sidebar for '{endpoint}'")
|
| 112 |
+
|
| 113 |
+
nav_items = []
|
| 114 |
+
for link in sidebar.find_all("a", href=True):
|
| 115 |
+
href = link["href"]
|
| 116 |
+
page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
|
| 117 |
+
nav_items.append({"title": link.get_text(strip=True), "url": page_url})
|
| 118 |
+
|
| 119 |
+
if not nav_items:
|
| 120 |
+
raise ValueError(f"No navigation links found for '{endpoint}'")
|
| 121 |
+
|
| 122 |
+
async def fetch_page(item: dict[str, str]) -> dict[str, str]:
|
| 123 |
+
md_url = f"{item['url']}.md"
|
| 124 |
+
try:
|
| 125 |
+
r = await client.get(md_url, headers=headers)
|
| 126 |
+
r.raise_for_status()
|
| 127 |
+
content = r.text.strip()
|
| 128 |
+
glimpse = content[:200] + "..." if len(content) > 200 else content
|
| 129 |
+
except Exception as e:
|
| 130 |
+
content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]"
|
| 131 |
+
return {
|
| 132 |
+
"title": item["title"],
|
| 133 |
+
"url": item["url"],
|
| 134 |
+
"md_url": md_url,
|
| 135 |
+
"glimpse": glimpse,
|
| 136 |
+
"content": content,
|
| 137 |
+
"section": endpoint,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
return list(await asyncio.gather(*[fetch_page(item) for item in nav_items]))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
|
| 144 |
+
"""Get docs for endpoint with caching. Expands composite endpoints."""
|
| 145 |
+
async with _cache_lock:
|
| 146 |
+
if endpoint in _docs_cache:
|
| 147 |
+
return _docs_cache[endpoint]
|
| 148 |
+
|
| 149 |
+
sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint])
|
| 150 |
+
all_docs: list[dict[str, str]] = []
|
| 151 |
+
|
| 152 |
+
for sub in sub_endpoints:
|
| 153 |
+
async with _cache_lock:
|
| 154 |
+
if sub in _docs_cache:
|
| 155 |
+
all_docs.extend(_docs_cache[sub])
|
| 156 |
+
continue
|
| 157 |
|
| 158 |
+
docs = await _fetch_endpoint_docs(hf_token, sub)
|
| 159 |
+
async with _cache_lock:
|
| 160 |
+
_docs_cache[sub] = docs
|
| 161 |
+
all_docs.extend(docs)
|
| 162 |
+
|
| 163 |
+
async with _cache_lock:
|
| 164 |
+
_docs_cache[endpoint] = all_docs
|
| 165 |
+
return all_docs
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ---------------------------------------------------------------------------
|
| 169 |
+
# HF Documentation - Search
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
async def _build_search_index(
|
| 174 |
+
endpoint: str, docs: list[dict[str, str]]
|
| 175 |
+
) -> tuple[Any, MultifieldParser]:
|
| 176 |
+
"""Build or retrieve cached Whoosh search index."""
|
| 177 |
+
async with _cache_lock:
|
| 178 |
+
if endpoint in _index_cache:
|
| 179 |
+
return _index_cache[endpoint]
|
| 180 |
+
|
| 181 |
+
analyzer = StemmingAnalyzer()
|
| 182 |
+
schema = Schema(
|
| 183 |
+
title=TEXT(stored=True, analyzer=analyzer),
|
| 184 |
+
url=ID(stored=True, unique=True),
|
| 185 |
+
md_url=ID(stored=True),
|
| 186 |
+
section=ID(stored=True),
|
| 187 |
+
glimpse=TEXT(stored=True, analyzer=analyzer),
|
| 188 |
+
content=TEXT(stored=False, analyzer=analyzer),
|
| 189 |
+
)
|
| 190 |
+
storage = RamStorage()
|
| 191 |
+
index = storage.create_index(schema)
|
| 192 |
+
writer = index.writer()
|
| 193 |
+
for doc in docs:
|
| 194 |
+
writer.add_document(
|
| 195 |
+
title=doc.get("title", ""),
|
| 196 |
+
url=doc.get("url", ""),
|
| 197 |
+
md_url=doc.get("md_url", ""),
|
| 198 |
+
section=doc.get("section", endpoint),
|
| 199 |
+
glimpse=doc.get("glimpse", ""),
|
| 200 |
+
content=doc.get("content", ""),
|
| 201 |
+
)
|
| 202 |
+
writer.commit()
|
| 203 |
|
| 204 |
+
parser = MultifieldParser(
|
| 205 |
+
["title", "content"],
|
| 206 |
+
schema=schema,
|
| 207 |
+
fieldboosts={"title": 2.0, "content": 1.0},
|
| 208 |
+
group=OrGroup,
|
| 209 |
+
)
|
| 210 |
|
| 211 |
+
async with _cache_lock:
|
| 212 |
+
_index_cache[endpoint] = (index, parser)
|
| 213 |
+
return index, parser
|
| 214 |
|
| 215 |
|
| 216 |
+
async def _search_docs(
|
| 217 |
+
endpoint: str, docs: list[dict[str, str]], query: str, limit: int
|
| 218 |
+
) -> tuple[list[dict[str, Any]], str | None]:
|
| 219 |
+
"""Search docs using Whoosh. Returns (results, fallback_message)."""
|
| 220 |
+
index, parser = await _build_search_index(endpoint, docs)
|
|
|
|
| 221 |
|
| 222 |
try:
|
| 223 |
+
query_obj = parser.parse(query)
|
| 224 |
+
except Exception:
|
| 225 |
+
return [], "Query contained unsupported syntax; showing default ordering."
|
| 226 |
+
|
| 227 |
+
with index.searcher() as searcher:
|
| 228 |
+
results = searcher.search(query_obj, limit=limit)
|
| 229 |
+
matches = [
|
| 230 |
+
{
|
| 231 |
+
"title": hit["title"],
|
| 232 |
+
"url": hit["url"],
|
| 233 |
+
"md_url": hit.get("md_url", ""),
|
| 234 |
+
"section": hit.get("section", endpoint),
|
| 235 |
+
"glimpse": hit["glimpse"],
|
| 236 |
+
"score": round(hit.score, 2),
|
| 237 |
+
}
|
| 238 |
+
for hit in results
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
if not matches:
|
| 242 |
+
return [], "No strong matches found; showing default ordering."
|
| 243 |
+
return matches, None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ---------------------------------------------------------------------------
|
| 247 |
+
# HF Documentation - Formatting
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _format_results(
|
| 252 |
+
endpoint: str,
|
| 253 |
+
items: list[dict[str, Any]],
|
| 254 |
+
total: int,
|
| 255 |
+
query: str | None = None,
|
| 256 |
+
note: str | None = None,
|
| 257 |
+
) -> str:
|
| 258 |
+
"""Format search results as readable text."""
|
| 259 |
+
base_url = f"https://huggingface.co/docs/{endpoint}"
|
| 260 |
+
out = f"Documentation structure for: {base_url}\n\n"
|
| 261 |
|
| 262 |
+
if query:
|
| 263 |
+
out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages"
|
| 264 |
+
if note:
|
| 265 |
+
out += f" ({note})"
|
| 266 |
+
out += "\n\n"
|
| 267 |
+
else:
|
| 268 |
+
out += f"Found {len(items)} page(s) (total available: {total}).\n"
|
| 269 |
+
if note:
|
| 270 |
+
out += f"({note})\n"
|
| 271 |
+
out += "\n"
|
| 272 |
|
| 273 |
+
for i, item in enumerate(items, 1):
|
| 274 |
+
out += f"{i}. **{item['title']}**\n"
|
| 275 |
+
out += f" URL: {item['url']}\n"
|
| 276 |
+
out += f" Section: {item.get('section', endpoint)}\n"
|
| 277 |
+
if query and "score" in item:
|
| 278 |
+
out += f" Relevance score: {item['score']:.2f}\n"
|
| 279 |
+
out += f" Glimpse: {item['glimpse']}\n\n"
|
| 280 |
+
|
| 281 |
+
return out
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
# ---------------------------------------------------------------------------
|
| 285 |
+
# Handlers
|
| 286 |
+
# ---------------------------------------------------------------------------
|
|
|
|
| 287 |
|
|
|
|
| 288 |
|
| 289 |
+
async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| 290 |
+
"""Explore documentation structure with optional search query."""
|
| 291 |
+
endpoint = arguments.get("endpoint", "").lstrip("/")
|
| 292 |
+
query = arguments.get("query")
|
| 293 |
+
max_results = arguments.get("max_results")
|
| 294 |
|
| 295 |
+
if not endpoint:
|
| 296 |
+
return "Error: No endpoint provided", False
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
# Gradio uses its own API
|
| 299 |
+
if endpoint.lower() == "gradio":
|
| 300 |
+
try:
|
| 301 |
+
clean_query = (
|
| 302 |
+
query.strip() if isinstance(query, str) and query.strip() else None
|
| 303 |
+
)
|
| 304 |
+
content = await _fetch_gradio_docs(clean_query)
|
| 305 |
+
header = "# Gradio Documentation\n\n"
|
| 306 |
+
if clean_query:
|
| 307 |
+
header += f"Query: '{clean_query}'\n\n"
|
| 308 |
+
header += "Source: https://gradio.app/docs\n\n---\n\n"
|
| 309 |
+
return header + content, True
|
| 310 |
+
except httpx.HTTPStatusError as e:
|
| 311 |
+
return f"HTTP error fetching Gradio docs: {e.response.status_code}", False
|
| 312 |
+
except httpx.RequestError as e:
|
| 313 |
+
return f"Request error fetching Gradio docs: {str(e)}", False
|
| 314 |
+
except Exception as e:
|
| 315 |
+
return f"Error fetching Gradio docs: {str(e)}", False
|
| 316 |
+
|
| 317 |
+
# HF docs
|
| 318 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 319 |
+
if not hf_token:
|
| 320 |
+
return "Error: HF_TOKEN environment variable not set", False
|
| 321 |
|
| 322 |
+
try:
|
| 323 |
+
max_results_int = int(max_results) if max_results is not None else None
|
| 324 |
+
except (TypeError, ValueError):
|
| 325 |
+
return "Error: max_results must be an integer", False
|
| 326 |
|
| 327 |
+
if max_results_int is not None and max_results_int <= 0:
|
| 328 |
+
return "Error: max_results must be greater than zero", False
|
| 329 |
|
| 330 |
+
try:
|
| 331 |
+
docs = await _get_docs(hf_token, endpoint)
|
| 332 |
+
total = len(docs)
|
| 333 |
+
|
| 334 |
+
# Determine limit
|
| 335 |
+
if max_results_int is None:
|
| 336 |
+
limit = DEFAULT_MAX_RESULTS
|
| 337 |
+
limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)."
|
| 338 |
+
elif max_results_int > MAX_RESULTS_CAP:
|
| 339 |
+
limit = MAX_RESULTS_CAP
|
| 340 |
+
limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)."
|
| 341 |
+
else:
|
| 342 |
+
limit = max_results_int
|
| 343 |
+
limit_note = None
|
| 344 |
+
|
| 345 |
+
# Search or paginate
|
| 346 |
+
clean_query = (
|
| 347 |
+
query.strip() if isinstance(query, str) and query.strip() else None
|
| 348 |
+
)
|
| 349 |
+
fallback_msg = None
|
| 350 |
|
| 351 |
+
if clean_query:
|
| 352 |
+
results, fallback_msg = await _search_docs(
|
| 353 |
+
endpoint, docs, clean_query, limit
|
| 354 |
+
)
|
| 355 |
+
if not results:
|
| 356 |
+
results = docs[:limit]
|
| 357 |
+
else:
|
| 358 |
+
results = docs[:limit]
|
| 359 |
|
| 360 |
+
# Combine notes
|
| 361 |
+
notes = []
|
| 362 |
+
if fallback_msg:
|
| 363 |
+
notes.append(fallback_msg)
|
| 364 |
+
if limit_note:
|
| 365 |
+
notes.append(limit_note)
|
| 366 |
+
note = "; ".join(notes) if notes else None
|
| 367 |
|
| 368 |
+
return _format_results(endpoint, results, total, clean_query, note), True
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
except httpx.HTTPStatusError as e:
|
| 371 |
+
return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False
|
| 372 |
+
except httpx.RequestError as e:
|
| 373 |
+
return f"Request error: {str(e)}", False
|
| 374 |
+
except ValueError as e:
|
| 375 |
+
return f"Error: {str(e)}", False
|
| 376 |
+
except Exception as e:
|
| 377 |
+
return f"Unexpected error: {str(e)}", False
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
+
async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| 381 |
+
"""Fetch full markdown content of a documentation page."""
|
| 382 |
+
url = arguments.get("url", "")
|
| 383 |
+
if not url:
|
| 384 |
+
return "Error: No URL provided", False
|
| 385 |
|
|
|
|
| 386 |
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
| 387 |
if not hf_token:
|
| 388 |
return "Error: HF_TOKEN environment variable not set", False
|
| 389 |
|
| 390 |
+
if not url.endswith(".md"):
|
| 391 |
+
url = f"{url}.md"
|
| 392 |
|
| 393 |
try:
|
| 394 |
+
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 395 |
+
resp = await client.get(
|
| 396 |
+
url, headers={"Authorization": f"Bearer {hf_token}"}
|
| 397 |
+
)
|
| 398 |
+
resp.raise_for_status()
|
| 399 |
+
return f"Documentation from: {url}\n\n{resp.text}", True
|
| 400 |
except httpx.HTTPStatusError as e:
|
| 401 |
return (
|
| 402 |
+
f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
|
| 403 |
False,
|
| 404 |
)
|
| 405 |
except httpx.RequestError as e:
|
| 406 |
+
return f"Request error fetching {url}: {str(e)}", False
|
|
|
|
|
|
|
| 407 |
except Exception as e:
|
| 408 |
+
return f"Error fetching documentation: {str(e)}", False
|
| 409 |
|
| 410 |
|
| 411 |
+
# ---------------------------------------------------------------------------
|
| 412 |
+
# OpenAPI Search
|
| 413 |
+
# ---------------------------------------------------------------------------
|
| 414 |
|
|
|
|
|
|
|
| 415 |
|
| 416 |
+
async def _fetch_openapi_spec() -> dict[str, Any]:
|
| 417 |
+
"""Fetch and cache HuggingFace OpenAPI specification."""
|
| 418 |
+
global _openapi_cache
|
| 419 |
+
if _openapi_cache is not None:
|
| 420 |
+
return _openapi_cache
|
| 421 |
|
| 422 |
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 423 |
+
resp = await client.get("https://huggingface.co/.well-known/openapi.json")
|
| 424 |
+
resp.raise_for_status()
|
| 425 |
|
| 426 |
+
_openapi_cache = resp.json()
|
| 427 |
+
return _openapi_cache
|
|
|
|
|
|
|
| 428 |
|
| 429 |
|
| 430 |
def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
|
| 431 |
+
"""Extract all unique tags from OpenAPI spec."""
|
| 432 |
tags = set()
|
|
|
|
|
|
|
| 433 |
for tag_obj in spec.get("tags", []):
|
| 434 |
if "name" in tag_obj:
|
| 435 |
tags.add(tag_obj["name"])
|
| 436 |
+
for path_item in spec.get("paths", {}).values():
|
| 437 |
+
for method, op in path_item.items():
|
|
|
|
|
|
|
| 438 |
if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
| 439 |
+
for tag in op.get("tags", []):
|
| 440 |
tags.add(tag)
|
| 441 |
+
return sorted(tags)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
|
| 444 |
def _generate_curl_example(endpoint: dict[str, Any]) -> str:
|
| 445 |
+
"""Generate curl command example for an endpoint."""
|
| 446 |
method = endpoint["method"]
|
| 447 |
path = endpoint["path"]
|
| 448 |
base_url = endpoint["base_url"]
|
| 449 |
|
| 450 |
+
# Build URL with path parameters
|
| 451 |
full_path = path
|
| 452 |
for param in endpoint.get("parameters", []):
|
| 453 |
if param.get("in") == "path" and param.get("required"):
|
| 454 |
+
name = param["name"]
|
| 455 |
example = param.get(
|
| 456 |
+
"example", param.get("schema", {}).get("example", f"<{name}>")
|
| 457 |
)
|
| 458 |
+
full_path = full_path.replace(f"{{{name}}}", str(example))
|
| 459 |
|
| 460 |
curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
|
| 461 |
|
| 462 |
+
# Add query parameters
|
| 463 |
query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
|
| 464 |
if query_params and query_params[0].get("required"):
|
| 465 |
param = query_params[0]
|
| 466 |
example = param.get("example", param.get("schema", {}).get("example", "value"))
|
| 467 |
curl += f"?{param['name']}={example}"
|
| 468 |
|
|
|
|
| 469 |
curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
|
| 470 |
|
| 471 |
+
# Add request body
|
| 472 |
if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
|
| 473 |
content = endpoint["request_body"].get("content", {})
|
| 474 |
if "application/json" in content:
|
|
|
|
| 476 |
schema = content["application/json"].get("schema", {})
|
| 477 |
example = schema.get("example", "{}")
|
| 478 |
if isinstance(example, dict):
|
|
|
|
|
|
|
| 479 |
example = json.dumps(example, indent=2)
|
| 480 |
curl += f" \\\n -d '{example}'"
|
| 481 |
|
|
|
|
| 483 |
|
| 484 |
|
| 485 |
def _format_parameters(parameters: list[dict[str, Any]]) -> str:
|
| 486 |
+
"""Format parameter information from OpenAPI spec."""
|
| 487 |
if not parameters:
|
| 488 |
return ""
|
| 489 |
|
|
|
|
| 490 |
path_params = [p for p in parameters if p.get("in") == "path"]
|
| 491 |
query_params = [p for p in parameters if p.get("in") == "query"]
|
| 492 |
header_params = [p for p in parameters if p.get("in") == "header"]
|
| 493 |
|
| 494 |
output = []
|
| 495 |
|
| 496 |
+
for label, params in [
|
| 497 |
+
("Path Parameters", path_params),
|
| 498 |
+
("Query Parameters", query_params),
|
| 499 |
+
("Header Parameters", header_params),
|
| 500 |
+
]:
|
| 501 |
+
if not params:
|
| 502 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
if output:
|
| 504 |
output.append("")
|
| 505 |
+
output.append(f"**{label}:**")
|
| 506 |
+
for p in params:
|
| 507 |
+
name = p.get("name", "")
|
| 508 |
+
required = " (required)" if p.get("required") else " (optional)"
|
| 509 |
+
desc = p.get("description", "")
|
| 510 |
+
ptype = p.get("schema", {}).get("type", "string")
|
| 511 |
+
example = p.get("example") or p.get("schema", {}).get("example", "")
|
| 512 |
+
|
| 513 |
+
output.append(f"- `{name}` ({ptype}){required}: {desc}")
|
| 514 |
if example:
|
| 515 |
output.append(f" Example: `{example}`")
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
return "\n".join(output)
|
| 518 |
|
| 519 |
|
| 520 |
def _format_response_info(responses: dict[str, Any]) -> str:
|
| 521 |
+
"""Format response information from OpenAPI spec."""
|
| 522 |
if not responses:
|
| 523 |
return "No response information available"
|
| 524 |
|
| 525 |
output = []
|
| 526 |
+
for status, resp_obj in list(responses.items())[:3]:
|
| 527 |
+
desc = resp_obj.get("description", "")
|
| 528 |
+
output.append(f"- **{status}**: {desc}")
|
| 529 |
+
content = resp_obj.get("content", {})
|
|
|
|
|
|
|
|
|
|
| 530 |
if "application/json" in content:
|
| 531 |
schema = content["application/json"].get("schema", {})
|
| 532 |
if "type" in schema:
|
|
|
|
| 536 |
|
| 537 |
|
| 538 |
def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
|
| 539 |
+
"""Format OpenAPI search results with curl examples."""
|
| 540 |
if not results:
|
| 541 |
return f"No API endpoints found with tag '{tag}'"
|
| 542 |
|
| 543 |
+
out = f"# API Endpoints for tag: `{tag}`\n\n"
|
| 544 |
+
out += f"Found {len(results)} endpoint(s)\n\n---\n\n"
|
|
|
|
| 545 |
|
| 546 |
+
for i, ep in enumerate(results, 1):
|
| 547 |
+
out += f"## {i}. {ep['method']} {ep['path']}\n\n"
|
| 548 |
|
| 549 |
+
if ep["summary"]:
|
| 550 |
+
out += f"**Summary:** {ep['summary']}\n\n"
|
| 551 |
|
| 552 |
+
if ep["description"]:
|
| 553 |
+
desc = ep["description"][:300]
|
| 554 |
+
if len(ep["description"]) > 300:
|
| 555 |
desc += "..."
|
| 556 |
+
out += f"**Description:** {desc}\n\n"
|
| 557 |
|
| 558 |
+
params_info = _format_parameters(ep.get("parameters", []))
|
|
|
|
| 559 |
if params_info:
|
| 560 |
+
out += params_info + "\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
+
out += "**Usage:**\n```bash\n"
|
| 563 |
+
out += _generate_curl_example(ep)
|
| 564 |
+
out += "\n```\n\n"
|
|
|
|
| 565 |
|
| 566 |
+
out += "**Returns:**\n"
|
| 567 |
+
out += _format_response_info(ep["responses"])
|
| 568 |
+
out += "\n\n---\n\n"
|
| 569 |
|
| 570 |
+
return out
|
| 571 |
|
| 572 |
|
| 573 |
async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| 574 |
+
"""Search HuggingFace OpenAPI specification by tag."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
tag = arguments.get("tag", "")
|
|
|
|
| 576 |
if not tag:
|
| 577 |
return "Error: No tag provided", False
|
| 578 |
|
| 579 |
try:
|
|
|
|
| 580 |
spec = await _fetch_openapi_spec()
|
| 581 |
+
paths = spec.get("paths", {})
|
| 582 |
+
servers = spec.get("servers", [])
|
| 583 |
+
base_url = (
|
| 584 |
+
servers[0].get("url", "https://huggingface.co")
|
| 585 |
+
if servers
|
| 586 |
+
else "https://huggingface.co"
|
| 587 |
+
)
|
| 588 |
|
| 589 |
+
results = []
|
| 590 |
+
for path, path_item in paths.items():
|
| 591 |
+
for method, op in path_item.items():
|
| 592 |
+
if method not in [
|
| 593 |
+
"get",
|
| 594 |
+
"post",
|
| 595 |
+
"put",
|
| 596 |
+
"delete",
|
| 597 |
+
"patch",
|
| 598 |
+
"head",
|
| 599 |
+
"options",
|
| 600 |
+
]:
|
| 601 |
+
continue
|
| 602 |
+
if tag not in op.get("tags", []):
|
| 603 |
+
continue
|
| 604 |
|
| 605 |
+
results.append(
|
| 606 |
+
{
|
| 607 |
+
"path": path,
|
| 608 |
+
"method": method.upper(),
|
| 609 |
+
"operationId": op.get("operationId", ""),
|
| 610 |
+
"summary": op.get("summary", ""),
|
| 611 |
+
"description": op.get("description", ""),
|
| 612 |
+
"parameters": op.get("parameters", []),
|
| 613 |
+
"request_body": op.get("requestBody", {}),
|
| 614 |
+
"responses": op.get("responses", {}),
|
| 615 |
+
"base_url": base_url,
|
| 616 |
+
}
|
| 617 |
+
)
|
| 618 |
|
| 619 |
+
return _format_openapi_results(results, tag), True
|
| 620 |
|
| 621 |
except httpx.HTTPStatusError as e:
|
| 622 |
return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
|
|
|
|
| 626 |
return f"Error searching OpenAPI spec: {str(e)}", False
|
| 627 |
|
| 628 |
|
| 629 |
+
async def _get_api_search_tool_spec() -> dict[str, Any]:
|
| 630 |
+
"""Generate OpenAPI tool spec with tags populated at runtime."""
|
| 631 |
+
spec = await _fetch_openapi_spec()
|
| 632 |
+
tags = _extract_all_tags(spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
+
return {
|
| 635 |
+
"name": "search_hf_api_endpoints",
|
| 636 |
+
"description": (
|
| 637 |
+
"Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
|
| 638 |
+
"**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
|
| 639 |
+
"(3) Need authentication patterns, (4) Understanding API parameters and responses, "
|
| 640 |
+
"(5) Need curl examples for HTTP requests. "
|
| 641 |
+
"Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
|
| 642 |
+
"Tags group related operations: repos, models, datasets, inference, spaces, etc."
|
| 643 |
+
),
|
| 644 |
+
"parameters": {
|
| 645 |
+
"type": "object",
|
| 646 |
+
"properties": {
|
| 647 |
+
"tag": {
|
| 648 |
+
"type": "string",
|
| 649 |
+
"enum": tags,
|
| 650 |
+
"description": "The API tag to search for. Each tag groups related API endpoints.",
|
| 651 |
+
},
|
| 652 |
+
},
|
| 653 |
+
"required": ["tag"],
|
| 654 |
+
},
|
| 655 |
+
}
|
| 656 |
|
| 657 |
|
| 658 |
+
# ---------------------------------------------------------------------------
|
| 659 |
+
# Tool Specifications
|
| 660 |
+
# ---------------------------------------------------------------------------
|
| 661 |
+
|
| 662 |
+
DOC_ENDPOINTS = [
|
| 663 |
+
"hub",
|
| 664 |
+
"transformers",
|
| 665 |
+
"diffusers",
|
| 666 |
+
"datasets",
|
| 667 |
+
"gradio",
|
| 668 |
+
"trackio",
|
| 669 |
+
"smolagents",
|
| 670 |
+
"huggingface_hub",
|
| 671 |
+
"huggingface.js",
|
| 672 |
+
"transformers.js",
|
| 673 |
+
"inference-providers",
|
| 674 |
+
"inference-endpoints",
|
| 675 |
+
"peft",
|
| 676 |
+
"accelerate",
|
| 677 |
+
"optimum",
|
| 678 |
+
"tokenizers",
|
| 679 |
+
"courses",
|
| 680 |
+
"evaluate",
|
| 681 |
+
"tasks",
|
| 682 |
+
"dataset-viewer",
|
| 683 |
+
"trl",
|
| 684 |
+
"simulate",
|
| 685 |
+
"sagemaker",
|
| 686 |
+
"timm",
|
| 687 |
+
"safetensors",
|
| 688 |
+
"tgi",
|
| 689 |
+
"setfit",
|
| 690 |
+
"lerobot",
|
| 691 |
+
"autotrain",
|
| 692 |
+
"tei",
|
| 693 |
+
"bitsandbytes",
|
| 694 |
+
"sentence_transformers",
|
| 695 |
+
"chat-ui",
|
| 696 |
+
"leaderboards",
|
| 697 |
+
"lighteval",
|
| 698 |
+
"argilla",
|
| 699 |
+
"distilabel",
|
| 700 |
+
"microsoft-azure",
|
| 701 |
+
"kernels",
|
| 702 |
+
"google-cloud",
|
| 703 |
+
]
|
| 704 |
|
| 705 |
EXPLORE_HF_DOCS_TOOL_SPEC = {
|
| 706 |
"name": "explore_hf_docs",
|
| 707 |
"description": (
|
| 708 |
+
"Explore Hugging Face documentation structure and discover available pages with 200-character previews. "
|
| 709 |
"⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
|
| 710 |
"Your training data may be outdated - current documentation is the source of truth. "
|
| 711 |
"**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
|
|
|
|
| 715 |
"Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
|
| 716 |
"**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
|
| 717 |
"**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
|
| 718 |
+
" By default returns the top 20 results; set max_results (max 50) to adjust."
|
| 719 |
),
|
| 720 |
"parameters": {
|
| 721 |
"type": "object",
|
| 722 |
"properties": {
|
| 723 |
"endpoint": {
|
| 724 |
"type": "string",
|
| 725 |
+
"enum": DOC_ENDPOINTS,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
"description": (
|
| 727 |
"The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
|
| 728 |
+
"• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n"
|
| 729 |
"• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
|
| 730 |
"• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
|
| 731 |
"• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
|
| 732 |
"• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
|
| 733 |
+
"• gradio — UI components and demos for ML models. Uses Gradio's native API: without query returns full docs (llms.txt), with query uses embedding search for precise results.\n"
|
| 734 |
"• trackio — Experiment tracking, metrics logging, and run comparison.\n"
|
| 735 |
"• smolagents — Lightweight agent abstractions and tool-using patterns.\n"
|
| 736 |
"• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n"
|
|
|
|
| 740 |
"• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n"
|
| 741 |
"• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
|
| 742 |
"• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
|
| 743 |
+
"• optimum — Hardware-aware optimization and model export tooling, including Habana, Neuron, Intel, ExecuTorch, and TPU variants.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
"• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
"• evaluate — Metrics, evaluation workflows, and training-loop integration.\n"
|
| 746 |
"• tasks — Canonical task definitions and model categorization.\n"
|
| 747 |
"• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n"
|
|
|
|
| 752 |
"• safetensors — Safe, fast tensor serialization format.\n"
|
| 753 |
"• tgi — High-throughput text generation server for LLMs.\n"
|
| 754 |
"• setfit — Few-shot text classification via sentence embeddings.\n"
|
|
|
|
| 755 |
"• lerobot — Robotics datasets, policies, and learning workflows.\n"
|
| 756 |
"• autotrain — No/low-code model training on Hugging Face.\n"
|
| 757 |
"• tei — Optimized inference server for embedding workloads.\n"
|
| 758 |
"• bitsandbytes — Quantization and memory-efficient optimizers.\n"
|
|
|
|
| 759 |
"• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n"
|
|
|
|
|
|
|
|
|
|
| 760 |
"• chat-ui — Reference chat interfaces for LLM deployment.\n"
|
| 761 |
"• leaderboards — Evaluation leaderboards and submission mechanics.\n"
|
| 762 |
"• lighteval — Lightweight, reproducible LLM evaluation framework.\n"
|
|
|
|
| 767 |
"• google-cloud — GCP deployment and serving workflows.\n"
|
| 768 |
),
|
| 769 |
},
|
| 770 |
+
"query": {
|
| 771 |
+
"type": "string",
|
| 772 |
+
"description": (
|
| 773 |
+
"Optional keyword query to rank and filter documentation pages. "
|
| 774 |
+
"For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'."
|
| 775 |
+
),
|
| 776 |
+
},
|
| 777 |
+
"max_results": {
|
| 778 |
+
"type": "integer",
|
| 779 |
+
"description": "Max results (default 20, max 50). Ignored for Gradio.",
|
| 780 |
+
"minimum": 1,
|
| 781 |
+
"maximum": 50,
|
| 782 |
+
},
|
| 783 |
},
|
| 784 |
"required": ["endpoint"],
|
| 785 |
},
|
|
|
|
| 814 |
"required": ["url"],
|
| 815 |
},
|
| 816 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/tools/hf_repo_files_tool.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF Repo Files Tool - File operations on Hugging Face repositories
|
| 3 |
+
|
| 4 |
+
Operations: list, read, upload, delete
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any, Dict, Literal, Optional
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
+
|
| 13 |
+
from agent.tools.types import ToolResult
|
| 14 |
+
|
| 15 |
+
OperationType = Literal["list", "read", "upload", "delete"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def _async_call(func, *args, **kwargs):
|
| 19 |
+
"""Wrap synchronous HfApi calls for async context."""
|
| 20 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
| 24 |
+
"""Build the Hub URL for a repository."""
|
| 25 |
+
if repo_type == "model":
|
| 26 |
+
return f"https://huggingface.co/{repo_id}"
|
| 27 |
+
return f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _format_size(size_bytes: int) -> str:
|
| 31 |
+
"""Format file size in human-readable form."""
|
| 32 |
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
| 33 |
+
if size_bytes < 1024:
|
| 34 |
+
return f"{size_bytes:.1f}{unit}"
|
| 35 |
+
size_bytes /= 1024
|
| 36 |
+
return f"{size_bytes:.1f}PB"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HfRepoFilesTool:
|
| 40 |
+
"""Tool for file operations on HF repos."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, hf_token: Optional[str] = None):
|
| 43 |
+
self.api = HfApi(token=hf_token)
|
| 44 |
+
|
| 45 |
+
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 46 |
+
"""Execute the specified operation."""
|
| 47 |
+
operation = args.get("operation")
|
| 48 |
+
|
| 49 |
+
if not operation:
|
| 50 |
+
return self._help()
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
handlers = {
|
| 54 |
+
"list": self._list,
|
| 55 |
+
"read": self._read,
|
| 56 |
+
"upload": self._upload,
|
| 57 |
+
"delete": self._delete,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
handler = handlers.get(operation)
|
| 61 |
+
if handler:
|
| 62 |
+
return await handler(args)
|
| 63 |
+
else:
|
| 64 |
+
return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
|
| 65 |
+
|
| 66 |
+
except RepositoryNotFoundError:
|
| 67 |
+
return self._error(f"Repository not found: {args.get('repo_id')}")
|
| 68 |
+
except EntryNotFoundError:
|
| 69 |
+
return self._error(f"File not found: {args.get('path')}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
return self._error(f"Error: {str(e)}")
|
| 72 |
+
|
| 73 |
+
def _help(self) -> ToolResult:
|
| 74 |
+
"""Show usage instructions."""
|
| 75 |
+
return {
|
| 76 |
+
"formatted": """**hf_repo_files** - File operations on HF repos
|
| 77 |
+
|
| 78 |
+
**Operations:**
|
| 79 |
+
- `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
|
| 80 |
+
- `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
|
| 81 |
+
- `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
|
| 82 |
+
- `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
|
| 83 |
+
|
| 84 |
+
**Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
|
| 85 |
+
"totalResults": 1,
|
| 86 |
+
"resultsShared": 1,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
async def _list(self, args: Dict[str, Any]) -> ToolResult:
|
| 90 |
+
"""List files in a repository."""
|
| 91 |
+
repo_id = args.get("repo_id")
|
| 92 |
+
if not repo_id:
|
| 93 |
+
return self._error("repo_id is required")
|
| 94 |
+
|
| 95 |
+
repo_type = args.get("repo_type", "model")
|
| 96 |
+
revision = args.get("revision", "main")
|
| 97 |
+
path = args.get("path", "")
|
| 98 |
+
|
| 99 |
+
items = list(await _async_call(
|
| 100 |
+
self.api.list_repo_tree,
|
| 101 |
+
repo_id=repo_id,
|
| 102 |
+
repo_type=repo_type,
|
| 103 |
+
revision=revision,
|
| 104 |
+
path_in_repo=path,
|
| 105 |
+
recursive=True,
|
| 106 |
+
))
|
| 107 |
+
|
| 108 |
+
if not items:
|
| 109 |
+
return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
|
| 110 |
+
|
| 111 |
+
lines = []
|
| 112 |
+
total_size = 0
|
| 113 |
+
for item in sorted(items, key=lambda x: x.path):
|
| 114 |
+
if hasattr(item, "size") and item.size:
|
| 115 |
+
total_size += item.size
|
| 116 |
+
lines.append(f"{item.path} ({_format_size(item.size)})")
|
| 117 |
+
else:
|
| 118 |
+
lines.append(f"{item.path}/")
|
| 119 |
+
|
| 120 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 121 |
+
response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
|
| 122 |
+
|
| 123 |
+
return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
|
| 124 |
+
|
| 125 |
+
async def _read(self, args: Dict[str, Any]) -> ToolResult:
|
| 126 |
+
"""Read file content from a repository."""
|
| 127 |
+
repo_id = args.get("repo_id")
|
| 128 |
+
path = args.get("path")
|
| 129 |
+
|
| 130 |
+
if not repo_id:
|
| 131 |
+
return self._error("repo_id is required")
|
| 132 |
+
if not path:
|
| 133 |
+
return self._error("path is required")
|
| 134 |
+
|
| 135 |
+
repo_type = args.get("repo_type", "model")
|
| 136 |
+
revision = args.get("revision", "main")
|
| 137 |
+
max_chars = args.get("max_chars", 50000)
|
| 138 |
+
|
| 139 |
+
file_path = await _async_call(
|
| 140 |
+
hf_hub_download,
|
| 141 |
+
repo_id=repo_id,
|
| 142 |
+
filename=path,
|
| 143 |
+
repo_type=repo_type,
|
| 144 |
+
revision=revision,
|
| 145 |
+
token=self.api.token,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 150 |
+
content = f.read()
|
| 151 |
+
|
| 152 |
+
truncated = len(content) > max_chars
|
| 153 |
+
if truncated:
|
| 154 |
+
content = content[:max_chars]
|
| 155 |
+
|
| 156 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
|
| 157 |
+
response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
|
| 158 |
+
|
| 159 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 160 |
+
|
| 161 |
+
except UnicodeDecodeError:
|
| 162 |
+
import os
|
| 163 |
+
size = os.path.getsize(file_path)
|
| 164 |
+
return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
|
| 165 |
+
|
| 166 |
+
async def _upload(self, args: Dict[str, Any]) -> ToolResult:
|
| 167 |
+
"""Upload content to a repository."""
|
| 168 |
+
repo_id = args.get("repo_id")
|
| 169 |
+
path = args.get("path")
|
| 170 |
+
content = args.get("content")
|
| 171 |
+
|
| 172 |
+
if not repo_id:
|
| 173 |
+
return self._error("repo_id is required")
|
| 174 |
+
if not path:
|
| 175 |
+
return self._error("path is required")
|
| 176 |
+
if content is None:
|
| 177 |
+
return self._error("content is required")
|
| 178 |
+
|
| 179 |
+
repo_type = args.get("repo_type", "model")
|
| 180 |
+
revision = args.get("revision", "main")
|
| 181 |
+
create_pr = args.get("create_pr", False)
|
| 182 |
+
commit_message = args.get("commit_message", f"Upload {path}")
|
| 183 |
+
|
| 184 |
+
file_bytes = content.encode("utf-8") if isinstance(content, str) else content
|
| 185 |
+
|
| 186 |
+
result = await _async_call(
|
| 187 |
+
self.api.upload_file,
|
| 188 |
+
path_or_fileobj=file_bytes,
|
| 189 |
+
path_in_repo=path,
|
| 190 |
+
repo_id=repo_id,
|
| 191 |
+
repo_type=repo_type,
|
| 192 |
+
revision=revision,
|
| 193 |
+
commit_message=commit_message,
|
| 194 |
+
create_pr=create_pr,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 198 |
+
if create_pr and hasattr(result, "pr_url"):
|
| 199 |
+
response = f"**Uploaded as PR**\n{result.pr_url}"
|
| 200 |
+
else:
|
| 201 |
+
response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
|
| 202 |
+
|
| 203 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 204 |
+
|
| 205 |
+
async def _delete(self, args: Dict[str, Any]) -> ToolResult:
|
| 206 |
+
"""Delete files from a repository."""
|
| 207 |
+
repo_id = args.get("repo_id")
|
| 208 |
+
patterns = args.get("patterns")
|
| 209 |
+
|
| 210 |
+
if not repo_id:
|
| 211 |
+
return self._error("repo_id is required")
|
| 212 |
+
if not patterns:
|
| 213 |
+
return self._error("patterns is required (list of paths/wildcards)")
|
| 214 |
+
|
| 215 |
+
if isinstance(patterns, str):
|
| 216 |
+
patterns = [patterns]
|
| 217 |
+
|
| 218 |
+
repo_type = args.get("repo_type", "model")
|
| 219 |
+
revision = args.get("revision", "main")
|
| 220 |
+
create_pr = args.get("create_pr", False)
|
| 221 |
+
commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
|
| 222 |
+
|
| 223 |
+
await _async_call(
|
| 224 |
+
self.api.delete_files,
|
| 225 |
+
repo_id=repo_id,
|
| 226 |
+
delete_patterns=patterns,
|
| 227 |
+
repo_type=repo_type,
|
| 228 |
+
revision=revision,
|
| 229 |
+
commit_message=commit_message,
|
| 230 |
+
create_pr=create_pr,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
|
| 234 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 235 |
+
|
| 236 |
+
def _error(self, message: str) -> ToolResult:
|
| 237 |
+
"""Return an error result."""
|
| 238 |
+
return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Tool specification
|
| 242 |
+
HF_REPO_FILES_TOOL_SPEC = {
|
| 243 |
+
"name": "hf_repo_files",
|
| 244 |
+
"description": (
|
| 245 |
+
"Read and write files in HF repos (models/datasets/spaces).\n\n"
|
| 246 |
+
"## Operations\n"
|
| 247 |
+
"- **list**: List files with sizes and structure\n"
|
| 248 |
+
"- **read**: Read file content (text files only)\n"
|
| 249 |
+
"- **upload**: Upload content to repo (can create PR)\n"
|
| 250 |
+
"- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
|
| 251 |
+
"## Use when\n"
|
| 252 |
+
"- Need to see what files exist in a repo\n"
|
| 253 |
+
"- Want to read config.json, README.md, or other text files\n"
|
| 254 |
+
"- Uploading training scripts, configs, or results to a repo\n"
|
| 255 |
+
"- Cleaning up temporary files from a repo\n\n"
|
| 256 |
+
"## Examples\n"
|
| 257 |
+
'{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
|
| 258 |
+
'{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
|
| 259 |
+
'{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
|
| 260 |
+
'{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
|
| 261 |
+
'{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
|
| 262 |
+
"## Notes\n"
|
| 263 |
+
"- For binary files (safetensors, bin), use list to see them but can't read content\n"
|
| 264 |
+
"- upload/delete require approval (can overwrite/destroy data)\n"
|
| 265 |
+
"- Use create_pr=true to propose changes instead of direct commit\n"
|
| 266 |
+
),
|
| 267 |
+
"parameters": {
|
| 268 |
+
"type": "object",
|
| 269 |
+
"properties": {
|
| 270 |
+
"operation": {
|
| 271 |
+
"type": "string",
|
| 272 |
+
"enum": ["list", "read", "upload", "delete"],
|
| 273 |
+
"description": "Operation: list, read, upload, delete",
|
| 274 |
+
},
|
| 275 |
+
"repo_id": {
|
| 276 |
+
"type": "string",
|
| 277 |
+
"description": "Repository ID (e.g., 'username/repo-name')",
|
| 278 |
+
},
|
| 279 |
+
"repo_type": {
|
| 280 |
+
"type": "string",
|
| 281 |
+
"enum": ["model", "dataset", "space"],
|
| 282 |
+
"description": "Repository type (default: model)",
|
| 283 |
+
},
|
| 284 |
+
"revision": {
|
| 285 |
+
"type": "string",
|
| 286 |
+
"description": "Branch/tag/commit (default: main)",
|
| 287 |
+
},
|
| 288 |
+
"path": {
|
| 289 |
+
"type": "string",
|
| 290 |
+
"description": "File path for read/upload",
|
| 291 |
+
},
|
| 292 |
+
"content": {
|
| 293 |
+
"type": "string",
|
| 294 |
+
"description": "File content for upload",
|
| 295 |
+
},
|
| 296 |
+
"patterns": {
|
| 297 |
+
"type": "array",
|
| 298 |
+
"items": {"type": "string"},
|
| 299 |
+
"description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
|
| 300 |
+
},
|
| 301 |
+
"create_pr": {
|
| 302 |
+
"type": "boolean",
|
| 303 |
+
"description": "Create PR instead of direct commit",
|
| 304 |
+
},
|
| 305 |
+
"commit_message": {
|
| 306 |
+
"type": "string",
|
| 307 |
+
"description": "Custom commit message",
|
| 308 |
+
},
|
| 309 |
+
},
|
| 310 |
+
"required": ["operation"],
|
| 311 |
+
},
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
|
| 316 |
+
"""Handler for agent tool router."""
|
| 317 |
+
try:
|
| 318 |
+
tool = HfRepoFilesTool()
|
| 319 |
+
result = await tool.execute(arguments)
|
| 320 |
+
return result["formatted"], not result.get("isError", False)
|
| 321 |
+
except Exception as e:
|
| 322 |
+
return f"Error: {str(e)}", False
|
agent/tools/hf_repo_git_tool.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF Repo Git Tool - Git-like operations on Hugging Face repositories
|
| 3 |
+
|
| 4 |
+
Operations: branches, tags, PRs, repo management
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any, Dict, Literal, Optional
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi
|
| 11 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
+
|
| 13 |
+
from agent.tools.types import ToolResult
|
| 14 |
+
|
| 15 |
+
OperationType = Literal[
|
| 16 |
+
"create_branch", "delete_branch",
|
| 17 |
+
"create_tag", "delete_tag",
|
| 18 |
+
"list_refs",
|
| 19 |
+
"create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
|
| 20 |
+
"create_repo", "update_repo",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
async def _async_call(func, *args, **kwargs):
|
| 25 |
+
"""Wrap synchronous HfApi calls for async context."""
|
| 26 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
| 30 |
+
"""Build the Hub URL for a repository."""
|
| 31 |
+
if repo_type == "model":
|
| 32 |
+
return f"https://huggingface.co/{repo_id}"
|
| 33 |
+
return f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HfRepoGitTool:
|
| 37 |
+
"""Tool for git-like operations on HF repos."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, hf_token: Optional[str] = None):
|
| 40 |
+
self.api = HfApi(token=hf_token)
|
| 41 |
+
|
| 42 |
+
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 43 |
+
"""Execute the specified operation."""
|
| 44 |
+
operation = args.get("operation")
|
| 45 |
+
|
| 46 |
+
if not operation:
|
| 47 |
+
return self._help()
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
handlers = {
|
| 51 |
+
"create_branch": self._create_branch,
|
| 52 |
+
"delete_branch": self._delete_branch,
|
| 53 |
+
"create_tag": self._create_tag,
|
| 54 |
+
"delete_tag": self._delete_tag,
|
| 55 |
+
"list_refs": self._list_refs,
|
| 56 |
+
"create_pr": self._create_pr,
|
| 57 |
+
"list_prs": self._list_prs,
|
| 58 |
+
"get_pr": self._get_pr,
|
| 59 |
+
"merge_pr": self._merge_pr,
|
| 60 |
+
"close_pr": self._close_pr,
|
| 61 |
+
"comment_pr": self._comment_pr,
|
| 62 |
+
"change_pr_status": self._change_pr_status,
|
| 63 |
+
"create_repo": self._create_repo,
|
| 64 |
+
"update_repo": self._update_repo,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
handler = handlers.get(operation)
|
| 68 |
+
if handler:
|
| 69 |
+
return await handler(args)
|
| 70 |
+
else:
|
| 71 |
+
ops = ", ".join(handlers.keys())
|
| 72 |
+
return self._error(f"Unknown operation: {operation}. Valid: {ops}")
|
| 73 |
+
|
| 74 |
+
except RepositoryNotFoundError:
|
| 75 |
+
return self._error(f"Repository not found: {args.get('repo_id')}")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return self._error(f"Error: {str(e)}")
|
| 78 |
+
|
| 79 |
+
def _help(self) -> ToolResult:
|
| 80 |
+
"""Show usage instructions."""
|
| 81 |
+
return {
|
| 82 |
+
"formatted": """**hf_repo_git** - Git-like operations on HF repos
|
| 83 |
+
|
| 84 |
+
**Branch/Tag:**
|
| 85 |
+
- `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
|
| 86 |
+
- `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
|
| 87 |
+
- `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
|
| 88 |
+
- `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
|
| 89 |
+
- `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
|
| 90 |
+
|
| 91 |
+
**PRs:**
|
| 92 |
+
- `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR)
|
| 93 |
+
- `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed)
|
| 94 |
+
- `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status)
|
| 95 |
+
- `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open)
|
| 96 |
+
- `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
|
| 97 |
+
- `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
|
| 98 |
+
- `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
|
| 99 |
+
|
| 100 |
+
**Repo:**
|
| 101 |
+
- `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
|
| 102 |
+
- `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
|
| 103 |
+
"totalResults": 1,
|
| 104 |
+
"resultsShared": 1,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# =========================================================================
|
| 108 |
+
# BRANCH OPERATIONS
|
| 109 |
+
# =========================================================================
|
| 110 |
+
|
| 111 |
+
async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 112 |
+
"""Create a new branch."""
|
| 113 |
+
repo_id = args.get("repo_id")
|
| 114 |
+
branch = args.get("branch")
|
| 115 |
+
|
| 116 |
+
if not repo_id:
|
| 117 |
+
return self._error("repo_id is required")
|
| 118 |
+
if not branch:
|
| 119 |
+
return self._error("branch is required")
|
| 120 |
+
|
| 121 |
+
repo_type = args.get("repo_type", "model")
|
| 122 |
+
from_rev = args.get("from_rev", "main")
|
| 123 |
+
|
| 124 |
+
await _async_call(
|
| 125 |
+
self.api.create_branch,
|
| 126 |
+
repo_id=repo_id,
|
| 127 |
+
branch=branch,
|
| 128 |
+
revision=from_rev,
|
| 129 |
+
repo_type=repo_type,
|
| 130 |
+
exist_ok=args.get("exist_ok", False),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
|
| 134 |
+
return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 135 |
+
|
| 136 |
+
async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 137 |
+
"""Delete a branch."""
|
| 138 |
+
repo_id = args.get("repo_id")
|
| 139 |
+
branch = args.get("branch")
|
| 140 |
+
|
| 141 |
+
if not repo_id:
|
| 142 |
+
return self._error("repo_id is required")
|
| 143 |
+
if not branch:
|
| 144 |
+
return self._error("branch is required")
|
| 145 |
+
|
| 146 |
+
repo_type = args.get("repo_type", "model")
|
| 147 |
+
|
| 148 |
+
await _async_call(
|
| 149 |
+
self.api.delete_branch,
|
| 150 |
+
repo_id=repo_id,
|
| 151 |
+
branch=branch,
|
| 152 |
+
repo_type=repo_type,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
|
| 156 |
+
|
| 157 |
+
# =========================================================================
|
| 158 |
+
# TAG OPERATIONS
|
| 159 |
+
# =========================================================================
|
| 160 |
+
|
| 161 |
+
async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 162 |
+
"""Create a tag."""
|
| 163 |
+
repo_id = args.get("repo_id")
|
| 164 |
+
tag = args.get("tag")
|
| 165 |
+
|
| 166 |
+
if not repo_id:
|
| 167 |
+
return self._error("repo_id is required")
|
| 168 |
+
if not tag:
|
| 169 |
+
return self._error("tag is required")
|
| 170 |
+
|
| 171 |
+
repo_type = args.get("repo_type", "model")
|
| 172 |
+
revision = args.get("revision", "main")
|
| 173 |
+
tag_message = args.get("tag_message", "")
|
| 174 |
+
|
| 175 |
+
await _async_call(
|
| 176 |
+
self.api.create_tag,
|
| 177 |
+
repo_id=repo_id,
|
| 178 |
+
tag=tag,
|
| 179 |
+
revision=revision,
|
| 180 |
+
tag_message=tag_message,
|
| 181 |
+
repo_type=repo_type,
|
| 182 |
+
exist_ok=args.get("exist_ok", False),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
|
| 186 |
+
return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 187 |
+
|
| 188 |
+
async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 189 |
+
"""Delete a tag."""
|
| 190 |
+
repo_id = args.get("repo_id")
|
| 191 |
+
tag = args.get("tag")
|
| 192 |
+
|
| 193 |
+
if not repo_id:
|
| 194 |
+
return self._error("repo_id is required")
|
| 195 |
+
if not tag:
|
| 196 |
+
return self._error("tag is required")
|
| 197 |
+
|
| 198 |
+
repo_type = args.get("repo_type", "model")
|
| 199 |
+
|
| 200 |
+
await _async_call(
|
| 201 |
+
self.api.delete_tag,
|
| 202 |
+
repo_id=repo_id,
|
| 203 |
+
tag=tag,
|
| 204 |
+
repo_type=repo_type,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
|
| 208 |
+
|
| 209 |
+
# =========================================================================
|
| 210 |
+
# LIST REFS
|
| 211 |
+
# =========================================================================
|
| 212 |
+
|
| 213 |
+
async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
|
| 214 |
+
"""List branches and tags."""
|
| 215 |
+
repo_id = args.get("repo_id")
|
| 216 |
+
|
| 217 |
+
if not repo_id:
|
| 218 |
+
return self._error("repo_id is required")
|
| 219 |
+
|
| 220 |
+
repo_type = args.get("repo_type", "model")
|
| 221 |
+
|
| 222 |
+
refs = await _async_call(
|
| 223 |
+
self.api.list_repo_refs,
|
| 224 |
+
repo_id=repo_id,
|
| 225 |
+
repo_type=repo_type,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
branches = [b.name for b in refs.branches] if refs.branches else []
|
| 229 |
+
tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
|
| 230 |
+
|
| 231 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 232 |
+
lines = [f"**{repo_id}**", url, ""]
|
| 233 |
+
|
| 234 |
+
if branches:
|
| 235 |
+
lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
|
| 236 |
+
else:
|
| 237 |
+
lines.append("**Branches:** none")
|
| 238 |
+
|
| 239 |
+
if tags:
|
| 240 |
+
lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
|
| 241 |
+
else:
|
| 242 |
+
lines.append("**Tags:** none")
|
| 243 |
+
|
| 244 |
+
return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
|
| 245 |
+
|
| 246 |
+
# =========================================================================
|
| 247 |
+
# PR OPERATIONS
|
| 248 |
+
# =========================================================================
|
| 249 |
+
|
| 250 |
+
async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 251 |
+
"""Create a pull request."""
|
| 252 |
+
repo_id = args.get("repo_id")
|
| 253 |
+
title = args.get("title")
|
| 254 |
+
|
| 255 |
+
if not repo_id:
|
| 256 |
+
return self._error("repo_id is required")
|
| 257 |
+
if not title:
|
| 258 |
+
return self._error("title is required")
|
| 259 |
+
|
| 260 |
+
repo_type = args.get("repo_type", "model")
|
| 261 |
+
description = args.get("description", "")
|
| 262 |
+
|
| 263 |
+
result = await _async_call(
|
| 264 |
+
self.api.create_pull_request,
|
| 265 |
+
repo_id=repo_id,
|
| 266 |
+
title=title,
|
| 267 |
+
description=description,
|
| 268 |
+
repo_type=repo_type,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
|
| 272 |
+
return {
|
| 273 |
+
"formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
|
| 274 |
+
"totalResults": 1,
|
| 275 |
+
"resultsShared": 1,
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
|
| 279 |
+
"""List PRs and discussions."""
|
| 280 |
+
repo_id = args.get("repo_id")
|
| 281 |
+
|
| 282 |
+
if not repo_id:
|
| 283 |
+
return self._error("repo_id is required")
|
| 284 |
+
|
| 285 |
+
repo_type = args.get("repo_type", "model")
|
| 286 |
+
status = args.get("status", "all") # open, closed, all
|
| 287 |
+
|
| 288 |
+
discussions = list(self.api.get_repo_discussions(
|
| 289 |
+
repo_id=repo_id,
|
| 290 |
+
repo_type=repo_type,
|
| 291 |
+
discussion_status=status if status != "all" else None,
|
| 292 |
+
))
|
| 293 |
+
|
| 294 |
+
if not discussions:
|
| 295 |
+
return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
|
| 296 |
+
|
| 297 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 298 |
+
lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
|
| 299 |
+
|
| 300 |
+
for d in discussions[:20]:
|
| 301 |
+
if d.status == "draft":
|
| 302 |
+
status_label = "[DRAFT]"
|
| 303 |
+
elif d.status == "open":
|
| 304 |
+
status_label = "[OPEN]"
|
| 305 |
+
elif d.status == "merged":
|
| 306 |
+
status_label = "[MERGED]"
|
| 307 |
+
else:
|
| 308 |
+
status_label = "[CLOSED]"
|
| 309 |
+
type_label = "PR" if d.is_pull_request else "D"
|
| 310 |
+
lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
|
| 311 |
+
|
| 312 |
+
return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
|
| 313 |
+
|
| 314 |
+
async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 315 |
+
"""Get PR details."""
|
| 316 |
+
repo_id = args.get("repo_id")
|
| 317 |
+
pr_num = args.get("pr_num")
|
| 318 |
+
|
| 319 |
+
if not repo_id:
|
| 320 |
+
return self._error("repo_id is required")
|
| 321 |
+
if not pr_num:
|
| 322 |
+
return self._error("pr_num is required")
|
| 323 |
+
|
| 324 |
+
repo_type = args.get("repo_type", "model")
|
| 325 |
+
|
| 326 |
+
pr = await _async_call(
|
| 327 |
+
self.api.get_discussion_details,
|
| 328 |
+
repo_id=repo_id,
|
| 329 |
+
discussion_num=int(pr_num),
|
| 330 |
+
repo_type=repo_type,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 334 |
+
status_map = {
|
| 335 |
+
"draft": "Draft",
|
| 336 |
+
"open": "Open",
|
| 337 |
+
"merged": "Merged",
|
| 338 |
+
"closed": "Closed"
|
| 339 |
+
}
|
| 340 |
+
status = status_map.get(pr.status, pr.status.capitalize())
|
| 341 |
+
type_label = "Pull Request" if pr.is_pull_request else "Discussion"
|
| 342 |
+
|
| 343 |
+
lines = [
|
| 344 |
+
f"**{type_label} #{pr_num}:** {pr.title}",
|
| 345 |
+
f"**Status:** {status}",
|
| 346 |
+
f"**Author:** {pr.author}",
|
| 347 |
+
url,
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
if pr.is_pull_request:
|
| 351 |
+
if pr.status == "draft":
|
| 352 |
+
lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
|
| 353 |
+
elif pr.status == "open":
|
| 354 |
+
lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
|
| 355 |
+
|
| 356 |
+
return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
|
| 357 |
+
|
| 358 |
+
async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 359 |
+
"""Merge a pull request."""
|
| 360 |
+
repo_id = args.get("repo_id")
|
| 361 |
+
pr_num = args.get("pr_num")
|
| 362 |
+
|
| 363 |
+
if not repo_id:
|
| 364 |
+
return self._error("repo_id is required")
|
| 365 |
+
if not pr_num:
|
| 366 |
+
return self._error("pr_num is required")
|
| 367 |
+
|
| 368 |
+
repo_type = args.get("repo_type", "model")
|
| 369 |
+
comment = args.get("comment", "")
|
| 370 |
+
|
| 371 |
+
await _async_call(
|
| 372 |
+
self.api.merge_pull_request,
|
| 373 |
+
repo_id=repo_id,
|
| 374 |
+
discussion_num=int(pr_num),
|
| 375 |
+
comment=comment,
|
| 376 |
+
repo_type=repo_type,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 380 |
+
return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 381 |
+
|
| 382 |
+
async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 383 |
+
"""Close a PR/discussion."""
|
| 384 |
+
repo_id = args.get("repo_id")
|
| 385 |
+
pr_num = args.get("pr_num")
|
| 386 |
+
|
| 387 |
+
if not repo_id:
|
| 388 |
+
return self._error("repo_id is required")
|
| 389 |
+
if not pr_num:
|
| 390 |
+
return self._error("pr_num is required")
|
| 391 |
+
|
| 392 |
+
repo_type = args.get("repo_type", "model")
|
| 393 |
+
comment = args.get("comment", "")
|
| 394 |
+
|
| 395 |
+
await _async_call(
|
| 396 |
+
self.api.change_discussion_status,
|
| 397 |
+
repo_id=repo_id,
|
| 398 |
+
discussion_num=int(pr_num),
|
| 399 |
+
new_status="closed",
|
| 400 |
+
comment=comment,
|
| 401 |
+
repo_type=repo_type,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
|
| 405 |
+
|
| 406 |
+
async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 407 |
+
"""Add a comment to a PR/discussion."""
|
| 408 |
+
repo_id = args.get("repo_id")
|
| 409 |
+
pr_num = args.get("pr_num")
|
| 410 |
+
comment = args.get("comment")
|
| 411 |
+
|
| 412 |
+
if not repo_id:
|
| 413 |
+
return self._error("repo_id is required")
|
| 414 |
+
if not pr_num:
|
| 415 |
+
return self._error("pr_num is required")
|
| 416 |
+
if not comment:
|
| 417 |
+
return self._error("comment is required")
|
| 418 |
+
|
| 419 |
+
repo_type = args.get("repo_type", "model")
|
| 420 |
+
|
| 421 |
+
await _async_call(
|
| 422 |
+
self.api.comment_discussion,
|
| 423 |
+
repo_id=repo_id,
|
| 424 |
+
discussion_num=int(pr_num),
|
| 425 |
+
comment=comment,
|
| 426 |
+
repo_type=repo_type,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 430 |
+
return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 431 |
+
|
| 432 |
+
async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
|
| 433 |
+
"""Change PR/discussion status (mainly to convert draft to open)."""
|
| 434 |
+
repo_id = args.get("repo_id")
|
| 435 |
+
pr_num = args.get("pr_num")
|
| 436 |
+
new_status = args.get("new_status")
|
| 437 |
+
|
| 438 |
+
if not repo_id:
|
| 439 |
+
return self._error("repo_id is required")
|
| 440 |
+
if not pr_num:
|
| 441 |
+
return self._error("pr_num is required")
|
| 442 |
+
if not new_status:
|
| 443 |
+
return self._error("new_status is required (open or closed)")
|
| 444 |
+
|
| 445 |
+
repo_type = args.get("repo_type", "model")
|
| 446 |
+
comment = args.get("comment", "")
|
| 447 |
+
|
| 448 |
+
await _async_call(
|
| 449 |
+
self.api.change_discussion_status,
|
| 450 |
+
repo_id=repo_id,
|
| 451 |
+
discussion_num=int(pr_num),
|
| 452 |
+
new_status=new_status,
|
| 453 |
+
comment=comment,
|
| 454 |
+
repo_type=repo_type,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 458 |
+
return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 459 |
+
|
| 460 |
+
# =========================================================================
|
| 461 |
+
# REPO MANAGEMENT
|
| 462 |
+
# =========================================================================
|
| 463 |
+
|
| 464 |
+
async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
|
| 465 |
+
"""Create a new repository."""
|
| 466 |
+
repo_id = args.get("repo_id")
|
| 467 |
+
|
| 468 |
+
if not repo_id:
|
| 469 |
+
return self._error("repo_id is required")
|
| 470 |
+
|
| 471 |
+
repo_type = args.get("repo_type", "model")
|
| 472 |
+
private = args.get("private", True)
|
| 473 |
+
space_sdk = args.get("space_sdk")
|
| 474 |
+
|
| 475 |
+
if repo_type == "space" and not space_sdk:
|
| 476 |
+
return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
|
| 477 |
+
|
| 478 |
+
kwargs = {
|
| 479 |
+
"repo_id": repo_id,
|
| 480 |
+
"repo_type": repo_type,
|
| 481 |
+
"private": private,
|
| 482 |
+
"exist_ok": args.get("exist_ok", False),
|
| 483 |
+
}
|
| 484 |
+
if space_sdk:
|
| 485 |
+
kwargs["space_sdk"] = space_sdk
|
| 486 |
+
|
| 487 |
+
result = await _async_call(self.api.create_repo, **kwargs)
|
| 488 |
+
|
| 489 |
+
return {
|
| 490 |
+
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
| 491 |
+
"totalResults": 1,
|
| 492 |
+
"resultsShared": 1,
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
|
| 496 |
+
"""Update repository settings."""
|
| 497 |
+
repo_id = args.get("repo_id")
|
| 498 |
+
|
| 499 |
+
if not repo_id:
|
| 500 |
+
return self._error("repo_id is required")
|
| 501 |
+
|
| 502 |
+
repo_type = args.get("repo_type", "model")
|
| 503 |
+
private = args.get("private")
|
| 504 |
+
gated = args.get("gated")
|
| 505 |
+
|
| 506 |
+
if private is None and gated is None:
|
| 507 |
+
return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
|
| 508 |
+
|
| 509 |
+
kwargs = {"repo_id": repo_id, "repo_type": repo_type}
|
| 510 |
+
if private is not None:
|
| 511 |
+
kwargs["private"] = private
|
| 512 |
+
if gated is not None:
|
| 513 |
+
kwargs["gated"] = gated
|
| 514 |
+
|
| 515 |
+
await _async_call(self.api.update_repo_settings, **kwargs)
|
| 516 |
+
|
| 517 |
+
changes = []
|
| 518 |
+
if private is not None:
|
| 519 |
+
changes.append(f"private={private}")
|
| 520 |
+
if gated is not None:
|
| 521 |
+
changes.append(f"gated={gated}")
|
| 522 |
+
|
| 523 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/settings"
|
| 524 |
+
return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 525 |
+
|
| 526 |
+
def _error(self, message: str) -> ToolResult:
|
| 527 |
+
"""Return an error result."""
|
| 528 |
+
return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# Tool specification
|
| 532 |
+
HF_REPO_GIT_TOOL_SPEC = {
|
| 533 |
+
"name": "hf_repo_git",
|
| 534 |
+
"description": (
|
| 535 |
+
"Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
|
| 536 |
+
"## Operations\n"
|
| 537 |
+
"**Branches:** create_branch, delete_branch, list_refs\n"
|
| 538 |
+
"**Tags:** create_tag, delete_tag\n"
|
| 539 |
+
"**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n"
|
| 540 |
+
"**Repo:** create_repo, update_repo\n\n"
|
| 541 |
+
"## Use when\n"
|
| 542 |
+
"- Creating feature branches for experiments\n"
|
| 543 |
+
"- Tagging model versions (v1.0, v2.0)\n"
|
| 544 |
+
"- Opening PRs to contribute to repos you don't own\n"
|
| 545 |
+
"- Reviewing and merging PRs on your repos\n"
|
| 546 |
+
"- Creating new model/dataset/space repos\n"
|
| 547 |
+
"- Changing repo visibility (public/private) or gated access\n\n"
|
| 548 |
+
"## Examples\n"
|
| 549 |
+
'{"operation": "list_refs", "repo_id": "my-model"}\n'
|
| 550 |
+
'{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
|
| 551 |
+
'{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
|
| 552 |
+
'{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
|
| 553 |
+
'{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n'
|
| 554 |
+
'{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
|
| 555 |
+
'{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
|
| 556 |
+
'{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
|
| 557 |
+
"## PR Workflow\n"
|
| 558 |
+
"1. create_pr → creates draft PR (empty by default)\n"
|
| 559 |
+
"2. Upload files with revision='refs/pr/N' to add commits\n"
|
| 560 |
+
"3. change_pr_status with new_status='open' to publish (convert draft to open)\n"
|
| 561 |
+
"4. merge_pr when ready\n\n"
|
| 562 |
+
"## Notes\n"
|
| 563 |
+
"- PR status: draft (default), open, merged, closed\n"
|
| 564 |
+
"- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
|
| 565 |
+
"- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
|
| 566 |
+
"- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
|
| 567 |
+
),
|
| 568 |
+
"parameters": {
|
| 569 |
+
"type": "object",
|
| 570 |
+
"properties": {
|
| 571 |
+
"operation": {
|
| 572 |
+
"type": "string",
|
| 573 |
+
"enum": [
|
| 574 |
+
"create_branch", "delete_branch",
|
| 575 |
+
"create_tag", "delete_tag", "list_refs",
|
| 576 |
+
"create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
|
| 577 |
+
"create_repo", "update_repo",
|
| 578 |
+
],
|
| 579 |
+
"description": "Operation to execute",
|
| 580 |
+
},
|
| 581 |
+
"repo_id": {
|
| 582 |
+
"type": "string",
|
| 583 |
+
"description": "Repository ID (e.g., 'username/repo-name')",
|
| 584 |
+
},
|
| 585 |
+
"repo_type": {
|
| 586 |
+
"type": "string",
|
| 587 |
+
"enum": ["model", "dataset", "space"],
|
| 588 |
+
"description": "Repository type (default: model)",
|
| 589 |
+
},
|
| 590 |
+
"branch": {
|
| 591 |
+
"type": "string",
|
| 592 |
+
"description": "Branch name (create_branch, delete_branch)",
|
| 593 |
+
},
|
| 594 |
+
"from_rev": {
|
| 595 |
+
"type": "string",
|
| 596 |
+
"description": "Create branch from this revision (default: main)",
|
| 597 |
+
},
|
| 598 |
+
"tag": {
|
| 599 |
+
"type": "string",
|
| 600 |
+
"description": "Tag name (create_tag, delete_tag)",
|
| 601 |
+
},
|
| 602 |
+
"revision": {
|
| 603 |
+
"type": "string",
|
| 604 |
+
"description": "Revision for tag (default: main)",
|
| 605 |
+
},
|
| 606 |
+
"tag_message": {
|
| 607 |
+
"type": "string",
|
| 608 |
+
"description": "Tag description",
|
| 609 |
+
},
|
| 610 |
+
"title": {
|
| 611 |
+
"type": "string",
|
| 612 |
+
"description": "PR title (create_pr)",
|
| 613 |
+
},
|
| 614 |
+
"description": {
|
| 615 |
+
"type": "string",
|
| 616 |
+
"description": "PR description (create_pr)",
|
| 617 |
+
},
|
| 618 |
+
"pr_num": {
|
| 619 |
+
"type": "integer",
|
| 620 |
+
"description": "PR/discussion number",
|
| 621 |
+
},
|
| 622 |
+
"comment": {
|
| 623 |
+
"type": "string",
|
| 624 |
+
"description": "Comment text",
|
| 625 |
+
},
|
| 626 |
+
"status": {
|
| 627 |
+
"type": "string",
|
| 628 |
+
"enum": ["open", "closed", "all"],
|
| 629 |
+
"description": "Filter PRs by status (list_prs)",
|
| 630 |
+
},
|
| 631 |
+
"new_status": {
|
| 632 |
+
"type": "string",
|
| 633 |
+
"enum": ["open", "closed"],
|
| 634 |
+
"description": "New status for PR/discussion (change_pr_status)",
|
| 635 |
+
},
|
| 636 |
+
"private": {
|
| 637 |
+
"type": "boolean",
|
| 638 |
+
"description": "Make repo private (create_repo, update_repo)",
|
| 639 |
+
},
|
| 640 |
+
"gated": {
|
| 641 |
+
"type": "string",
|
| 642 |
+
"enum": ["auto", "manual", "false"],
|
| 643 |
+
"description": "Gated access setting (update_repo)",
|
| 644 |
+
},
|
| 645 |
+
"space_sdk": {
|
| 646 |
+
"type": "string",
|
| 647 |
+
"enum": ["gradio", "streamlit", "docker", "static"],
|
| 648 |
+
"description": "Space SDK (required for create_repo with space)",
|
| 649 |
+
},
|
| 650 |
+
},
|
| 651 |
+
"required": ["operation"],
|
| 652 |
+
},
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
|
| 657 |
+
"""Handler for agent tool router."""
|
| 658 |
+
try:
|
| 659 |
+
tool = HfRepoGitTool()
|
| 660 |
+
result = await tool.execute(arguments)
|
| 661 |
+
return result["formatted"], not result.get("isError", False)
|
| 662 |
+
except Exception as e:
|
| 663 |
+
return f"Error: {str(e)}", False
|
agent/tools/utils_tools.py
DELETED
|
@@ -1,203 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utils Tools - General utility operations
|
| 3 |
-
|
| 4 |
-
Provides system information like current date/time with timezone support.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import zoneinfo
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
from typing import Any, Dict, Literal
|
| 10 |
-
|
| 11 |
-
from agent.tools.types import ToolResult
|
| 12 |
-
|
| 13 |
-
# Operation names
|
| 14 |
-
OperationType = Literal["get_datetime"]
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class UtilsTool:
|
| 18 |
-
"""Tool for general utility operations."""
|
| 19 |
-
|
| 20 |
-
async def execute(self, params: Dict[str, Any]) -> ToolResult:
|
| 21 |
-
"""Execute the specified utility operation."""
|
| 22 |
-
operation = params.get("operation")
|
| 23 |
-
args = params.get("args", {})
|
| 24 |
-
|
| 25 |
-
# If no operation provided, return usage instructions
|
| 26 |
-
if not operation:
|
| 27 |
-
return self._show_help()
|
| 28 |
-
|
| 29 |
-
# Normalize operation name
|
| 30 |
-
operation = operation.lower()
|
| 31 |
-
|
| 32 |
-
# Check if help is requested
|
| 33 |
-
if args.get("help"):
|
| 34 |
-
return self._show_operation_help(operation)
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
# Route to appropriate handler
|
| 38 |
-
if operation == "get_datetime":
|
| 39 |
-
return await self._get_datetime(args)
|
| 40 |
-
else:
|
| 41 |
-
return {
|
| 42 |
-
"formatted": f'Unknown operation: "{operation}"\n\n'
|
| 43 |
-
"Available operations: get_datetime\n\n"
|
| 44 |
-
"Call this tool with no operation for full usage instructions.",
|
| 45 |
-
"totalResults": 0,
|
| 46 |
-
"resultsShared": 0,
|
| 47 |
-
"isError": True,
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
except Exception as e:
|
| 51 |
-
return {
|
| 52 |
-
"formatted": f"Error executing {operation}: {str(e)}",
|
| 53 |
-
"totalResults": 0,
|
| 54 |
-
"resultsShared": 0,
|
| 55 |
-
"isError": True,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
def _show_help(self) -> ToolResult:
|
| 59 |
-
"""Show usage instructions when tool is called with no arguments."""
|
| 60 |
-
usage_text = """# Utils Tool
|
| 61 |
-
|
| 62 |
-
Utility operations for system information.
|
| 63 |
-
|
| 64 |
-
## Available Commands
|
| 65 |
-
|
| 66 |
-
- **get_datetime** - Get current date and time with timezone support
|
| 67 |
-
|
| 68 |
-
## Examples
|
| 69 |
-
|
| 70 |
-
### Get current date and time (Paris timezone by default)
|
| 71 |
-
Call this tool with:
|
| 72 |
-
```json
|
| 73 |
-
{
|
| 74 |
-
"operation": "get_datetime",
|
| 75 |
-
"args": {}
|
| 76 |
-
}
|
| 77 |
-
```
|
| 78 |
-
|
| 79 |
-
### Get current date and time in a specific timezone
|
| 80 |
-
Call this tool with:
|
| 81 |
-
```json
|
| 82 |
-
{
|
| 83 |
-
"operation": "get_datetime",
|
| 84 |
-
"args": {
|
| 85 |
-
"timezone": "America/New_York"
|
| 86 |
-
}
|
| 87 |
-
}
|
| 88 |
-
```
|
| 89 |
-
|
| 90 |
-
Common timezones: Europe/Paris, America/New_York, America/Los_Angeles, Asia/Tokyo, UTC
|
| 91 |
-
|
| 92 |
-
## Tips
|
| 93 |
-
|
| 94 |
-
- **Default timezone**: Paris (Europe/Paris)
|
| 95 |
-
- **Date format**: dd-mm-yyyy
|
| 96 |
-
- **Time format**: HH:MM:SS.mmm (24-hour format with milliseconds)
|
| 97 |
-
- **Timezone names**: Use IANA timezone database names (e.g., "Europe/Paris", "UTC")
|
| 98 |
-
"""
|
| 99 |
-
return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
|
| 100 |
-
|
| 101 |
-
def _show_operation_help(self, operation: str) -> ToolResult:
|
| 102 |
-
"""Show help for a specific operation."""
|
| 103 |
-
help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
|
| 104 |
-
return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
|
| 105 |
-
|
| 106 |
-
async def _get_datetime(self, args: Dict[str, Any]) -> ToolResult:
|
| 107 |
-
"""Get current date and time with timezone support."""
|
| 108 |
-
timezone_name = args.get("timezone", "Europe/Paris")
|
| 109 |
-
|
| 110 |
-
try:
|
| 111 |
-
# Get timezone object
|
| 112 |
-
tz = zoneinfo.ZoneInfo(timezone_name)
|
| 113 |
-
|
| 114 |
-
# Get current datetime in specified timezone
|
| 115 |
-
now = datetime.now(tz)
|
| 116 |
-
|
| 117 |
-
# Format date as dd-mm-yyyy
|
| 118 |
-
date_str = now.strftime("%d-%m-%Y")
|
| 119 |
-
|
| 120 |
-
# Format time as HH:MM:SS.mmm
|
| 121 |
-
time_str = now.strftime("%H:%M:%S.%f")[
|
| 122 |
-
:-3
|
| 123 |
-
] # Remove last 3 digits to keep only milliseconds
|
| 124 |
-
|
| 125 |
-
# Get timezone abbreviation/offset
|
| 126 |
-
tz_offset = now.strftime("%z")
|
| 127 |
-
tz_name = now.strftime("%Z")
|
| 128 |
-
|
| 129 |
-
response = f"""✓ Current date and time
|
| 130 |
-
|
| 131 |
-
**Date:** {date_str}
|
| 132 |
-
**Time:** {time_str}
|
| 133 |
-
**Timezone:** {timezone_name} ({tz_name}, UTC{tz_offset[:3]}:{tz_offset[3:]})
|
| 134 |
-
|
| 135 |
-
**ISO Format:** {now.isoformat()}
|
| 136 |
-
**Unix Timestamp:** {int(now.timestamp())}"""
|
| 137 |
-
|
| 138 |
-
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 139 |
-
|
| 140 |
-
except zoneinfo.ZoneInfoNotFoundError:
|
| 141 |
-
return {
|
| 142 |
-
"formatted": f"Invalid timezone: {timezone_name}\n\n"
|
| 143 |
-
"Use IANA timezone database names like:\n"
|
| 144 |
-
"- Europe/Paris\n"
|
| 145 |
-
"- America/New_York\n"
|
| 146 |
-
"- Asia/Tokyo\n"
|
| 147 |
-
"- UTC\n\n"
|
| 148 |
-
"See: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones",
|
| 149 |
-
"totalResults": 0,
|
| 150 |
-
"resultsShared": 0,
|
| 151 |
-
"isError": True,
|
| 152 |
-
}
|
| 153 |
-
except Exception as e:
|
| 154 |
-
return {
|
| 155 |
-
"formatted": f"Failed to get date/time: {str(e)}",
|
| 156 |
-
"totalResults": 0,
|
| 157 |
-
"resultsShared": 0,
|
| 158 |
-
"isError": True,
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
# Tool specification for agent registration
|
| 163 |
-
UTILS_TOOL_SPEC = {
|
| 164 |
-
"name": "utils",
|
| 165 |
-
"description": (
|
| 166 |
-
"System utility operations - currently provides date/time with timezone support. "
|
| 167 |
-
"**Use when:** (1) Need current date for logging/timestamps, (2) User asks 'what time is it', "
|
| 168 |
-
"(3) Need timezone-aware datetime for scheduling/coordination, (4) Creating timestamped filenames. "
|
| 169 |
-
"**Operation:** get_datetime with optional timezone parameter (default: Europe/Paris). "
|
| 170 |
-
"Returns: Date (dd-mm-yyyy), time (HH:MM:SS.mmm), timezone info, ISO format, Unix timestamp. "
|
| 171 |
-
"**Pattern:** utils get_datetime → use timestamp in filename/log → upload to hf_private_repos. "
|
| 172 |
-
"Supports IANA timezone names: 'Europe/Paris', 'America/New_York', 'Asia/Tokyo', 'UTC'."
|
| 173 |
-
),
|
| 174 |
-
"parameters": {
|
| 175 |
-
"type": "object",
|
| 176 |
-
"properties": {
|
| 177 |
-
"operation": {
|
| 178 |
-
"type": "string",
|
| 179 |
-
"enum": ["get_datetime"],
|
| 180 |
-
"description": "Operation to execute. Valid values: [get_datetime]",
|
| 181 |
-
},
|
| 182 |
-
"args": {
|
| 183 |
-
"type": "object",
|
| 184 |
-
"description": (
|
| 185 |
-
"Operation-specific arguments as a JSON object. "
|
| 186 |
-
"For get_datetime: timezone (string, optional, default: Europe/Paris). "
|
| 187 |
-
"Use IANA timezone names like 'America/New_York', 'Asia/Tokyo', 'UTC'."
|
| 188 |
-
),
|
| 189 |
-
"additionalProperties": True,
|
| 190 |
-
},
|
| 191 |
-
},
|
| 192 |
-
},
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
async def utils_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
|
| 197 |
-
"""Handler for agent tool router."""
|
| 198 |
-
try:
|
| 199 |
-
tool = UtilsTool()
|
| 200 |
-
result = await tool.execute(arguments)
|
| 201 |
-
return result["formatted"], not result.get("isError", False)
|
| 202 |
-
except Exception as e:
|
| 203 |
-
return f"Error executing Utils tool: {str(e)}", False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/utils/__init__.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
"""
|
| 2 |
Utility functions and helpers
|
| 3 |
"""
|
| 4 |
-
|
| 5 |
-
from agent.utils.logging import setup_logger
|
| 6 |
-
|
| 7 |
-
__all__ = ["setup_logger"]
|
|
|
|
| 1 |
"""
|
| 2 |
Utility functions and helpers
|
| 3 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
agent/utils/logging.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Logging utilities
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import logging
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Optional
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def setup_logger(
|
| 12 |
-
name: str = "hf_agent", level: int = logging.INFO, log_file: Optional[Path] = None
|
| 13 |
-
) -> logging.Logger:
|
| 14 |
-
"""Setup and configure logger"""
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(name)
|
| 17 |
-
logger.setLevel(level)
|
| 18 |
-
|
| 19 |
-
# Remove existing handlers
|
| 20 |
-
logger.handlers = []
|
| 21 |
-
|
| 22 |
-
# Console handler
|
| 23 |
-
console_handler = logging.StreamHandler(sys.stdout)
|
| 24 |
-
console_handler.setLevel(level)
|
| 25 |
-
console_format = logging.Formatter(
|
| 26 |
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 27 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
| 28 |
-
)
|
| 29 |
-
console_handler.setFormatter(console_format)
|
| 30 |
-
logger.addHandler(console_handler)
|
| 31 |
-
|
| 32 |
-
# File handler if log_file specified
|
| 33 |
-
if log_file:
|
| 34 |
-
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
-
file_handler = logging.FileHandler(log_file)
|
| 36 |
-
file_handler.setLevel(level)
|
| 37 |
-
file_handler.setFormatter(console_format)
|
| 38 |
-
logger.addHandler(file_handler)
|
| 39 |
-
|
| 40 |
-
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/utils/terminal_display.py
CHANGED
|
@@ -94,13 +94,18 @@ def format_tool_call(tool_name: str, arguments: str) -> str:
|
|
| 94 |
|
| 95 |
def format_tool_output(output: str, success: bool, truncate: bool = True) -> str:
|
| 96 |
"""Format tool output with color and optional truncation"""
|
|
|
|
| 97 |
if truncate:
|
| 98 |
output = truncate_to_lines(output, max_lines=6)
|
| 99 |
|
| 100 |
if success:
|
| 101 |
-
return
|
|
|
|
|
|
|
| 102 |
else:
|
| 103 |
-
return
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
def format_turn_complete() -> str:
|
|
|
|
| 94 |
|
| 95 |
def format_tool_output(output: str, success: bool, truncate: bool = True) -> str:
|
| 96 |
"""Format tool output with color and optional truncation"""
|
| 97 |
+
original_length = len(output)
|
| 98 |
if truncate:
|
| 99 |
output = truncate_to_lines(output, max_lines=6)
|
| 100 |
|
| 101 |
if success:
|
| 102 |
+
return (
|
| 103 |
+
f"{Colors.YELLOW}Tool output ({original_length} tkns): {Colors.RESET}\n{output}"
|
| 104 |
+
)
|
| 105 |
else:
|
| 106 |
+
return (
|
| 107 |
+
f"{Colors.RED}Tool output ({original_length} tokens): {Colors.RESET}\n{output}"
|
| 108 |
+
)
|
| 109 |
|
| 110 |
|
| 111 |
def format_turn_complete() -> str:
|
pyproject.toml
CHANGED
|
@@ -24,6 +24,7 @@ agent = [
|
|
| 24 |
"nbconvert>=7.16.6",
|
| 25 |
"nbformat>=5.10.4",
|
| 26 |
"datasets>=4.3.0", # For session logging to HF datasets
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
# Evaluation/benchmarking dependencies
|
|
|
|
| 24 |
"nbconvert>=7.16.6",
|
| 25 |
"nbformat>=5.10.4",
|
| 26 |
"datasets>=4.3.0", # For session logging to HF datasets
|
| 27 |
+
"whoosh>=2.7.4",
|
| 28 |
]
|
| 29 |
|
| 30 |
# Evaluation/benchmarking dependencies
|
test_dataset_tools.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Test script for unified dataset inspection tool
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import asyncio
|
| 6 |
-
import sys
|
| 7 |
-
from typing import TypedDict
|
| 8 |
-
from unittest.mock import MagicMock
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
# Mock the types module before importing dataset_tools
|
| 12 |
-
class ToolResult(TypedDict, total=False):
|
| 13 |
-
formatted: str
|
| 14 |
-
totalResults: int
|
| 15 |
-
resultsShared: int
|
| 16 |
-
isError: bool
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
mock_types = MagicMock()
|
| 20 |
-
mock_types.ToolResult = ToolResult
|
| 21 |
-
sys.modules["agent.tools.types"] = mock_types
|
| 22 |
-
|
| 23 |
-
# Now import directly from the file
|
| 24 |
-
sys.path.insert(0, "/Users/akseljoonas/Documents/hf-agent/agent/tools")
|
| 25 |
-
from dataset_tools import hf_inspect_dataset_handler, inspect_dataset
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
async def test_inspect_dataset():
|
| 29 |
-
"""Test the unified inspect_dataset function"""
|
| 30 |
-
print("=" * 70)
|
| 31 |
-
print("Testing inspect_dataset()")
|
| 32 |
-
print("=" * 70)
|
| 33 |
-
|
| 34 |
-
# Test with akseljoonas/hf-agent-sessions as specified
|
| 35 |
-
print("\n→ inspect_dataset('akseljoonas/hf-agent-sessions'):")
|
| 36 |
-
result = await inspect_dataset("akseljoonas/hf-agent-sessions")
|
| 37 |
-
print(f" isError: {result['isError']}")
|
| 38 |
-
print(f" Output:\n{result['formatted']}")
|
| 39 |
-
|
| 40 |
-
print("\n" + "=" * 70)
|
| 41 |
-
|
| 42 |
-
# # Test with stanfordnlp/imdb
|
| 43 |
-
# print("\n→ inspect_dataset('stanfordnlp/imdb'):")
|
| 44 |
-
# result = await inspect_dataset("stanfordnlp/imdb")
|
| 45 |
-
# print(f" isError: {result['isError']}")
|
| 46 |
-
# print(f" Output:\n{result['formatted']}")
|
| 47 |
-
|
| 48 |
-
# print("\n" + "=" * 70)
|
| 49 |
-
|
| 50 |
-
# # Test with multi-config dataset
|
| 51 |
-
# print("\n→ inspect_dataset('nyu-mll/glue', config='mrpc'):")
|
| 52 |
-
# result = await inspect_dataset("nyu-mll/glue", config="mrpc")
|
| 53 |
-
# print(f" isError: {result['isError']}")
|
| 54 |
-
# print(f" Output:\n{result['formatted']}")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
async def test_handler():
|
| 58 |
-
"""Test the handler (what the agent calls)"""
|
| 59 |
-
print("\n" + "=" * 70)
|
| 60 |
-
print("Testing hf_inspect_dataset_handler()")
|
| 61 |
-
print("=" * 70)
|
| 62 |
-
|
| 63 |
-
result, success = await hf_inspect_dataset_handler(
|
| 64 |
-
{
|
| 65 |
-
"dataset": "stanfordnlp/imdb",
|
| 66 |
-
"sample_rows": 2,
|
| 67 |
-
}
|
| 68 |
-
)
|
| 69 |
-
print("\n→ Handler result:")
|
| 70 |
-
print(f" success: {success}")
|
| 71 |
-
print(f" output:\n{result}")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
if __name__ == "__main__":
|
| 75 |
-
print("\nUnified Dataset Inspection Tool Test\n")
|
| 76 |
-
asyncio.run(test_inspect_dataset())
|
| 77 |
-
# asyncio.run(test_handler())
|
| 78 |
-
print("\n" + "=" * 70)
|
| 79 |
-
print("Done!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|