Aksel Joonas Reedi commited on
Commit
471bd1a
·
2 Parent(s): a8a06cc9459bd9

Merge branch 'main' into dataset_tool_improved

Browse files
.gitignore CHANGED
@@ -16,4 +16,5 @@ wheels/
16
  /logs
17
  hf-agent-leaderboard/
18
  .cursor/
19
- session_logs/
 
 
16
  /logs
17
  hf-agent-leaderboard/
18
  .cursor/
19
+ session_logs/
20
+ skills/
agent/MCP_INTEGRATION.md DELETED
@@ -1,205 +0,0 @@
1
- # MCP Integration for HF Agent
2
-
3
- This agent now supports the Model Context Protocol (MCP), allowing it to connect to and use tools from MCP servers.
4
-
5
- ## Overview
6
-
7
- The MCP integration allows the agent to:
8
- - Connect to multiple MCP servers simultaneously
9
- - Automatically discover and use tools from connected servers
10
- - Execute tool calls through the MCP protocol
11
- - Seamlessly integrate MCP tools with the agent's existing tool system
12
-
13
- ## Architecture
14
-
15
- The integration consists of several components:
16
-
17
- 1. **MCPClient** (`agent/core/mcp_client.py`): Manages connections to MCP servers
18
- 2. **ToolExecutor** (`agent/core/executor.py`): Executes both MCP and local tools
19
- 3. **Config** (`agent/config.py`): Stores MCP server configurations
20
- 4. **Session** (`agent/core/session.py`): Initializes MCP connections and manages lifecycle
21
-
22
- ## Configuration
23
-
24
- To use MCP servers with your agent, add them to your configuration file:
25
-
26
- ```json
27
- {
28
- "model_name": "anthropic/claude-sonnet-4-5-20250929",
29
- "tools": [],
30
- "system_prompt_path": "",
31
- "mcp_servers": [
32
- {
33
- "name": "weather",
34
- "command": "python",
35
- "args": ["path/to/weather_server.py"],
36
- "env": null
37
- },
38
- {
39
- "name": "filesystem",
40
- "command": "node",
41
- "args": ["path/to/filesystem_server.js"],
42
- "env": {
43
- "ALLOWED_PATHS": "/home/user/documents"
44
- }
45
- }
46
- ]
47
- }
48
- ```
49
-
50
- ### Configuration Fields
51
-
52
- - `name`: Unique identifier for the MCP server
53
- - `command`: Command to execute the server (`python`, `node`, etc.)
54
- - `args`: Arguments to pass to the command (path to server script)
55
- - `env`: (Optional) Environment variables for the server process
56
-
57
- ## Usage
58
-
59
- ### Basic Usage
60
-
61
- ```python
62
- import asyncio
63
- from agent.config import Config, load_config
64
- from agent.core.agent_loop import submission_loop
65
-
66
- async def main():
67
- # Load config with MCP servers
68
- config = load_config("config.json")
69
-
70
- # Create queues
71
- submission_queue = asyncio.Queue()
72
- event_queue = asyncio.Queue()
73
-
74
- # Start agent loop (MCP connections initialized automatically)
75
- await submission_loop(submission_queue, event_queue, config)
76
-
77
- if __name__ == "__main__":
78
- asyncio.run(main())
79
- ```
80
-
81
- ### Programmatic Configuration
82
-
83
- ```python
84
- from agent.config import Config, MCPServerConfig
85
-
86
- config = Config(
87
- model_name="anthropic/claude-sonnet-4-5-20250929",
88
- tools=[],
89
- system_prompt_path="",
90
- mcp_servers=[
91
- MCPServerConfig(
92
- name="weather",
93
- command="python",
94
- args=["weather_server.py"],
95
- env=None
96
- )
97
- ]
98
- )
99
- ```
100
-
101
- ## How It Works
102
-
103
- 1. **Initialization**: When the agent loop starts, it calls `session.initialize_mcp()`
104
- 2. **Connection**: The session connects to all configured MCP servers
105
- 3. **Tool Discovery**: Tools from all servers are discovered and added to the agent's tool list
106
- 4. **Tool Naming**: MCP tools are prefixed with their server name (e.g., `weather__get_forecast`)
107
- 5. **Execution**: When the LLM calls a tool, the ToolExecutor routes it to the appropriate MCP server
108
- 6. **Cleanup**: When the agent shuts down, all MCP connections are cleaned up properly
109
-
110
- ## Tool Naming Convention
111
-
112
- MCP tools are automatically prefixed with their server name to avoid conflicts:
113
-
114
- - Original tool: `get_forecast`
115
- - MCP tool name: `weather__get_forecast`
116
-
117
- This ensures that tools from different servers don't conflict, even if they have the same name.
118
-
119
- ## Example: Creating a Simple MCP Server
120
-
121
- Here's a minimal example of an MCP server (save as `calculator_server.py`):
122
-
123
- ```python
124
- import asyncio
125
- from mcp.server import Server, stdio_server
126
- from mcp.types import Tool, TextContent
127
-
128
- app = Server("calculator")
129
-
130
- @app.list_tools()
131
- async def list_tools() -> list[Tool]:
132
- return [
133
- Tool(
134
- name="add",
135
- description="Add two numbers",
136
- inputSchema={
137
- "type": "object",
138
- "properties": {
139
- "a": {"type": "number"},
140
- "b": {"type": "number"}
141
- },
142
- "required": ["a", "b"]
143
- }
144
- )
145
- ]
146
-
147
- @app.call_tool()
148
- async def call_tool(name: str, arguments: dict) -> list[TextContent]:
149
- if name == "add":
150
- result = arguments["a"] + arguments["b"]
151
- return [TextContent(type="text", text=str(result))]
152
-
153
- raise ValueError(f"Unknown tool: {name}")
154
-
155
- async def main():
156
- async with stdio_server() as (read_stream, write_stream):
157
- await app.run(read_stream, write_stream, app.create_initialization_options())
158
-
159
- if __name__ == "__main__":
160
- asyncio.run(main())
161
- ```
162
-
163
- ## Troubleshooting
164
-
165
- ### Server Connection Issues
166
-
167
- If you see errors connecting to an MCP server:
168
-
169
- 1. Check that the server script path is correct
170
- 2. Ensure the command (`python`, `node`) is in your PATH
171
- 3. Verify the server script is executable
172
- 4. Check server logs for initialization errors
173
-
174
- ### Tool Not Found
175
-
176
- If the agent can't find an MCP tool:
177
-
178
- 1. Verify the server is connected (check startup logs)
179
- 2. Check tool naming (should be `servername__toolname`)
180
- 3. Ensure the server properly implements `list_tools()`
181
-
182
- ### Performance Considerations
183
-
184
- - MCP server initialization happens once at startup
185
- - Tool calls are asynchronous and don't block the agent
186
- - Multiple servers can be used simultaneously
187
- - Consider using local tools for high-frequency operations
188
-
189
- ## Best Practices
190
-
191
- 1. **Unique Server Names**: Give each MCP server a unique, descriptive name
192
- 2. **Error Handling**: MCP connection failures are logged but don't crash the agent
193
- 3. **Resource Cleanup**: Always let the agent shut down gracefully to cleanup connections
194
- 4. **Testing**: Test MCP servers independently before integrating them
195
- 5. **Security**: Be cautious with file system and network access in MCP servers
196
-
197
- ## Future Enhancements
198
-
199
- Potential improvements to consider:
200
-
201
- - Dynamic server addition/removal during runtime
202
- - Server health monitoring and auto-reconnection
203
- - Tool caching and performance optimization
204
- - Support for MCP resources and prompts
205
- - Rate limiting and timeout configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/codex_agent_demo.py DELETED
@@ -1,470 +0,0 @@
1
- """
2
- Minimum Viable Implementation of Codex Agent Loop in Python
3
-
4
- This demonstrates the core architecture patterns from codex-rs:
5
- - Async submission loop (like submission_loop in codex.rs)
6
- - Context manager for conversation history
7
- - Channel-based communication (submissions in, events out)
8
- - Handler pattern for operations
9
- """
10
-
11
- import asyncio
12
- from dataclasses import dataclass, field
13
- from datetime import datetime
14
- from enum import Enum
15
- from typing import Any, Dict, List, Optional
16
-
17
- # ============================================================================
18
- # PROTOCOL TYPES (ResponseItem equivalents)
19
- # ============================================================================
20
-
21
-
22
- class MessageRole(Enum):
23
- SYSTEM = "system"
24
- USER = "user"
25
- ASSISTANT = "assistant"
26
-
27
-
28
- @dataclass
29
- class Message:
30
- role: MessageRole
31
- content: str
32
- timestamp: datetime = field(default_factory=datetime.now)
33
-
34
-
35
- @dataclass
36
- class ToolCall:
37
- call_id: str
38
- tool_name: str
39
- arguments: Dict[str, Any]
40
-
41
-
42
- @dataclass
43
- class ToolOutput:
44
- call_id: str
45
- content: str
46
- success: bool = True
47
-
48
-
49
- # ============================================================================
50
- # CONTEXT MANAGER (like context_manager/history.rs)
51
- # ============================================================================
52
-
53
-
54
- class ContextManager:
55
- """
56
- Manages conversation history with normalization and truncation.
57
- Based on codex-rs/core/src/context_manager/history.rs
58
- """
59
-
60
- def __init__(self, max_history_length: int = 1000):
61
- self.items: List[Any] = [] # Oldest → Newest
62
- self.token_count: int = 0
63
- self.max_history_length = max_history_length
64
-
65
- def record_items(self, items: List[Any]) -> None:
66
- """Record new items to history (like record_items in history.rs:41)"""
67
- for item in items:
68
- # Filter and process items
69
- if self._is_api_message(item):
70
- processed = self._process_item(item)
71
- self.items.append(processed)
72
-
73
- def _is_api_message(self, item: Any) -> bool:
74
- """Filter out system messages (like is_api_message in history.rs:157)"""
75
- if isinstance(item, Message):
76
- return item.role != MessageRole.SYSTEM
77
- return isinstance(item, (ToolCall, ToolOutput))
78
-
79
- def _process_item(self, item: Any) -> Any:
80
- """Process item before adding (like process_item in history.rs:119)"""
81
- # Truncate long outputs
82
- if isinstance(item, ToolOutput):
83
- if len(item.content) > 2000:
84
- item.content = item.content[:2000] + "...[truncated]"
85
- return item
86
-
87
- def get_history_for_prompt(self) -> List[Any]:
88
- """
89
- Get normalized history ready for model
90
- (like get_history_for_prompt in history.rs:65)
91
- """
92
- self._normalize_history()
93
- return self.items.copy()
94
-
95
- def _normalize_history(self) -> None:
96
- """
97
- Enforce invariants (like normalize_history in history.rs:102):
98
- 1. Every tool call has corresponding output
99
- 2. Every output has corresponding call
100
- """
101
- # Build mapping of call_id → call
102
- calls = {}
103
- outputs = {}
104
-
105
- for item in self.items:
106
- if isinstance(item, ToolCall):
107
- calls[item.call_id] = item
108
- elif isinstance(item, ToolOutput):
109
- outputs[item.call_id] = item
110
-
111
- # Remove orphan outputs (no matching call)
112
- self.items = [
113
- item
114
- for item in self.items
115
- if not isinstance(item, ToolOutput) or item.call_id in calls
116
- ]
117
-
118
- # Add missing outputs for calls (create synthetic outputs)
119
- for call_id, call in calls.items():
120
- if call_id not in outputs:
121
- self.items.append(
122
- ToolOutput(
123
- call_id=call_id, content="[No output recorded]", success=False
124
- )
125
- )
126
-
127
- def remove_first_item(self) -> None:
128
- """Remove oldest item for compaction (like remove_first_item in history.rs:71)"""
129
- if self.items:
130
- removed = self.items.pop(0)
131
- # Also remove corresponding pair if needed
132
- if isinstance(removed, ToolCall):
133
- self.items = [
134
- item
135
- for item in self.items
136
- if not (
137
- isinstance(item, ToolOutput) and item.call_id == removed.call_id
138
- )
139
- ]
140
- elif isinstance(removed, ToolOutput):
141
- self.items = [
142
- item
143
- for item in self.items
144
- if not (
145
- isinstance(item, ToolCall) and item.call_id == removed.call_id
146
- )
147
- ]
148
-
149
- def compact(self, target_size: int) -> None:
150
- """Remove old items until we're under target size"""
151
- while len(self.items) > target_size:
152
- self.remove_first_item()
153
-
154
-
155
- # ============================================================================
156
- # OPERATIONS (like Op enum in codex.rs)
157
- # ============================================================================
158
-
159
-
160
- class OpType(Enum):
161
- USER_INPUT = "user_input"
162
- EXEC_APPROVAL = "exec_approval"
163
- INTERRUPT = "interrupt"
164
- UNDO = "undo"
165
- COMPACT = "compact"
166
- SHUTDOWN = "shutdown"
167
-
168
-
169
- @dataclass
170
- class Operation:
171
- op_type: OpType
172
- data: Optional[Dict[str, Any]] = None
173
-
174
-
175
- @dataclass
176
- class Submission:
177
- id: str
178
- operation: Operation
179
-
180
-
181
- # ============================================================================
182
- # EVENTS (like Event in codex-rs)
183
- # ============================================================================
184
-
185
-
186
- @dataclass
187
- class Event:
188
- event_type: str
189
- data: Optional[Dict[str, Any]] = None
190
-
191
-
192
- # ============================================================================
193
- # SESSION STATE (like Session in codex.rs)
194
- # ============================================================================
195
-
196
-
197
- class Session:
198
- """
199
- Maintains agent session state
200
- Similar to Session in codex-rs/core/src/codex.rs
201
- """
202
-
203
- def __init__(self, event_queue: asyncio.Queue):
204
- self.context_manager = ContextManager(tool_specs=[])
205
- self.event_queue = event_queue
206
- self.is_running = True
207
- self.current_task: Optional[asyncio.Task] = None
208
-
209
- async def send_event(self, event: Event) -> None:
210
- """Send event back to client"""
211
- await self.event_queue.put(event)
212
-
213
- def interrupt(self) -> None:
214
- """Interrupt current running task"""
215
- if self.current_task and not self.current_task.done():
216
- self.current_task.cancel()
217
-
218
-
219
- # ============================================================================
220
- # OPERATION HANDLERS (like handlers module in codex.rs:1343)
221
- # ============================================================================
222
-
223
-
224
- class Handlers:
225
- """Handler functions for each operation type"""
226
-
227
- @staticmethod
228
- async def user_input(session: Session, text: str) -> None:
229
- """Handle user input (like user_input_or_turn in codex.rs:1291)"""
230
- # Add user message to history
231
- user_msg = Message(role=MessageRole.USER, content=text)
232
- session.context_manager.record_items([user_msg])
233
-
234
- # Send event that we're processing
235
- await session.send_event(
236
- Event(event_type="processing", data={"message": "Processing user input"})
237
- )
238
-
239
- # Simulate agent processing
240
- await asyncio.sleep(0.1)
241
-
242
- # Generate mock assistant response
243
- assistant_msg = Message(
244
- role=MessageRole.ASSISTANT, content=f"I received: {text}"
245
- )
246
- session.context_manager.record_items([assistant_msg])
247
-
248
- # Simulate tool call
249
- tool_call = ToolCall(
250
- call_id="call_123", tool_name="bash", arguments={"command": "echo 'hello'"}
251
- )
252
- session.context_manager.record_items([tool_call])
253
-
254
- # Simulate tool execution
255
- await asyncio.sleep(0.1)
256
-
257
- tool_output = ToolOutput(call_id="call_123", content="hello\n", success=True)
258
- session.context_manager.record_items([tool_output])
259
-
260
- # Send completion event
261
- await session.send_event(
262
- Event(
263
- event_type="turn_complete",
264
- data={"history_size": len(session.context_manager.items)},
265
- )
266
- )
267
-
268
- @staticmethod
269
- async def interrupt(session: Session) -> None:
270
- """Handle interrupt (like interrupt in codex.rs:1266)"""
271
- session.interrupt()
272
- await session.send_event(Event(event_type="interrupted"))
273
-
274
- @staticmethod
275
- async def compact(session: Session) -> None:
276
- """Handle compact (like compact in codex.rs:1317)"""
277
- old_size = len(session.context_manager.items)
278
- session.context_manager.compact(target_size=10)
279
- new_size = len(session.context_manager.items)
280
-
281
- await session.send_event(
282
- Event(
283
- event_type="compacted",
284
- data={"removed": old_size - new_size, "remaining": new_size},
285
- )
286
- )
287
-
288
- @staticmethod
289
- async def undo(session: Session) -> None:
290
- """Handle undo (like undo in codex.rs:1314)"""
291
- # Remove last user turn and all following items
292
- # Simplified: just remove last 2 items
293
- for _ in range(min(2, len(session.context_manager.items))):
294
- session.context_manager.items.pop()
295
-
296
- await session.send_event(Event(event_type="undo_complete"))
297
-
298
- @staticmethod
299
- async def shutdown(session: Session) -> bool:
300
- """Handle shutdown (like shutdown in codex.rs:1329)"""
301
- session.is_running = False
302
- await session.send_event(Event(event_type="shutdown"))
303
- return True
304
-
305
-
306
- # ============================================================================
307
- # MAIN AGENT LOOP (like submission_loop in codex.rs:1259)
308
- # ============================================================================
309
-
310
-
311
- async def submission_loop(
312
- submission_queue: asyncio.Queue, event_queue: asyncio.Queue
313
- ) -> None:
314
- """
315
- Main agent loop - processes submissions and dispatches to handlers.
316
- This is the core of the agent (like submission_loop in codex.rs:1259-1340)
317
- """
318
- session = Session(event_queue)
319
-
320
- print("🤖 Agent loop started")
321
-
322
- # Main processing loop
323
- while session.is_running:
324
- try:
325
- # Wait for next submission (like rx_sub.recv() in codex.rs:1262)
326
- submission = await submission_queue.get()
327
-
328
- print(f"📨 Received: {submission.operation.op_type.value}")
329
-
330
- # Dispatch to handler based on operation type
331
- # (like match in codex.rs:1264-1337)
332
- op = submission.operation
333
-
334
- if op.op_type == OpType.USER_INPUT:
335
- text = op.data.get("text", "") if op.data else ""
336
- await Handlers.user_input(session, text)
337
-
338
- elif op.op_type == OpType.INTERRUPT:
339
- await Handlers.interrupt(session)
340
-
341
- elif op.op_type == OpType.COMPACT:
342
- await Handlers.compact(session)
343
-
344
- elif op.op_type == OpType.UNDO:
345
- await Handlers.undo(session)
346
-
347
- elif op.op_type == OpType.SHUTDOWN:
348
- if await Handlers.shutdown(session):
349
- break
350
-
351
- else:
352
- print(f"⚠️ Unknown operation: {op.op_type}")
353
-
354
- except asyncio.CancelledError:
355
- break
356
- except Exception as e:
357
- print(f"❌ Error in agent loop: {e}")
358
- await session.send_event(Event(event_type="error", data={"error": str(e)}))
359
-
360
- print("🛑 Agent loop exited")
361
-
362
-
363
- # ============================================================================
364
- # CODEX INTERFACE (like Codex struct in codex.rs:154)
365
- # ============================================================================
366
-
367
-
368
- class Codex:
369
- """
370
- Main interface to the agent (like Codex in codex.rs:154-246)
371
- Provides submit() and next_event() methods
372
- """
373
-
374
- def __init__(self):
375
- self.submission_queue = asyncio.Queue()
376
- self.event_queue = asyncio.Queue()
377
- self.agent_task: Optional[asyncio.Task] = None
378
- self.submission_counter = 0
379
-
380
- async def spawn(self) -> None:
381
- """Spawn the agent loop (like Codex::spawn in codex.rs:156)"""
382
- self.agent_task = asyncio.create_task(
383
- submission_loop(self.submission_queue, self.event_queue)
384
- )
385
-
386
- async def submit(self, operation: Operation) -> str:
387
- """Submit operation to agent (like Codex::submit in codex.rs:218)"""
388
- self.submission_counter += 1
389
- submission = Submission(
390
- id=f"sub_{self.submission_counter}", operation=operation
391
- )
392
- await self.submission_queue.put(submission)
393
- return submission.id
394
-
395
- async def next_event(self) -> Optional[Event]:
396
- """Get next event from agent (like Codex::next_event in codex.rs:238)"""
397
- try:
398
- return await asyncio.wait_for(self.event_queue.get(), timeout=1.0)
399
- except asyncio.TimeoutError:
400
- return None
401
-
402
- async def shutdown(self) -> None:
403
- """Shutdown the agent"""
404
- await self.submit(Operation(op_type=OpType.SHUTDOWN))
405
- if self.agent_task:
406
- await self.agent_task
407
-
408
-
409
- # ============================================================================
410
- # DEMO / EXAMPLE USAGE
411
- # ============================================================================
412
-
413
-
414
- async def main():
415
- """Demo of the agent system"""
416
- print("=" * 60)
417
- print("Codex Agent Loop Demo (Python MVP)")
418
- print("=" * 60)
419
-
420
- # Create and spawn agent
421
- codex = Codex()
422
- await codex.spawn()
423
-
424
- # Submit some operations
425
- print("\n1️⃣ Submitting user input...")
426
- await codex.submit(
427
- Operation(op_type=OpType.USER_INPUT, data={"text": "Hello, agent!"})
428
- )
429
-
430
- # Receive events
431
- for _ in range(3):
432
- event = await codex.next_event()
433
- if event:
434
- print(f" ✅ Event: {event.event_type} - {event.data}")
435
-
436
- print("\n2️⃣ Submitting another input...")
437
- await codex.submit(
438
- Operation(op_type=OpType.USER_INPUT, data={"text": "What's the weather?"})
439
- )
440
-
441
- for _ in range(3):
442
- event = await codex.next_event()
443
- if event:
444
- print(f" ✅ Event: {event.event_type} - {event.data}")
445
-
446
- print("\n3️⃣ Compacting history...")
447
- await codex.submit(Operation(op_type=OpType.COMPACT))
448
-
449
- event = await codex.next_event()
450
- if event:
451
- print(f" ✅ Event: {event.event_type} - {event.data}")
452
-
453
- print("\n4️⃣ Undoing last turn...")
454
- await codex.submit(Operation(op_type=OpType.UNDO))
455
-
456
- event = await codex.next_event()
457
- if event:
458
- print(f" ✅ Event: {event.event_type}")
459
-
460
- # Shutdown
461
- print("\n5️⃣ Shutting down...")
462
- await codex.shutdown()
463
-
464
- print("\n" + "=" * 60)
465
- print("Demo complete!")
466
- print("=" * 60)
467
-
468
-
469
- if __name__ == "__main__":
470
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/context_manager/manager.py CHANGED
@@ -8,6 +8,7 @@ from pathlib import Path
8
  from typing import Any
9
 
10
  import yaml
 
11
  from jinja2 import Template
12
  from litellm import Message, acompletion
13
 
@@ -24,7 +25,8 @@ class ContextManager:
24
  prompt_file_suffix: str = "system_prompt_v2.yaml",
25
  ):
26
  self.system_prompt = self._load_system_prompt(
27
- tool_specs or [], prompt_file_suffix="system_prompt_v2.yaml"
 
28
  )
29
  self.max_context = max_context
30
  self.compact_size = int(max_context * compact_size)
@@ -58,6 +60,7 @@ class ContextManager:
58
  current_date=current_date,
59
  current_time=current_time,
60
  current_timezone=current_timezone,
 
61
  )
62
 
63
  def add_message(self, message: Message, token_count: int = None) -> None:
 
8
  from typing import Any
9
 
10
  import yaml
11
+ from huggingface_hub import HfApi
12
  from jinja2 import Template
13
  from litellm import Message, acompletion
14
 
 
25
  prompt_file_suffix: str = "system_prompt_v2.yaml",
26
  ):
27
  self.system_prompt = self._load_system_prompt(
28
+ tool_specs or [],
29
+ prompt_file_suffix="system_prompt_v2.yaml",
30
  )
31
  self.max_context = max_context
32
  self.compact_size = int(max_context * compact_size)
 
60
  current_date=current_date,
61
  current_time=current_time,
62
  current_timezone=current_timezone,
63
+ hf_user_info=HfApi().whoami().get("name"),
64
  )
65
 
66
  def add_message(self, message: Message, token_count: int = None) -> None:
agent/core/agent_loop.py CHANGED
@@ -76,7 +76,19 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
-
 
 
 
 
 
 
 
 
 
 
 
 
80
  return False
81
 
82
 
 
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
+
80
+ # hf_repo_files: upload (can overwrite) and delete require approval
81
+ if tool_name == "hf_repo_files":
82
+ operation = tool_args.get("operation", "")
83
+ if operation in ["upload", "delete"]:
84
+ return True
85
+
86
+ # hf_repo_git: destructive operations require approval
87
+ if tool_name == "hf_repo_git":
88
+ operation = tool_args.get("operation", "")
89
+ if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
90
+ return True
91
+
92
  return False
93
 
94
 
agent/core/tools.py CHANGED
@@ -35,22 +35,29 @@ from agent.tools.github_read_file import (
35
  GITHUB_READ_FILE_TOOL_SPEC,
36
  github_read_file_handler,
37
  )
 
 
 
 
 
 
 
 
38
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
39
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
40
- from agent.tools.private_hf_repo_tools import (
41
- PRIVATE_HF_REPO_TOOL_SPEC,
42
- private_hf_repo_handler,
43
- )
44
 
45
- # NOTE: Utils tool disabled - date/time now loaded into system prompt at initialization
46
- # from agent.tools.utils_tools import UTILS_TOOL_SPEC, utils_handler
 
 
 
47
 
48
  # Suppress aiohttp deprecation warning
49
  warnings.filterwarnings(
50
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
51
  )
52
 
53
- NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch"]
54
 
55
 
56
  def convert_mcp_content_to_string(content: list) -> str:
@@ -281,20 +288,19 @@ def create_builtin_tools() -> list[ToolSpec]:
281
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
282
  handler=hf_jobs_handler,
283
  ),
 
284
  ToolSpec(
285
- name=PRIVATE_HF_REPO_TOOL_SPEC["name"],
286
- description=PRIVATE_HF_REPO_TOOL_SPEC["description"],
287
- parameters=PRIVATE_HF_REPO_TOOL_SPEC["parameters"],
288
- handler=private_hf_repo_handler,
 
 
 
 
 
 
289
  ),
290
- # NOTE: Utils tool disabled - date/time now loaded into system prompt at initialization (less tool calls=more reliablity)
291
- # ToolSpec(
292
- # name=UTILS_TOOL_SPEC["name"],
293
- # description=UTILS_TOOL_SPEC["description"],
294
- # parameters=UTILS_TOOL_SPEC["parameters"],
295
- # handler=utils_handler,
296
- # ),
297
- # GitHub tools
298
  # NOTE: Github search code tool disabled - a bit buggy
299
  # ToolSpec(
300
  # name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
 
35
  GITHUB_READ_FILE_TOOL_SPEC,
36
  github_read_file_handler,
37
  )
38
+ from agent.tools.hf_repo_files_tool import (
39
+ HF_REPO_FILES_TOOL_SPEC,
40
+ hf_repo_files_handler,
41
+ )
42
+ from agent.tools.hf_repo_git_tool import (
43
+ HF_REPO_GIT_TOOL_SPEC,
44
+ hf_repo_git_handler,
45
+ )
46
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
47
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
 
 
 
 
48
 
49
+ # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
50
+ # from agent.tools.private_hf_repo_tools import (
51
+ # PRIVATE_HF_REPO_TOOL_SPEC,
52
+ # private_hf_repo_handler,
53
+ # )
54
 
55
  # Suppress aiohttp deprecation warning
56
  warnings.filterwarnings(
57
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
58
  )
59
 
60
+ NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
61
 
62
 
63
  def convert_mcp_content_to_string(content: list) -> str:
 
288
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
289
  handler=hf_jobs_handler,
290
  ),
291
+ # HF Repo management tools
292
  ToolSpec(
293
+ name=HF_REPO_FILES_TOOL_SPEC["name"],
294
+ description=HF_REPO_FILES_TOOL_SPEC["description"],
295
+ parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
296
+ handler=hf_repo_files_handler,
297
+ ),
298
+ ToolSpec(
299
+ name=HF_REPO_GIT_TOOL_SPEC["name"],
300
+ description=HF_REPO_GIT_TOOL_SPEC["description"],
301
+ parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
302
+ handler=hf_repo_git_handler,
303
  ),
 
 
 
 
 
 
 
 
304
  # NOTE: Github search code tool disabled - a bit buggy
305
  # ToolSpec(
306
  # name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
agent/main.py CHANGED
@@ -287,6 +287,95 @@ async def event_listener(
287
  if len(all_lines) > 5:
288
  print("...")
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  # Get user decision for this item
291
  response = await prompt_session.prompt_async(
292
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
 
287
  if len(all_lines) > 5:
288
  print("...")
289
 
290
+ elif tool_name == "hf_repo_files":
291
+ # Handle repo files operations (upload, delete)
292
+ repo_id = arguments.get("repo_id", "")
293
+ repo_type = arguments.get("repo_type", "model")
294
+ revision = arguments.get("revision", "main")
295
+
296
+ # Build repo URL
297
+ if repo_type == "model":
298
+ repo_url = f"https://huggingface.co/{repo_id}"
299
+ else:
300
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
301
+
302
+ print(f"Repository: {repo_id}")
303
+ print(f"Type: {repo_type}")
304
+ print(f"Branch: {revision}")
305
+ print(f"URL: {repo_url}")
306
+
307
+ if operation == "upload":
308
+ path = arguments.get("path", "")
309
+ content = arguments.get("content", "")
310
+ create_pr = arguments.get("create_pr", False)
311
+
312
+ print(f"File: {path}")
313
+ if create_pr:
314
+ print("Mode: Create PR")
315
+
316
+ if isinstance(content, str):
317
+ all_lines = content.split("\n")
318
+ line_count = len(all_lines)
319
+ size_bytes = len(content.encode("utf-8"))
320
+ size_kb = size_bytes / 1024
321
+
322
+ print(f"Lines: {line_count}")
323
+ if size_kb < 1024:
324
+ print(f"Size: {size_kb:.2f} KB")
325
+ else:
326
+ print(f"Size: {size_kb / 1024:.2f} MB")
327
+
328
+ # Show full content
329
+ print(f"Content:\n{content}")
330
+
331
+ elif operation == "delete":
332
+ patterns = arguments.get("patterns", [])
333
+ if isinstance(patterns, str):
334
+ patterns = [patterns]
335
+ print(f"Patterns to delete: {', '.join(patterns)}")
336
+
337
+ elif tool_name == "hf_repo_git":
338
+ # Handle git operations (branches, tags, PRs, repo management)
339
+ repo_id = arguments.get("repo_id", "")
340
+ repo_type = arguments.get("repo_type", "model")
341
+
342
+ # Build repo URL
343
+ if repo_type == "model":
344
+ repo_url = f"https://huggingface.co/{repo_id}"
345
+ else:
346
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
347
+
348
+ print(f"Repository: {repo_id}")
349
+ print(f"Type: {repo_type}")
350
+ print(f"URL: {repo_url}")
351
+
352
+ if operation == "delete_branch":
353
+ branch = arguments.get("branch", "")
354
+ print(f"Branch to delete: {branch}")
355
+
356
+ elif operation == "delete_tag":
357
+ tag = arguments.get("tag", "")
358
+ print(f"Tag to delete: {tag}")
359
+
360
+ elif operation == "merge_pr":
361
+ pr_num = arguments.get("pr_num", "")
362
+ print(f"PR to merge: #{pr_num}")
363
+
364
+ elif operation == "create_repo":
365
+ private = arguments.get("private", False)
366
+ space_sdk = arguments.get("space_sdk")
367
+ print(f"Private: {private}")
368
+ if space_sdk:
369
+ print(f"Space SDK: {space_sdk}")
370
+
371
+ elif operation == "update_repo":
372
+ private = arguments.get("private")
373
+ gated = arguments.get("gated")
374
+ if private is not None:
375
+ print(f"Private: {private}")
376
+ if gated is not None:
377
+ print(f"Gated: {gated}")
378
+
379
  # Get user decision for this item
380
  response = await prompt_session.prompt_async(
381
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
agent/prompts/system_prompt_v2.yaml CHANGED
@@ -2,6 +2,7 @@ system_prompt: |
2
  You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools.
3
 
4
  _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
 
5
 
6
  # Core Mission & Behavior
7
 
@@ -330,11 +331,6 @@ system_prompt: |
330
  - Check model size, architecture, requirements
331
  - Verify dataset columns, splits, size
332
 
333
- **hf_whoami:**
334
- - Check authentication status
335
- - Verify token has correct permissions
336
- - Use before operations requiring write access
337
-
338
  ## Execution & Storage Tools
339
 
340
  **hf_jobs:**
@@ -456,8 +452,6 @@ system_prompt: |
456
  hub_model_id="username/model-name", # ← Must be set
457
  # ...
458
  )
459
-
460
- # Verify token: hf_whoami()
461
  ```
462
 
463
  ### Dataset Format Mismatch
 
2
  You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools.
3
 
4
  _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
5
+ {% if hf_user_info %}_AUTHENTICATED ON HF AS: **{{ hf_user_info }}**_{% endif %}
6
 
7
  # Core Mission & Behavior
8
 
 
331
  - Check model size, architecture, requirements
332
  - Verify dataset columns, splits, size
333
 
 
 
 
 
 
334
  ## Execution & Storage Tools
335
 
336
  **hf_jobs:**
 
452
  hub_model_id="username/model-name", # ← Must be set
453
  # ...
454
  )
 
 
455
  ```
456
 
457
  ### Dataset Format Mismatch
agent/tools/docs_tools.py CHANGED
@@ -1,289 +1,474 @@
1
  """
2
- Documentation search tools for the HF Agent
3
- Tools for exploring and fetching HuggingFace documentation and API specifications
4
  """
5
 
6
  import asyncio
 
7
  import os
8
  from typing import Any
9
 
10
  import httpx
11
  from bs4 import BeautifulSoup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Cache for OpenAPI spec to avoid repeated fetches
14
- _openapi_spec_cache: dict[str, Any] | None = None
 
15
 
 
 
 
 
16
 
17
- async def _fetch_html_page(hf_token: str, endpoint: str) -> str:
18
- """Fetch the HTML page for a given endpoint"""
19
- base_url = "https://huggingface.co/docs"
20
- url = f"{base_url}/{endpoint}"
21
- headers = {"Authorization": f"Bearer {hf_token}"}
22
 
 
 
 
 
 
 
 
23
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
24
- response = await client.get(url, headers=headers)
25
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- return response.text
28
 
 
 
 
29
 
30
- def _parse_sidebar_navigation(html_content: str) -> list[dict[str, str]]:
31
- """Parse the sidebar navigation and extract all links"""
32
- soup = BeautifulSoup(html_content, "html.parser")
33
- sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
34
 
35
- if not sidebar:
36
- raise ValueError("Could not find navigation sidebar")
 
 
37
 
38
- links = sidebar.find_all("a", href=True)
39
- nav_data = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- for link in links:
42
- title = link.get_text(strip=True)
43
- href = link["href"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Make URL absolute
46
- page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
47
- nav_data.append({"title": title, "url": page_url})
 
 
 
48
 
49
- return nav_data
 
 
50
 
51
 
52
- async def _fetch_single_glimpse(
53
- client: httpx.AsyncClient, hf_token: str, item: dict[str, str]
54
- ) -> dict[str, str]:
55
- """Fetch a glimpse (first 300 chars) for a single page"""
56
- md_url = f"{item['url']}.md"
57
- headers = {"Authorization": f"Bearer {hf_token}"}
58
 
59
  try:
60
- response = await client.get(md_url, headers=headers)
61
- response.raise_for_status()
62
-
63
- content = response.text
64
- glimpse = content[:300].strip()
65
- if len(content) > 300:
66
- glimpse += "..."
67
-
68
- return {
69
- "title": item["title"],
70
- "url": item["url"],
71
- "md_url": md_url,
72
- "glimpse": glimpse,
73
- }
74
- except Exception as e:
75
- return {
76
- "title": item["title"],
77
- "url": item["url"],
78
- "md_url": md_url,
79
- "glimpse": f"[Could not fetch glimpse: {str(e)[:50]}]",
80
- }
81
-
82
-
83
- async def _fetch_all_glimpses(
84
- hf_token: str, nav_data: list[dict[str, str]]
85
- ) -> list[dict[str, str]]:
86
- """Fetch glimpses for all pages in parallel"""
87
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
88
- result_items = await asyncio.gather(
89
- *[_fetch_single_glimpse(client, hf_token, item) for item in nav_data]
90
- )
 
 
 
 
 
 
 
91
 
92
- return list(result_items)
 
 
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
 
94
 
95
- def _format_exploration_results(
96
- endpoint: str, result_items: list[dict[str, str]]
97
- ) -> str:
98
- """Format the exploration results as a readable string"""
99
- base_url = "https://huggingface.co/docs"
100
- url = f"{base_url}/{endpoint}"
101
- result = f"Documentation structure for: {url}\n\n"
102
- result += f"Found {len(result_items)} pages:\n\n"
103
 
104
- for i, item in enumerate(result_items, 1):
105
- result += f"{i}. **{item['title']}**\n"
106
- result += f" URL: {item['url']}\n"
107
- result += f" Glimpse: {item['glimpse']}\n\n"
108
 
109
- return result
110
 
 
 
 
 
 
111
 
112
- async def explore_hf_docs(hf_token: str, endpoint: str) -> str:
113
- """Main function to explore documentation structure"""
114
- # Fetch HTML page
115
- html_content = await _fetch_html_page(hf_token, endpoint)
116
 
117
- # Parse navigation
118
- nav_data = _parse_sidebar_navigation(html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- if not nav_data:
121
- raise ValueError(f"No navigation links found for endpoint '{endpoint}'")
 
 
122
 
123
- # Fetch all glimpses in parallel
124
- result_items = await _fetch_all_glimpses(hf_token, nav_data)
125
 
126
- # Format results
127
- result = _format_exploration_results(endpoint, result_items)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- return result
 
 
 
 
 
 
 
130
 
 
 
 
 
 
 
 
131
 
132
- async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
133
- """
134
- Explore the documentation structure for a given endpoint by parsing the sidebar navigation
135
 
136
- Args:
137
- arguments: Dictionary with 'endpoint' parameter (e.g., 'trl', 'transformers', etc.)
 
 
 
 
 
 
138
 
139
- Returns:
140
- Tuple of (structured_navigation_with_glimpses, success)
141
- """
142
- endpoint = arguments.get("endpoint", "")
143
 
144
- if not endpoint:
145
- return "Error: No endpoint provided", False
 
 
 
146
 
147
- # Get HF token from environment
148
  hf_token = os.environ.get("HF_TOKEN")
149
-
150
  if not hf_token:
151
  return "Error: HF_TOKEN environment variable not set", False
152
 
153
- endpoint = endpoint.lstrip("/")
 
154
 
155
  try:
156
- result = await explore_hf_docs(hf_token, endpoint)
157
- return result, True
158
-
 
 
 
159
  except httpx.HTTPStatusError as e:
160
  return (
161
- f"HTTP error: {e.response.status_code} - {e.response.text[:200]}",
162
  False,
163
  )
164
  except httpx.RequestError as e:
165
- return f"Request error: {str(e)}", False
166
- except ValueError as e:
167
- return f"Error: {str(e)}", False
168
  except Exception as e:
169
- return f"Unexpected error: {str(e)}", False
170
 
171
 
172
- async def _fetch_openapi_spec() -> dict[str, Any]:
173
- """Fetch and cache the HuggingFace OpenAPI specification"""
174
- global _openapi_spec_cache
175
 
176
- if _openapi_spec_cache is not None:
177
- return _openapi_spec_cache
178
 
179
- url = "https://huggingface.co/.well-known/openapi.json"
 
 
 
 
180
 
181
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
182
- response = await client.get(url)
183
- response.raise_for_status()
184
 
185
- spec = response.json()
186
- _openapi_spec_cache = spec
187
-
188
- return spec
189
 
190
 
191
  def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
192
- """Extract all unique tags from the OpenAPI spec"""
193
  tags = set()
194
-
195
- # Get tags from the tags section
196
  for tag_obj in spec.get("tags", []):
197
  if "name" in tag_obj:
198
  tags.add(tag_obj["name"])
199
-
200
- # Also get tags from paths (in case some aren't in the tags section)
201
- for path, path_item in spec.get("paths", {}).items():
202
- for method, operation in path_item.items():
203
  if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
204
- for tag in operation.get("tags", []):
205
  tags.add(tag)
206
-
207
- return sorted(list(tags))
208
-
209
-
210
- def _search_openapi_by_tag(spec: dict[str, Any], tag: str) -> list[dict[str, Any]]:
211
- """Search for API endpoints with a specific tag"""
212
- results = []
213
- paths = spec.get("paths", {})
214
- servers = spec.get("servers", [])
215
- base_url = (
216
- servers[0].get("url", "https://huggingface.co")
217
- if servers
218
- else "https://huggingface.co"
219
- )
220
-
221
- for path, path_item in paths.items():
222
- for method, operation in path_item.items():
223
- if method not in [
224
- "get",
225
- "post",
226
- "put",
227
- "delete",
228
- "patch",
229
- "head",
230
- "options",
231
- ]:
232
- continue
233
-
234
- operation_tags = operation.get("tags", [])
235
- if tag in operation_tags:
236
- # Extract parameters
237
- parameters = operation.get("parameters", [])
238
- request_body = operation.get("requestBody", {})
239
- responses = operation.get("responses", {})
240
-
241
- results.append(
242
- {
243
- "path": path,
244
- "method": method.upper(),
245
- "operationId": operation.get("operationId", ""),
246
- "summary": operation.get("summary", ""),
247
- "description": operation.get("description", ""),
248
- "parameters": parameters,
249
- "request_body": request_body,
250
- "responses": responses,
251
- "base_url": base_url,
252
- }
253
- )
254
-
255
- return results
256
 
257
 
258
  def _generate_curl_example(endpoint: dict[str, Any]) -> str:
259
- """Generate a curl command example for an endpoint"""
260
  method = endpoint["method"]
261
  path = endpoint["path"]
262
  base_url = endpoint["base_url"]
263
 
264
- # Build the full URL with example path parameters
265
  full_path = path
266
  for param in endpoint.get("parameters", []):
267
  if param.get("in") == "path" and param.get("required"):
268
- param_name = param["name"]
269
  example = param.get(
270
- "example", param.get("schema", {}).get("example", f"<{param_name}>")
271
  )
272
- full_path = full_path.replace(f"{{{param_name}}}", str(example))
273
 
274
  curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
275
 
276
- # Add query parameters if any
277
  query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
278
  if query_params and query_params[0].get("required"):
279
  param = query_params[0]
280
  example = param.get("example", param.get("schema", {}).get("example", "value"))
281
  curl += f"?{param['name']}={example}"
282
 
283
- # Add headers
284
  curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
285
 
286
- # Add request body if applicable
287
  if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
288
  content = endpoint["request_body"].get("content", {})
289
  if "application/json" in content:
@@ -291,8 +476,6 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
291
  schema = content["application/json"].get("schema", {})
292
  example = schema.get("example", "{}")
293
  if isinstance(example, dict):
294
- import json
295
-
296
  example = json.dumps(example, indent=2)
297
  curl += f" \\\n -d '{example}'"
298
 
@@ -300,72 +483,50 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
300
 
301
 
302
  def _format_parameters(parameters: list[dict[str, Any]]) -> str:
303
- """Format parameter information from OpenAPI spec"""
304
  if not parameters:
305
  return ""
306
 
307
- # Group parameters by type
308
  path_params = [p for p in parameters if p.get("in") == "path"]
309
  query_params = [p for p in parameters if p.get("in") == "query"]
310
  header_params = [p for p in parameters if p.get("in") == "header"]
311
 
312
  output = []
313
 
314
- if path_params:
315
- output.append("**Path Parameters:**")
316
- for param in path_params:
317
- name = param.get("name", "")
318
- required = " (required)" if param.get("required") else " (optional)"
319
- description = param.get("description", "")
320
- param_type = param.get("schema", {}).get("type", "string")
321
- example = param.get("example") or param.get("schema", {}).get("example", "")
322
-
323
- output.append(f"- `{name}` ({param_type}){required}: {description}")
324
- if example:
325
- output.append(f" Example: `{example}`")
326
-
327
- if query_params:
328
  if output:
329
  output.append("")
330
- output.append("**Query Parameters:**")
331
- for param in query_params:
332
- name = param.get("name", "")
333
- required = " (required)" if param.get("required") else " (optional)"
334
- description = param.get("description", "")
335
- param_type = param.get("schema", {}).get("type", "string")
336
- example = param.get("example") or param.get("schema", {}).get("example", "")
337
-
338
- output.append(f"- `{name}` ({param_type}){required}: {description}")
339
  if example:
340
  output.append(f" Example: `{example}`")
341
 
342
- if header_params:
343
- if output:
344
- output.append("")
345
- output.append("**Header Parameters:**")
346
- for param in header_params:
347
- name = param.get("name", "")
348
- required = " (required)" if param.get("required") else " (optional)"
349
- description = param.get("description", "")
350
-
351
- output.append(f"- `{name}`{required}: {description}")
352
-
353
  return "\n".join(output)
354
 
355
 
356
  def _format_response_info(responses: dict[str, Any]) -> str:
357
- """Format response information from OpenAPI spec"""
358
  if not responses:
359
  return "No response information available"
360
 
361
  output = []
362
- for status_code, response_obj in list(responses.items())[
363
- :3
364
- ]: # Show first 3 status codes
365
- desc = response_obj.get("description", "")
366
- output.append(f"- **{status_code}**: {desc}")
367
-
368
- content = response_obj.get("content", {})
369
  if "application/json" in content:
370
  schema = content["application/json"].get("schema", {})
371
  if "type" in schema:
@@ -375,72 +536,87 @@ def _format_response_info(responses: dict[str, Any]) -> str:
375
 
376
 
377
  def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
378
- """Format OpenAPI search results as markdown with curl examples"""
379
  if not results:
380
  return f"No API endpoints found with tag '{tag}'"
381
 
382
- output = f"# API Endpoints for tag: `{tag}`\n\n"
383
- output += f"Found {len(results)} endpoint(s)\n\n"
384
- output += "---\n\n"
385
 
386
- for i, endpoint in enumerate(results, 1):
387
- output += f"## {i}. {endpoint['method']} {endpoint['path']}\n\n"
388
 
389
- if endpoint["summary"]:
390
- output += f"**Summary:** {endpoint['summary']}\n\n"
391
 
392
- if endpoint["description"]:
393
- desc = endpoint["description"][:300]
394
- if len(endpoint["description"]) > 300:
395
  desc += "..."
396
- output += f"**Description:** {desc}\n\n"
397
 
398
- # Parameters
399
- params_info = _format_parameters(endpoint.get("parameters", []))
400
  if params_info:
401
- output += params_info + "\n\n"
402
-
403
- # Curl example
404
- output += "**Usage:**\n```bash\n"
405
- output += _generate_curl_example(endpoint)
406
- output += "\n```\n\n"
407
 
408
- # Response info
409
- output += "**Returns:**\n"
410
- output += _format_response_info(endpoint["responses"])
411
- output += "\n\n"
412
 
413
- output += "---\n\n"
 
 
414
 
415
- return output
416
 
417
 
418
  async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
419
- """
420
- Search the HuggingFace OpenAPI specification by tag
421
-
422
- Args:
423
- arguments: Dictionary with 'tag' parameter
424
-
425
- Returns:
426
- Tuple of (search_results, success)
427
- """
428
  tag = arguments.get("tag", "")
429
-
430
  if not tag:
431
  return "Error: No tag provided", False
432
 
433
  try:
434
- # Fetch OpenAPI spec (cached after first fetch)
435
  spec = await _fetch_openapi_spec()
 
 
 
 
 
 
 
436
 
437
- # Search for endpoints with this tag
438
- results = _search_openapi_by_tag(spec, tag)
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- # Format results
441
- formatted = _format_openapi_results(results, tag)
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- return formatted, True
444
 
445
  except httpx.HTTPStatusError as e:
446
  return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
@@ -450,66 +626,86 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
450
  return f"Error searching OpenAPI spec: {str(e)}", False
451
 
452
 
453
- async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
454
- """
455
- Fetch full documentation content from a specific HF docs page
456
-
457
- Args:
458
- arguments: Dictionary with 'url' parameter (full URL to the doc page)
459
-
460
- Returns:
461
- Tuple of (full_markdown_content, success)
462
- """
463
- url = arguments.get("url", "")
464
-
465
- if not url:
466
- return "Error: No URL provided", False
467
-
468
- # Get HF token from environment
469
- hf_token = os.environ.get("HF_TOKEN")
470
-
471
- if not hf_token:
472
- return (
473
- "Error: HF_TOKEN environment variable not set",
474
- False,
475
- )
476
-
477
- # Add .md extension if not already present
478
- if not url.endswith(".md"):
479
- url = f"{url}.md"
480
-
481
- try:
482
- # Make request with auth
483
- headers = {"Authorization": f"Bearer {hf_token}"}
484
-
485
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
486
- response = await client.get(url, headers=headers)
487
- response.raise_for_status()
488
-
489
- content = response.text
490
-
491
- # Return the markdown content directly
492
- result = f"Documentation from: {url}\n\n{content}"
493
-
494
- return result, True
495
 
496
- except httpx.HTTPStatusError as e:
497
- return (
498
- f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
499
- False,
500
- )
501
- except httpx.RequestError as e:
502
- return f"Request error fetching {url}: {str(e)}", False
503
- except Exception as e:
504
- return f"Error fetching documentation: {str(e)}", False
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
 
507
- # Tool specifications for documentation search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  EXPLORE_HF_DOCS_TOOL_SPEC = {
510
  "name": "explore_hf_docs",
511
  "description": (
512
- "Explore Hugging Face documentation structure and discover available pages with 300-character previews. "
513
  "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
514
  "Your training data may be outdated - current documentation is the source of truth. "
515
  "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
@@ -519,77 +715,22 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
519
  "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
520
  "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
521
  "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
 
522
  ),
523
  "parameters": {
524
  "type": "object",
525
  "properties": {
526
  "endpoint": {
527
  "type": "string",
528
- "enum": [
529
- "hub",
530
- "transformers",
531
- "diffusers",
532
- "datasets",
533
- "gradio",
534
- "trackio",
535
- "smolagents",
536
- "huggingface_hub",
537
- "huggingface.js",
538
- "transformers.js",
539
- "inference-providers",
540
- "inference-endpoints",
541
- "peft",
542
- "accelerate",
543
- "optimum",
544
- "optimum-habana",
545
- "optimum-neuron",
546
- "optimum-intel",
547
- "optimum-executorch",
548
- "optimum-tpu",
549
- "tokenizers",
550
- "llm-course",
551
- "robotics-course",
552
- "mcp-course",
553
- "smol-course",
554
- "agents-course",
555
- "deep-rl-course",
556
- "computer-vision-course",
557
- "evaluate",
558
- "tasks",
559
- "dataset-viewer",
560
- "trl",
561
- "simulate",
562
- "sagemaker",
563
- "timm",
564
- "safetensors",
565
- "tgi",
566
- "setfit",
567
- "audio-course",
568
- "lerobot",
569
- "autotrain",
570
- "tei",
571
- "bitsandbytes",
572
- "cookbook",
573
- "sentence_transformers",
574
- "ml-games-course",
575
- "diffusion-course",
576
- "ml-for-3d-course",
577
- "chat-ui",
578
- "leaderboards",
579
- "lighteval",
580
- "argilla",
581
- "distilabel",
582
- "microsoft-azure",
583
- "kernels",
584
- "google-cloud",
585
- ],
586
  "description": (
587
  "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
 
588
  "• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
589
  "• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
590
  "• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
591
  "• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
592
- "• gradio — UI components and demos for interacting with ML models.\n"
593
  "• trackio — Experiment tracking, metrics logging, and run comparison.\n"
594
  "• smolagents — Lightweight agent abstractions and tool-using patterns.\n"
595
  "• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n"
@@ -599,20 +740,8 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
599
  "• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n"
600
  "• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
601
  "• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
602
- "• optimum — Hardware-aware optimization and model export tooling.\n"
603
- "• optimum-habana — Training and inference on Habana Gaudi accelerators.\n"
604
- "• optimum-neuron — Optimization workflows for AWS Inferentia/Trainium.\n"
605
- "• optimum-intel — Intel CPU/GPU optimizations (OpenVINO, IPEX).\n"
606
- "• optimum-executorch — Exporting models to ExecuTorch for edge/mobile.\n"
607
- "• optimum-tpu — TPU-specific training and optimization paths.\n"
608
  "• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n"
609
- "• llm-course — End-to-end LLM concepts, training, and deployment.\n"
610
- "• robotics-course — Learning-based robotics foundations.\n"
611
- "• mcp-course — Model Context Protocol concepts and usage.\n"
612
- "• smol-course — Small-model and efficiency-focused workflows.\n"
613
- "• agents-course — Tool-using, planning, and multi-step agent design.\n"
614
- "• deep-rl-course — Deep reinforcement learning foundations.\n"
615
- "• computer-vision-course — Vision models, datasets, and pipelines.\n"
616
  "• evaluate — Metrics, evaluation workflows, and training-loop integration.\n"
617
  "• tasks — Canonical task definitions and model categorization.\n"
618
  "• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n"
@@ -623,16 +752,11 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
623
  "• safetensors — Safe, fast tensor serialization format.\n"
624
  "• tgi — High-throughput text generation server for LLMs.\n"
625
  "• setfit — Few-shot text classification via sentence embeddings.\n"
626
- "• audio-course — Speech and audio models, datasets, and tasks.\n"
627
  "• lerobot — Robotics datasets, policies, and learning workflows.\n"
628
  "• autotrain — No/low-code model training on Hugging Face.\n"
629
  "• tei — Optimized inference server for embedding workloads.\n"
630
  "• bitsandbytes — Quantization and memory-efficient optimizers.\n"
631
- "• cookbook — Practical, task-oriented recipes across the ecosystem.\n"
632
  "• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n"
633
- "• ml-games-course — Game-based ML and reinforcement learning experiments.\n"
634
- "• diffusion-course — Diffusion model theory and hands-on practice.\n"
635
- "• ml-for-3d-course — 3D representations, models, and learning techniques.\n"
636
  "• chat-ui — Reference chat interfaces for LLM deployment.\n"
637
  "• leaderboards — Evaluation leaderboards and submission mechanics.\n"
638
  "• lighteval — Lightweight, reproducible LLM evaluation framework.\n"
@@ -643,6 +767,19 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
643
  "• google-cloud — GCP deployment and serving workflows.\n"
644
  ),
645
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  },
647
  "required": ["endpoint"],
648
  },
@@ -677,40 +814,3 @@ HF_DOCS_FETCH_TOOL_SPEC = {
677
  "required": ["url"],
678
  },
679
  }
680
-
681
-
682
- async def _get_api_search_tool_spec() -> dict[str, Any]:
683
- """
684
- Dynamically generate the OpenAPI tool spec with tag enum populated at runtime
685
- This must be called async to fetch the OpenAPI spec and extract tags
686
- """
687
- spec = await _fetch_openapi_spec()
688
- tags = _extract_all_tags(spec)
689
-
690
- return {
691
- "name": "search_hf_api_endpoints",
692
- "description": (
693
- "Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
694
- "**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
695
- "(3) Need authentication patterns, (4) Understanding API parameters and responses, "
696
- "(5) Need curl examples for HTTP requests. "
697
- "Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
698
- "**Pattern:** search_hf_api_endpoints (find endpoint) → use curl pattern in implementation. "
699
- "Tags group related operations: repos, models, datasets, inference, spaces, etc. "
700
- "**Note:** Each result includes curl example with $HF_TOKEN placeholder for authentication. "
701
- "**For tool building:** This provides the API foundation for creating Hub interaction scripts."
702
- ),
703
- "parameters": {
704
- "type": "object",
705
- "properties": {
706
- "tag": {
707
- "type": "string",
708
- "enum": tags,
709
- "description": (
710
- "The API tag to search for. Each tag groups related API endpoints. "
711
- ),
712
- },
713
- },
714
- "required": ["tag"],
715
- },
716
- }
 
1
  """
2
+ Documentation search tools for exploring HuggingFace and Gradio documentation.
 
3
  """
4
 
5
  import asyncio
6
+ import json
7
  import os
8
  from typing import Any
9
 
10
  import httpx
11
  from bs4 import BeautifulSoup
12
+ from whoosh.analysis import StemmingAnalyzer
13
+ from whoosh.fields import ID, TEXT, Schema
14
+ from whoosh.filedb.filestore import RamStorage
15
+ from whoosh.qparser import MultifieldParser, OrGroup
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Configuration
19
+ # ---------------------------------------------------------------------------
20
+
21
+ DEFAULT_MAX_RESULTS = 20
22
+ MAX_RESULTS_CAP = 50
23
+
24
+ GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt"
25
+ GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt"
26
+
27
+ COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
28
+ "optimum": [
29
+ "optimum",
30
+ "optimum-habana",
31
+ "optimum-neuron",
32
+ "optimum-intel",
33
+ "optimum-executorch",
34
+ "optimum-tpu",
35
+ ],
36
+ "courses": [
37
+ "llm-course",
38
+ "robotics-course",
39
+ "mcp-course",
40
+ "smol-course",
41
+ "agents-course",
42
+ "deep-rl-course",
43
+ "computer-vision-course",
44
+ "audio-course",
45
+ "ml-games-course",
46
+ "diffusion-course",
47
+ "ml-for-3d-course",
48
+ "cookbook",
49
+ ],
50
+ }
51
 
52
+ # ---------------------------------------------------------------------------
53
+ # Caches
54
+ # ---------------------------------------------------------------------------
55
 
56
+ _docs_cache: dict[str, list[dict[str, str]]] = {}
57
+ _index_cache: dict[str, tuple[Any, MultifieldParser]] = {}
58
+ _cache_lock = asyncio.Lock()
59
+ _openapi_cache: dict[str, Any] | None = None
60
 
61
+ # ---------------------------------------------------------------------------
62
+ # Gradio Documentation
63
+ # ---------------------------------------------------------------------------
 
 
64
 
65
+
66
+ async def _fetch_gradio_docs(query: str | None = None) -> str:
67
+ """
68
+ Fetch Gradio documentation.
69
+ Without query: Get full documentation from llms.txt
70
+ With query: Run embedding search on guides/demos for relevant content
71
+ """
72
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
73
+ if not query:
74
+ resp = await client.get(GRADIO_LLMS_TXT_URL)
75
+ resp.raise_for_status()
76
+ return resp.text
77
+
78
+ resp = await client.post(
79
+ GRADIO_SEARCH_URL,
80
+ headers={
81
+ "Content-Type": "application/json",
82
+ "Origin": "https://gradio-docs-mcp.up.railway.app",
83
+ },
84
+ json={
85
+ "prompt_to_embed": query,
86
+ "SYSTEM_PROMPT": "$INSERT_GUIDES_DOCS_DEMOS",
87
+ "FALLBACK_PROMPT": "No results found",
88
+ },
89
+ )
90
+ resp.raise_for_status()
91
+ return resp.json().get("SYS_PROMPT", "No results found")
92
 
 
93
 
94
+ # ---------------------------------------------------------------------------
95
+ # HF Documentation - Fetching
96
+ # ---------------------------------------------------------------------------
97
 
 
 
 
 
98
 
99
+ async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
100
+ """Fetch all docs for an endpoint by parsing sidebar and fetching each page."""
101
+ url = f"https://huggingface.co/docs/{endpoint}"
102
+ headers = {"Authorization": f"Bearer {hf_token}"}
103
 
104
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
105
+ resp = await client.get(url, headers=headers)
106
+ resp.raise_for_status()
107
+
108
+ soup = BeautifulSoup(resp.text, "html.parser")
109
+ sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
110
+ if not sidebar:
111
+ raise ValueError(f"Could not find navigation sidebar for '{endpoint}'")
112
+
113
+ nav_items = []
114
+ for link in sidebar.find_all("a", href=True):
115
+ href = link["href"]
116
+ page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
117
+ nav_items.append({"title": link.get_text(strip=True), "url": page_url})
118
+
119
+ if not nav_items:
120
+ raise ValueError(f"No navigation links found for '{endpoint}'")
121
+
122
+ async def fetch_page(item: dict[str, str]) -> dict[str, str]:
123
+ md_url = f"{item['url']}.md"
124
+ try:
125
+ r = await client.get(md_url, headers=headers)
126
+ r.raise_for_status()
127
+ content = r.text.strip()
128
+ glimpse = content[:200] + "..." if len(content) > 200 else content
129
+ except Exception as e:
130
+ content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]"
131
+ return {
132
+ "title": item["title"],
133
+ "url": item["url"],
134
+ "md_url": md_url,
135
+ "glimpse": glimpse,
136
+ "content": content,
137
+ "section": endpoint,
138
+ }
139
+
140
+ return list(await asyncio.gather(*[fetch_page(item) for item in nav_items]))
141
+
142
+
143
+ async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
144
+ """Get docs for endpoint with caching. Expands composite endpoints."""
145
+ async with _cache_lock:
146
+ if endpoint in _docs_cache:
147
+ return _docs_cache[endpoint]
148
+
149
+ sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint])
150
+ all_docs: list[dict[str, str]] = []
151
+
152
+ for sub in sub_endpoints:
153
+ async with _cache_lock:
154
+ if sub in _docs_cache:
155
+ all_docs.extend(_docs_cache[sub])
156
+ continue
157
 
158
+ docs = await _fetch_endpoint_docs(hf_token, sub)
159
+ async with _cache_lock:
160
+ _docs_cache[sub] = docs
161
+ all_docs.extend(docs)
162
+
163
+ async with _cache_lock:
164
+ _docs_cache[endpoint] = all_docs
165
+ return all_docs
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # HF Documentation - Search
170
+ # ---------------------------------------------------------------------------
171
+
172
+
173
+ async def _build_search_index(
174
+ endpoint: str, docs: list[dict[str, str]]
175
+ ) -> tuple[Any, MultifieldParser]:
176
+ """Build or retrieve cached Whoosh search index."""
177
+ async with _cache_lock:
178
+ if endpoint in _index_cache:
179
+ return _index_cache[endpoint]
180
+
181
+ analyzer = StemmingAnalyzer()
182
+ schema = Schema(
183
+ title=TEXT(stored=True, analyzer=analyzer),
184
+ url=ID(stored=True, unique=True),
185
+ md_url=ID(stored=True),
186
+ section=ID(stored=True),
187
+ glimpse=TEXT(stored=True, analyzer=analyzer),
188
+ content=TEXT(stored=False, analyzer=analyzer),
189
+ )
190
+ storage = RamStorage()
191
+ index = storage.create_index(schema)
192
+ writer = index.writer()
193
+ for doc in docs:
194
+ writer.add_document(
195
+ title=doc.get("title", ""),
196
+ url=doc.get("url", ""),
197
+ md_url=doc.get("md_url", ""),
198
+ section=doc.get("section", endpoint),
199
+ glimpse=doc.get("glimpse", ""),
200
+ content=doc.get("content", ""),
201
+ )
202
+ writer.commit()
203
 
204
+ parser = MultifieldParser(
205
+ ["title", "content"],
206
+ schema=schema,
207
+ fieldboosts={"title": 2.0, "content": 1.0},
208
+ group=OrGroup,
209
+ )
210
 
211
+ async with _cache_lock:
212
+ _index_cache[endpoint] = (index, parser)
213
+ return index, parser
214
 
215
 
216
+ async def _search_docs(
217
+ endpoint: str, docs: list[dict[str, str]], query: str, limit: int
218
+ ) -> tuple[list[dict[str, Any]], str | None]:
219
+ """Search docs using Whoosh. Returns (results, fallback_message)."""
220
+ index, parser = await _build_search_index(endpoint, docs)
 
221
 
222
  try:
223
+ query_obj = parser.parse(query)
224
+ except Exception:
225
+ return [], "Query contained unsupported syntax; showing default ordering."
226
+
227
+ with index.searcher() as searcher:
228
+ results = searcher.search(query_obj, limit=limit)
229
+ matches = [
230
+ {
231
+ "title": hit["title"],
232
+ "url": hit["url"],
233
+ "md_url": hit.get("md_url", ""),
234
+ "section": hit.get("section", endpoint),
235
+ "glimpse": hit["glimpse"],
236
+ "score": round(hit.score, 2),
237
+ }
238
+ for hit in results
239
+ ]
240
+
241
+ if not matches:
242
+ return [], "No strong matches found; showing default ordering."
243
+ return matches, None
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # HF Documentation - Formatting
248
+ # ---------------------------------------------------------------------------
249
+
250
+
251
+ def _format_results(
252
+ endpoint: str,
253
+ items: list[dict[str, Any]],
254
+ total: int,
255
+ query: str | None = None,
256
+ note: str | None = None,
257
+ ) -> str:
258
+ """Format search results as readable text."""
259
+ base_url = f"https://huggingface.co/docs/{endpoint}"
260
+ out = f"Documentation structure for: {base_url}\n\n"
261
 
262
+ if query:
263
+ out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages"
264
+ if note:
265
+ out += f" ({note})"
266
+ out += "\n\n"
267
+ else:
268
+ out += f"Found {len(items)} page(s) (total available: {total}).\n"
269
+ if note:
270
+ out += f"({note})\n"
271
+ out += "\n"
272
 
273
+ for i, item in enumerate(items, 1):
274
+ out += f"{i}. **{item['title']}**\n"
275
+ out += f" URL: {item['url']}\n"
276
+ out += f" Section: {item.get('section', endpoint)}\n"
277
+ if query and "score" in item:
278
+ out += f" Relevance score: {item['score']:.2f}\n"
279
+ out += f" Glimpse: {item['glimpse']}\n\n"
280
+
281
+ return out
282
 
 
 
 
 
 
 
 
 
283
 
284
+ # ---------------------------------------------------------------------------
285
+ # Handlers
286
+ # ---------------------------------------------------------------------------
 
287
 
 
288
 
289
+ async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
290
+ """Explore documentation structure with optional search query."""
291
+ endpoint = arguments.get("endpoint", "").lstrip("/")
292
+ query = arguments.get("query")
293
+ max_results = arguments.get("max_results")
294
 
295
+ if not endpoint:
296
+ return "Error: No endpoint provided", False
 
 
297
 
298
+ # Gradio uses its own API
299
+ if endpoint.lower() == "gradio":
300
+ try:
301
+ clean_query = (
302
+ query.strip() if isinstance(query, str) and query.strip() else None
303
+ )
304
+ content = await _fetch_gradio_docs(clean_query)
305
+ header = "# Gradio Documentation\n\n"
306
+ if clean_query:
307
+ header += f"Query: '{clean_query}'\n\n"
308
+ header += "Source: https://gradio.app/docs\n\n---\n\n"
309
+ return header + content, True
310
+ except httpx.HTTPStatusError as e:
311
+ return f"HTTP error fetching Gradio docs: {e.response.status_code}", False
312
+ except httpx.RequestError as e:
313
+ return f"Request error fetching Gradio docs: {str(e)}", False
314
+ except Exception as e:
315
+ return f"Error fetching Gradio docs: {str(e)}", False
316
+
317
+ # HF docs
318
+ hf_token = os.environ.get("HF_TOKEN")
319
+ if not hf_token:
320
+ return "Error: HF_TOKEN environment variable not set", False
321
 
322
+ try:
323
+ max_results_int = int(max_results) if max_results is not None else None
324
+ except (TypeError, ValueError):
325
+ return "Error: max_results must be an integer", False
326
 
327
+ if max_results_int is not None and max_results_int <= 0:
328
+ return "Error: max_results must be greater than zero", False
329
 
330
+ try:
331
+ docs = await _get_docs(hf_token, endpoint)
332
+ total = len(docs)
333
+
334
+ # Determine limit
335
+ if max_results_int is None:
336
+ limit = DEFAULT_MAX_RESULTS
337
+ limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)."
338
+ elif max_results_int > MAX_RESULTS_CAP:
339
+ limit = MAX_RESULTS_CAP
340
+ limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)."
341
+ else:
342
+ limit = max_results_int
343
+ limit_note = None
344
+
345
+ # Search or paginate
346
+ clean_query = (
347
+ query.strip() if isinstance(query, str) and query.strip() else None
348
+ )
349
+ fallback_msg = None
350
 
351
+ if clean_query:
352
+ results, fallback_msg = await _search_docs(
353
+ endpoint, docs, clean_query, limit
354
+ )
355
+ if not results:
356
+ results = docs[:limit]
357
+ else:
358
+ results = docs[:limit]
359
 
360
+ # Combine notes
361
+ notes = []
362
+ if fallback_msg:
363
+ notes.append(fallback_msg)
364
+ if limit_note:
365
+ notes.append(limit_note)
366
+ note = "; ".join(notes) if notes else None
367
 
368
+ return _format_results(endpoint, results, total, clean_query, note), True
 
 
369
 
370
+ except httpx.HTTPStatusError as e:
371
+ return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False
372
+ except httpx.RequestError as e:
373
+ return f"Request error: {str(e)}", False
374
+ except ValueError as e:
375
+ return f"Error: {str(e)}", False
376
+ except Exception as e:
377
+ return f"Unexpected error: {str(e)}", False
378
 
 
 
 
 
379
 
380
+ async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
381
+ """Fetch full markdown content of a documentation page."""
382
+ url = arguments.get("url", "")
383
+ if not url:
384
+ return "Error: No URL provided", False
385
 
 
386
  hf_token = os.environ.get("HF_TOKEN")
 
387
  if not hf_token:
388
  return "Error: HF_TOKEN environment variable not set", False
389
 
390
+ if not url.endswith(".md"):
391
+ url = f"{url}.md"
392
 
393
  try:
394
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
395
+ resp = await client.get(
396
+ url, headers={"Authorization": f"Bearer {hf_token}"}
397
+ )
398
+ resp.raise_for_status()
399
+ return f"Documentation from: {url}\n\n{resp.text}", True
400
  except httpx.HTTPStatusError as e:
401
  return (
402
+ f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
403
  False,
404
  )
405
  except httpx.RequestError as e:
406
+ return f"Request error fetching {url}: {str(e)}", False
 
 
407
  except Exception as e:
408
+ return f"Error fetching documentation: {str(e)}", False
409
 
410
 
411
+ # ---------------------------------------------------------------------------
412
+ # OpenAPI Search
413
+ # ---------------------------------------------------------------------------
414
 
 
 
415
 
416
+ async def _fetch_openapi_spec() -> dict[str, Any]:
417
+ """Fetch and cache HuggingFace OpenAPI specification."""
418
+ global _openapi_cache
419
+ if _openapi_cache is not None:
420
+ return _openapi_cache
421
 
422
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
423
+ resp = await client.get("https://huggingface.co/.well-known/openapi.json")
424
+ resp.raise_for_status()
425
 
426
+ _openapi_cache = resp.json()
427
+ return _openapi_cache
 
 
428
 
429
 
430
  def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
431
+ """Extract all unique tags from OpenAPI spec."""
432
  tags = set()
 
 
433
  for tag_obj in spec.get("tags", []):
434
  if "name" in tag_obj:
435
  tags.add(tag_obj["name"])
436
+ for path_item in spec.get("paths", {}).values():
437
+ for method, op in path_item.items():
 
 
438
  if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
439
+ for tag in op.get("tags", []):
440
  tags.add(tag)
441
+ return sorted(tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
 
444
  def _generate_curl_example(endpoint: dict[str, Any]) -> str:
445
+ """Generate curl command example for an endpoint."""
446
  method = endpoint["method"]
447
  path = endpoint["path"]
448
  base_url = endpoint["base_url"]
449
 
450
+ # Build URL with path parameters
451
  full_path = path
452
  for param in endpoint.get("parameters", []):
453
  if param.get("in") == "path" and param.get("required"):
454
+ name = param["name"]
455
  example = param.get(
456
+ "example", param.get("schema", {}).get("example", f"<{name}>")
457
  )
458
+ full_path = full_path.replace(f"{{{name}}}", str(example))
459
 
460
  curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
461
 
462
+ # Add query parameters
463
  query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
464
  if query_params and query_params[0].get("required"):
465
  param = query_params[0]
466
  example = param.get("example", param.get("schema", {}).get("example", "value"))
467
  curl += f"?{param['name']}={example}"
468
 
 
469
  curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
470
 
471
+ # Add request body
472
  if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
473
  content = endpoint["request_body"].get("content", {})
474
  if "application/json" in content:
 
476
  schema = content["application/json"].get("schema", {})
477
  example = schema.get("example", "{}")
478
  if isinstance(example, dict):
 
 
479
  example = json.dumps(example, indent=2)
480
  curl += f" \\\n -d '{example}'"
481
 
 
483
 
484
 
485
  def _format_parameters(parameters: list[dict[str, Any]]) -> str:
486
+ """Format parameter information from OpenAPI spec."""
487
  if not parameters:
488
  return ""
489
 
 
490
  path_params = [p for p in parameters if p.get("in") == "path"]
491
  query_params = [p for p in parameters if p.get("in") == "query"]
492
  header_params = [p for p in parameters if p.get("in") == "header"]
493
 
494
  output = []
495
 
496
+ for label, params in [
497
+ ("Path Parameters", path_params),
498
+ ("Query Parameters", query_params),
499
+ ("Header Parameters", header_params),
500
+ ]:
501
+ if not params:
502
+ continue
 
 
 
 
 
 
 
503
  if output:
504
  output.append("")
505
+ output.append(f"**{label}:**")
506
+ for p in params:
507
+ name = p.get("name", "")
508
+ required = " (required)" if p.get("required") else " (optional)"
509
+ desc = p.get("description", "")
510
+ ptype = p.get("schema", {}).get("type", "string")
511
+ example = p.get("example") or p.get("schema", {}).get("example", "")
512
+
513
+ output.append(f"- `{name}` ({ptype}){required}: {desc}")
514
  if example:
515
  output.append(f" Example: `{example}`")
516
 
 
 
 
 
 
 
 
 
 
 
 
517
  return "\n".join(output)
518
 
519
 
520
  def _format_response_info(responses: dict[str, Any]) -> str:
521
+ """Format response information from OpenAPI spec."""
522
  if not responses:
523
  return "No response information available"
524
 
525
  output = []
526
+ for status, resp_obj in list(responses.items())[:3]:
527
+ desc = resp_obj.get("description", "")
528
+ output.append(f"- **{status}**: {desc}")
529
+ content = resp_obj.get("content", {})
 
 
 
530
  if "application/json" in content:
531
  schema = content["application/json"].get("schema", {})
532
  if "type" in schema:
 
536
 
537
 
538
  def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
539
+ """Format OpenAPI search results with curl examples."""
540
  if not results:
541
  return f"No API endpoints found with tag '{tag}'"
542
 
543
+ out = f"# API Endpoints for tag: `{tag}`\n\n"
544
+ out += f"Found {len(results)} endpoint(s)\n\n---\n\n"
 
545
 
546
+ for i, ep in enumerate(results, 1):
547
+ out += f"## {i}. {ep['method']} {ep['path']}\n\n"
548
 
549
+ if ep["summary"]:
550
+ out += f"**Summary:** {ep['summary']}\n\n"
551
 
552
+ if ep["description"]:
553
+ desc = ep["description"][:300]
554
+ if len(ep["description"]) > 300:
555
  desc += "..."
556
+ out += f"**Description:** {desc}\n\n"
557
 
558
+ params_info = _format_parameters(ep.get("parameters", []))
 
559
  if params_info:
560
+ out += params_info + "\n\n"
 
 
 
 
 
561
 
562
+ out += "**Usage:**\n```bash\n"
563
+ out += _generate_curl_example(ep)
564
+ out += "\n```\n\n"
 
565
 
566
+ out += "**Returns:**\n"
567
+ out += _format_response_info(ep["responses"])
568
+ out += "\n\n---\n\n"
569
 
570
+ return out
571
 
572
 
573
  async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
574
+ """Search HuggingFace OpenAPI specification by tag."""
 
 
 
 
 
 
 
 
575
  tag = arguments.get("tag", "")
 
576
  if not tag:
577
  return "Error: No tag provided", False
578
 
579
  try:
 
580
  spec = await _fetch_openapi_spec()
581
+ paths = spec.get("paths", {})
582
+ servers = spec.get("servers", [])
583
+ base_url = (
584
+ servers[0].get("url", "https://huggingface.co")
585
+ if servers
586
+ else "https://huggingface.co"
587
+ )
588
 
589
+ results = []
590
+ for path, path_item in paths.items():
591
+ for method, op in path_item.items():
592
+ if method not in [
593
+ "get",
594
+ "post",
595
+ "put",
596
+ "delete",
597
+ "patch",
598
+ "head",
599
+ "options",
600
+ ]:
601
+ continue
602
+ if tag not in op.get("tags", []):
603
+ continue
604
 
605
+ results.append(
606
+ {
607
+ "path": path,
608
+ "method": method.upper(),
609
+ "operationId": op.get("operationId", ""),
610
+ "summary": op.get("summary", ""),
611
+ "description": op.get("description", ""),
612
+ "parameters": op.get("parameters", []),
613
+ "request_body": op.get("requestBody", {}),
614
+ "responses": op.get("responses", {}),
615
+ "base_url": base_url,
616
+ }
617
+ )
618
 
619
+ return _format_openapi_results(results, tag), True
620
 
621
  except httpx.HTTPStatusError as e:
622
  return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
 
626
  return f"Error searching OpenAPI spec: {str(e)}", False
627
 
628
 
629
+ async def _get_api_search_tool_spec() -> dict[str, Any]:
630
+ """Generate OpenAPI tool spec with tags populated at runtime."""
631
+ spec = await _fetch_openapi_spec()
632
+ tags = _extract_all_tags(spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ return {
635
+ "name": "search_hf_api_endpoints",
636
+ "description": (
637
+ "Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
638
+ "**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
639
+ "(3) Need authentication patterns, (4) Understanding API parameters and responses, "
640
+ "(5) Need curl examples for HTTP requests. "
641
+ "Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
642
+ "Tags group related operations: repos, models, datasets, inference, spaces, etc."
643
+ ),
644
+ "parameters": {
645
+ "type": "object",
646
+ "properties": {
647
+ "tag": {
648
+ "type": "string",
649
+ "enum": tags,
650
+ "description": "The API tag to search for. Each tag groups related API endpoints.",
651
+ },
652
+ },
653
+ "required": ["tag"],
654
+ },
655
+ }
656
 
657
 
658
+ # ---------------------------------------------------------------------------
659
+ # Tool Specifications
660
+ # ---------------------------------------------------------------------------
661
+
662
+ DOC_ENDPOINTS = [
663
+ "hub",
664
+ "transformers",
665
+ "diffusers",
666
+ "datasets",
667
+ "gradio",
668
+ "trackio",
669
+ "smolagents",
670
+ "huggingface_hub",
671
+ "huggingface.js",
672
+ "transformers.js",
673
+ "inference-providers",
674
+ "inference-endpoints",
675
+ "peft",
676
+ "accelerate",
677
+ "optimum",
678
+ "tokenizers",
679
+ "courses",
680
+ "evaluate",
681
+ "tasks",
682
+ "dataset-viewer",
683
+ "trl",
684
+ "simulate",
685
+ "sagemaker",
686
+ "timm",
687
+ "safetensors",
688
+ "tgi",
689
+ "setfit",
690
+ "lerobot",
691
+ "autotrain",
692
+ "tei",
693
+ "bitsandbytes",
694
+ "sentence_transformers",
695
+ "chat-ui",
696
+ "leaderboards",
697
+ "lighteval",
698
+ "argilla",
699
+ "distilabel",
700
+ "microsoft-azure",
701
+ "kernels",
702
+ "google-cloud",
703
+ ]
704
 
705
  EXPLORE_HF_DOCS_TOOL_SPEC = {
706
  "name": "explore_hf_docs",
707
  "description": (
708
+ "Explore Hugging Face documentation structure and discover available pages with 200-character previews. "
709
  "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
710
  "Your training data may be outdated - current documentation is the source of truth. "
711
  "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
 
715
  "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
716
  "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
717
  "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
718
+ " By default returns the top 20 results; set max_results (max 50) to adjust."
719
  ),
720
  "parameters": {
721
  "type": "object",
722
  "properties": {
723
  "endpoint": {
724
  "type": "string",
725
+ "enum": DOC_ENDPOINTS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  "description": (
727
  "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
728
+ "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n"
729
  "• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
730
  "• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
731
  "• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
732
  "• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
733
+ "• gradio — UI components and demos for ML models. Uses Gradio's native API: without query returns full docs (llms.txt), with query uses embedding search for precise results.\n"
734
  "• trackio — Experiment tracking, metrics logging, and run comparison.\n"
735
  "• smolagents — Lightweight agent abstractions and tool-using patterns.\n"
736
  "• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n"
 
740
  "• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n"
741
  "• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
742
  "• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
743
+ "• optimum — Hardware-aware optimization and model export tooling, including Habana, Neuron, Intel, ExecuTorch, and TPU variants.\n"
 
 
 
 
 
744
  "• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n"
 
 
 
 
 
 
 
745
  "• evaluate — Metrics, evaluation workflows, and training-loop integration.\n"
746
  "• tasks — Canonical task definitions and model categorization.\n"
747
  "• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n"
 
752
  "• safetensors — Safe, fast tensor serialization format.\n"
753
  "• tgi — High-throughput text generation server for LLMs.\n"
754
  "• setfit — Few-shot text classification via sentence embeddings.\n"
 
755
  "• lerobot — Robotics datasets, policies, and learning workflows.\n"
756
  "• autotrain — No/low-code model training on Hugging Face.\n"
757
  "• tei — Optimized inference server for embedding workloads.\n"
758
  "• bitsandbytes — Quantization and memory-efficient optimizers.\n"
 
759
  "• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n"
 
 
 
760
  "• chat-ui — Reference chat interfaces for LLM deployment.\n"
761
  "• leaderboards — Evaluation leaderboards and submission mechanics.\n"
762
  "• lighteval — Lightweight, reproducible LLM evaluation framework.\n"
 
767
  "• google-cloud — GCP deployment and serving workflows.\n"
768
  ),
769
  },
770
+ "query": {
771
+ "type": "string",
772
+ "description": (
773
+ "Optional keyword query to rank and filter documentation pages. "
774
+ "For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'."
775
+ ),
776
+ },
777
+ "max_results": {
778
+ "type": "integer",
779
+ "description": "Max results (default 20, max 50). Ignored for Gradio.",
780
+ "minimum": 1,
781
+ "maximum": 50,
782
+ },
783
  },
784
  "required": ["endpoint"],
785
  },
 
814
  "required": ["url"],
815
  },
816
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/tools/hf_repo_files_tool.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Files Tool - File operations on Hugging Face repositories
3
+
4
+ Operations: list, read, upload, delete
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi, hf_hub_download
11
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal["list", "read", "upload", "delete"]
16
+
17
+
18
+ async def _async_call(func, *args, **kwargs):
19
+ """Wrap synchronous HfApi calls for async context."""
20
+ return await asyncio.to_thread(func, *args, **kwargs)
21
+
22
+
23
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
24
+ """Build the Hub URL for a repository."""
25
+ if repo_type == "model":
26
+ return f"https://huggingface.co/{repo_id}"
27
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
28
+
29
+
30
+ def _format_size(size_bytes: int) -> str:
31
+ """Format file size in human-readable form."""
32
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
33
+ if size_bytes < 1024:
34
+ return f"{size_bytes:.1f}{unit}"
35
+ size_bytes /= 1024
36
+ return f"{size_bytes:.1f}PB"
37
+
38
+
39
+ class HfRepoFilesTool:
40
+ """Tool for file operations on HF repos."""
41
+
42
+ def __init__(self, hf_token: Optional[str] = None):
43
+ self.api = HfApi(token=hf_token)
44
+
45
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
+ """Execute the specified operation."""
47
+ operation = args.get("operation")
48
+
49
+ if not operation:
50
+ return self._help()
51
+
52
+ try:
53
+ handlers = {
54
+ "list": self._list,
55
+ "read": self._read,
56
+ "upload": self._upload,
57
+ "delete": self._delete,
58
+ }
59
+
60
+ handler = handlers.get(operation)
61
+ if handler:
62
+ return await handler(args)
63
+ else:
64
+ return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
65
+
66
+ except RepositoryNotFoundError:
67
+ return self._error(f"Repository not found: {args.get('repo_id')}")
68
+ except EntryNotFoundError:
69
+ return self._error(f"File not found: {args.get('path')}")
70
+ except Exception as e:
71
+ return self._error(f"Error: {str(e)}")
72
+
73
+ def _help(self) -> ToolResult:
74
+ """Show usage instructions."""
75
+ return {
76
+ "formatted": """**hf_repo_files** - File operations on HF repos
77
+
78
+ **Operations:**
79
+ - `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
80
+ - `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
81
+ - `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
82
+ - `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
83
+
84
+ **Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
85
+ "totalResults": 1,
86
+ "resultsShared": 1,
87
+ }
88
+
89
+ async def _list(self, args: Dict[str, Any]) -> ToolResult:
90
+ """List files in a repository."""
91
+ repo_id = args.get("repo_id")
92
+ if not repo_id:
93
+ return self._error("repo_id is required")
94
+
95
+ repo_type = args.get("repo_type", "model")
96
+ revision = args.get("revision", "main")
97
+ path = args.get("path", "")
98
+
99
+ items = list(await _async_call(
100
+ self.api.list_repo_tree,
101
+ repo_id=repo_id,
102
+ repo_type=repo_type,
103
+ revision=revision,
104
+ path_in_repo=path,
105
+ recursive=True,
106
+ ))
107
+
108
+ if not items:
109
+ return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
110
+
111
+ lines = []
112
+ total_size = 0
113
+ for item in sorted(items, key=lambda x: x.path):
114
+ if hasattr(item, "size") and item.size:
115
+ total_size += item.size
116
+ lines.append(f"{item.path} ({_format_size(item.size)})")
117
+ else:
118
+ lines.append(f"{item.path}/")
119
+
120
+ url = _build_repo_url(repo_id, repo_type)
121
+ response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
122
+
123
+ return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
124
+
125
+ async def _read(self, args: Dict[str, Any]) -> ToolResult:
126
+ """Read file content from a repository."""
127
+ repo_id = args.get("repo_id")
128
+ path = args.get("path")
129
+
130
+ if not repo_id:
131
+ return self._error("repo_id is required")
132
+ if not path:
133
+ return self._error("path is required")
134
+
135
+ repo_type = args.get("repo_type", "model")
136
+ revision = args.get("revision", "main")
137
+ max_chars = args.get("max_chars", 50000)
138
+
139
+ file_path = await _async_call(
140
+ hf_hub_download,
141
+ repo_id=repo_id,
142
+ filename=path,
143
+ repo_type=repo_type,
144
+ revision=revision,
145
+ token=self.api.token,
146
+ )
147
+
148
+ try:
149
+ with open(file_path, "r", encoding="utf-8") as f:
150
+ content = f.read()
151
+
152
+ truncated = len(content) > max_chars
153
+ if truncated:
154
+ content = content[:max_chars]
155
+
156
+ url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
157
+ response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
158
+
159
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
160
+
161
+ except UnicodeDecodeError:
162
+ import os
163
+ size = os.path.getsize(file_path)
164
+ return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
165
+
166
+ async def _upload(self, args: Dict[str, Any]) -> ToolResult:
167
+ """Upload content to a repository."""
168
+ repo_id = args.get("repo_id")
169
+ path = args.get("path")
170
+ content = args.get("content")
171
+
172
+ if not repo_id:
173
+ return self._error("repo_id is required")
174
+ if not path:
175
+ return self._error("path is required")
176
+ if content is None:
177
+ return self._error("content is required")
178
+
179
+ repo_type = args.get("repo_type", "model")
180
+ revision = args.get("revision", "main")
181
+ create_pr = args.get("create_pr", False)
182
+ commit_message = args.get("commit_message", f"Upload {path}")
183
+
184
+ file_bytes = content.encode("utf-8") if isinstance(content, str) else content
185
+
186
+ result = await _async_call(
187
+ self.api.upload_file,
188
+ path_or_fileobj=file_bytes,
189
+ path_in_repo=path,
190
+ repo_id=repo_id,
191
+ repo_type=repo_type,
192
+ revision=revision,
193
+ commit_message=commit_message,
194
+ create_pr=create_pr,
195
+ )
196
+
197
+ url = _build_repo_url(repo_id, repo_type)
198
+ if create_pr and hasattr(result, "pr_url"):
199
+ response = f"**Uploaded as PR**\n{result.pr_url}"
200
+ else:
201
+ response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
202
+
203
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
204
+
205
+ async def _delete(self, args: Dict[str, Any]) -> ToolResult:
206
+ """Delete files from a repository."""
207
+ repo_id = args.get("repo_id")
208
+ patterns = args.get("patterns")
209
+
210
+ if not repo_id:
211
+ return self._error("repo_id is required")
212
+ if not patterns:
213
+ return self._error("patterns is required (list of paths/wildcards)")
214
+
215
+ if isinstance(patterns, str):
216
+ patterns = [patterns]
217
+
218
+ repo_type = args.get("repo_type", "model")
219
+ revision = args.get("revision", "main")
220
+ create_pr = args.get("create_pr", False)
221
+ commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
222
+
223
+ await _async_call(
224
+ self.api.delete_files,
225
+ repo_id=repo_id,
226
+ delete_patterns=patterns,
227
+ repo_type=repo_type,
228
+ revision=revision,
229
+ commit_message=commit_message,
230
+ create_pr=create_pr,
231
+ )
232
+
233
+ response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
234
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
235
+
236
+ def _error(self, message: str) -> ToolResult:
237
+ """Return an error result."""
238
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
239
+
240
+
241
+ # Tool specification
242
+ HF_REPO_FILES_TOOL_SPEC = {
243
+ "name": "hf_repo_files",
244
+ "description": (
245
+ "Read and write files in HF repos (models/datasets/spaces).\n\n"
246
+ "## Operations\n"
247
+ "- **list**: List files with sizes and structure\n"
248
+ "- **read**: Read file content (text files only)\n"
249
+ "- **upload**: Upload content to repo (can create PR)\n"
250
+ "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
251
+ "## Use when\n"
252
+ "- Need to see what files exist in a repo\n"
253
+ "- Want to read config.json, README.md, or other text files\n"
254
+ "- Uploading training scripts, configs, or results to a repo\n"
255
+ "- Cleaning up temporary files from a repo\n\n"
256
+ "## Examples\n"
257
+ '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
258
+ '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
259
+ '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
260
+ '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
261
+ '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
262
+ "## Notes\n"
263
+ "- For binary files (safetensors, bin), use list to see them but can't read content\n"
264
+ "- upload/delete require approval (can overwrite/destroy data)\n"
265
+ "- Use create_pr=true to propose changes instead of direct commit\n"
266
+ ),
267
+ "parameters": {
268
+ "type": "object",
269
+ "properties": {
270
+ "operation": {
271
+ "type": "string",
272
+ "enum": ["list", "read", "upload", "delete"],
273
+ "description": "Operation: list, read, upload, delete",
274
+ },
275
+ "repo_id": {
276
+ "type": "string",
277
+ "description": "Repository ID (e.g., 'username/repo-name')",
278
+ },
279
+ "repo_type": {
280
+ "type": "string",
281
+ "enum": ["model", "dataset", "space"],
282
+ "description": "Repository type (default: model)",
283
+ },
284
+ "revision": {
285
+ "type": "string",
286
+ "description": "Branch/tag/commit (default: main)",
287
+ },
288
+ "path": {
289
+ "type": "string",
290
+ "description": "File path for read/upload",
291
+ },
292
+ "content": {
293
+ "type": "string",
294
+ "description": "File content for upload",
295
+ },
296
+ "patterns": {
297
+ "type": "array",
298
+ "items": {"type": "string"},
299
+ "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
300
+ },
301
+ "create_pr": {
302
+ "type": "boolean",
303
+ "description": "Create PR instead of direct commit",
304
+ },
305
+ "commit_message": {
306
+ "type": "string",
307
+ "description": "Custom commit message",
308
+ },
309
+ },
310
+ "required": ["operation"],
311
+ },
312
+ }
313
+
314
+
315
+ async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
316
+ """Handler for agent tool router."""
317
+ try:
318
+ tool = HfRepoFilesTool()
319
+ result = await tool.execute(arguments)
320
+ return result["formatted"], not result.get("isError", False)
321
+ except Exception as e:
322
+ return f"Error: {str(e)}", False
agent/tools/hf_repo_git_tool.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Git Tool - Git-like operations on Hugging Face repositories
3
+
4
+ Operations: branches, tags, PRs, repo management
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal[
16
+ "create_branch", "delete_branch",
17
+ "create_tag", "delete_tag",
18
+ "list_refs",
19
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
20
+ "create_repo", "update_repo",
21
+ ]
22
+
23
+
24
+ async def _async_call(func, *args, **kwargs):
25
+ """Wrap synchronous HfApi calls for async context."""
26
+ return await asyncio.to_thread(func, *args, **kwargs)
27
+
28
+
29
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
30
+ """Build the Hub URL for a repository."""
31
+ if repo_type == "model":
32
+ return f"https://huggingface.co/{repo_id}"
33
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
34
+
35
+
36
+ class HfRepoGitTool:
37
+ """Tool for git-like operations on HF repos."""
38
+
39
+ def __init__(self, hf_token: Optional[str] = None):
40
+ self.api = HfApi(token=hf_token)
41
+
42
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
43
+ """Execute the specified operation."""
44
+ operation = args.get("operation")
45
+
46
+ if not operation:
47
+ return self._help()
48
+
49
+ try:
50
+ handlers = {
51
+ "create_branch": self._create_branch,
52
+ "delete_branch": self._delete_branch,
53
+ "create_tag": self._create_tag,
54
+ "delete_tag": self._delete_tag,
55
+ "list_refs": self._list_refs,
56
+ "create_pr": self._create_pr,
57
+ "list_prs": self._list_prs,
58
+ "get_pr": self._get_pr,
59
+ "merge_pr": self._merge_pr,
60
+ "close_pr": self._close_pr,
61
+ "comment_pr": self._comment_pr,
62
+ "change_pr_status": self._change_pr_status,
63
+ "create_repo": self._create_repo,
64
+ "update_repo": self._update_repo,
65
+ }
66
+
67
+ handler = handlers.get(operation)
68
+ if handler:
69
+ return await handler(args)
70
+ else:
71
+ ops = ", ".join(handlers.keys())
72
+ return self._error(f"Unknown operation: {operation}. Valid: {ops}")
73
+
74
+ except RepositoryNotFoundError:
75
+ return self._error(f"Repository not found: {args.get('repo_id')}")
76
+ except Exception as e:
77
+ return self._error(f"Error: {str(e)}")
78
+
79
+ def _help(self) -> ToolResult:
80
+ """Show usage instructions."""
81
+ return {
82
+ "formatted": """**hf_repo_git** - Git-like operations on HF repos
83
+
84
+ **Branch/Tag:**
85
+ - `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
86
+ - `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
87
+ - `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
88
+ - `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
89
+ - `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
90
+
91
+ **PRs:**
92
+ - `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR)
93
+ - `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed)
94
+ - `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status)
95
+ - `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open)
96
+ - `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
97
+ - `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
98
+ - `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
99
+
100
+ **Repo:**
101
+ - `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
102
+ - `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
103
+ "totalResults": 1,
104
+ "resultsShared": 1,
105
+ }
106
+
107
+ # =========================================================================
108
+ # BRANCH OPERATIONS
109
+ # =========================================================================
110
+
111
+ async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
112
+ """Create a new branch."""
113
+ repo_id = args.get("repo_id")
114
+ branch = args.get("branch")
115
+
116
+ if not repo_id:
117
+ return self._error("repo_id is required")
118
+ if not branch:
119
+ return self._error("branch is required")
120
+
121
+ repo_type = args.get("repo_type", "model")
122
+ from_rev = args.get("from_rev", "main")
123
+
124
+ await _async_call(
125
+ self.api.create_branch,
126
+ repo_id=repo_id,
127
+ branch=branch,
128
+ revision=from_rev,
129
+ repo_type=repo_type,
130
+ exist_ok=args.get("exist_ok", False),
131
+ )
132
+
133
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
134
+ return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
135
+
136
+ async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
137
+ """Delete a branch."""
138
+ repo_id = args.get("repo_id")
139
+ branch = args.get("branch")
140
+
141
+ if not repo_id:
142
+ return self._error("repo_id is required")
143
+ if not branch:
144
+ return self._error("branch is required")
145
+
146
+ repo_type = args.get("repo_type", "model")
147
+
148
+ await _async_call(
149
+ self.api.delete_branch,
150
+ repo_id=repo_id,
151
+ branch=branch,
152
+ repo_type=repo_type,
153
+ )
154
+
155
+ return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
156
+
157
+ # =========================================================================
158
+ # TAG OPERATIONS
159
+ # =========================================================================
160
+
161
+ async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
162
+ """Create a tag."""
163
+ repo_id = args.get("repo_id")
164
+ tag = args.get("tag")
165
+
166
+ if not repo_id:
167
+ return self._error("repo_id is required")
168
+ if not tag:
169
+ return self._error("tag is required")
170
+
171
+ repo_type = args.get("repo_type", "model")
172
+ revision = args.get("revision", "main")
173
+ tag_message = args.get("tag_message", "")
174
+
175
+ await _async_call(
176
+ self.api.create_tag,
177
+ repo_id=repo_id,
178
+ tag=tag,
179
+ revision=revision,
180
+ tag_message=tag_message,
181
+ repo_type=repo_type,
182
+ exist_ok=args.get("exist_ok", False),
183
+ )
184
+
185
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
186
+ return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
187
+
188
+ async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
189
+ """Delete a tag."""
190
+ repo_id = args.get("repo_id")
191
+ tag = args.get("tag")
192
+
193
+ if not repo_id:
194
+ return self._error("repo_id is required")
195
+ if not tag:
196
+ return self._error("tag is required")
197
+
198
+ repo_type = args.get("repo_type", "model")
199
+
200
+ await _async_call(
201
+ self.api.delete_tag,
202
+ repo_id=repo_id,
203
+ tag=tag,
204
+ repo_type=repo_type,
205
+ )
206
+
207
+ return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
208
+
209
+ # =========================================================================
210
+ # LIST REFS
211
+ # =========================================================================
212
+
213
+ async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
214
+ """List branches and tags."""
215
+ repo_id = args.get("repo_id")
216
+
217
+ if not repo_id:
218
+ return self._error("repo_id is required")
219
+
220
+ repo_type = args.get("repo_type", "model")
221
+
222
+ refs = await _async_call(
223
+ self.api.list_repo_refs,
224
+ repo_id=repo_id,
225
+ repo_type=repo_type,
226
+ )
227
+
228
+ branches = [b.name for b in refs.branches] if refs.branches else []
229
+ tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
230
+
231
+ url = _build_repo_url(repo_id, repo_type)
232
+ lines = [f"**{repo_id}**", url, ""]
233
+
234
+ if branches:
235
+ lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
236
+ else:
237
+ lines.append("**Branches:** none")
238
+
239
+ if tags:
240
+ lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
241
+ else:
242
+ lines.append("**Tags:** none")
243
+
244
+ return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
245
+
246
+ # =========================================================================
247
+ # PR OPERATIONS
248
+ # =========================================================================
249
+
250
+ async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
251
+ """Create a pull request."""
252
+ repo_id = args.get("repo_id")
253
+ title = args.get("title")
254
+
255
+ if not repo_id:
256
+ return self._error("repo_id is required")
257
+ if not title:
258
+ return self._error("title is required")
259
+
260
+ repo_type = args.get("repo_type", "model")
261
+ description = args.get("description", "")
262
+
263
+ result = await _async_call(
264
+ self.api.create_pull_request,
265
+ repo_id=repo_id,
266
+ title=title,
267
+ description=description,
268
+ repo_type=repo_type,
269
+ )
270
+
271
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
272
+ return {
273
+ "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
274
+ "totalResults": 1,
275
+ "resultsShared": 1,
276
+ }
277
+
278
+ async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
279
+ """List PRs and discussions."""
280
+ repo_id = args.get("repo_id")
281
+
282
+ if not repo_id:
283
+ return self._error("repo_id is required")
284
+
285
+ repo_type = args.get("repo_type", "model")
286
+ status = args.get("status", "all") # open, closed, all
287
+
288
+ discussions = list(self.api.get_repo_discussions(
289
+ repo_id=repo_id,
290
+ repo_type=repo_type,
291
+ discussion_status=status if status != "all" else None,
292
+ ))
293
+
294
+ if not discussions:
295
+ return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
296
+
297
+ url = _build_repo_url(repo_id, repo_type)
298
+ lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
299
+
300
+ for d in discussions[:20]:
301
+ if d.status == "draft":
302
+ status_label = "[DRAFT]"
303
+ elif d.status == "open":
304
+ status_label = "[OPEN]"
305
+ elif d.status == "merged":
306
+ status_label = "[MERGED]"
307
+ else:
308
+ status_label = "[CLOSED]"
309
+ type_label = "PR" if d.is_pull_request else "D"
310
+ lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
311
+
312
+ return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
313
+
314
+ async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
315
+ """Get PR details."""
316
+ repo_id = args.get("repo_id")
317
+ pr_num = args.get("pr_num")
318
+
319
+ if not repo_id:
320
+ return self._error("repo_id is required")
321
+ if not pr_num:
322
+ return self._error("pr_num is required")
323
+
324
+ repo_type = args.get("repo_type", "model")
325
+
326
+ pr = await _async_call(
327
+ self.api.get_discussion_details,
328
+ repo_id=repo_id,
329
+ discussion_num=int(pr_num),
330
+ repo_type=repo_type,
331
+ )
332
+
333
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
334
+ status_map = {
335
+ "draft": "Draft",
336
+ "open": "Open",
337
+ "merged": "Merged",
338
+ "closed": "Closed"
339
+ }
340
+ status = status_map.get(pr.status, pr.status.capitalize())
341
+ type_label = "Pull Request" if pr.is_pull_request else "Discussion"
342
+
343
+ lines = [
344
+ f"**{type_label} #{pr_num}:** {pr.title}",
345
+ f"**Status:** {status}",
346
+ f"**Author:** {pr.author}",
347
+ url,
348
+ ]
349
+
350
+ if pr.is_pull_request:
351
+ if pr.status == "draft":
352
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
353
+ elif pr.status == "open":
354
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
355
+
356
+ return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
357
+
358
+ async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
359
+ """Merge a pull request."""
360
+ repo_id = args.get("repo_id")
361
+ pr_num = args.get("pr_num")
362
+
363
+ if not repo_id:
364
+ return self._error("repo_id is required")
365
+ if not pr_num:
366
+ return self._error("pr_num is required")
367
+
368
+ repo_type = args.get("repo_type", "model")
369
+ comment = args.get("comment", "")
370
+
371
+ await _async_call(
372
+ self.api.merge_pull_request,
373
+ repo_id=repo_id,
374
+ discussion_num=int(pr_num),
375
+ comment=comment,
376
+ repo_type=repo_type,
377
+ )
378
+
379
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
380
+ return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
381
+
382
+ async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
383
+ """Close a PR/discussion."""
384
+ repo_id = args.get("repo_id")
385
+ pr_num = args.get("pr_num")
386
+
387
+ if not repo_id:
388
+ return self._error("repo_id is required")
389
+ if not pr_num:
390
+ return self._error("pr_num is required")
391
+
392
+ repo_type = args.get("repo_type", "model")
393
+ comment = args.get("comment", "")
394
+
395
+ await _async_call(
396
+ self.api.change_discussion_status,
397
+ repo_id=repo_id,
398
+ discussion_num=int(pr_num),
399
+ new_status="closed",
400
+ comment=comment,
401
+ repo_type=repo_type,
402
+ )
403
+
404
+ return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
405
+
406
+ async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
407
+ """Add a comment to a PR/discussion."""
408
+ repo_id = args.get("repo_id")
409
+ pr_num = args.get("pr_num")
410
+ comment = args.get("comment")
411
+
412
+ if not repo_id:
413
+ return self._error("repo_id is required")
414
+ if not pr_num:
415
+ return self._error("pr_num is required")
416
+ if not comment:
417
+ return self._error("comment is required")
418
+
419
+ repo_type = args.get("repo_type", "model")
420
+
421
+ await _async_call(
422
+ self.api.comment_discussion,
423
+ repo_id=repo_id,
424
+ discussion_num=int(pr_num),
425
+ comment=comment,
426
+ repo_type=repo_type,
427
+ )
428
+
429
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
430
+ return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
431
+
432
+ async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
433
+ """Change PR/discussion status (mainly to convert draft to open)."""
434
+ repo_id = args.get("repo_id")
435
+ pr_num = args.get("pr_num")
436
+ new_status = args.get("new_status")
437
+
438
+ if not repo_id:
439
+ return self._error("repo_id is required")
440
+ if not pr_num:
441
+ return self._error("pr_num is required")
442
+ if not new_status:
443
+ return self._error("new_status is required (open or closed)")
444
+
445
+ repo_type = args.get("repo_type", "model")
446
+ comment = args.get("comment", "")
447
+
448
+ await _async_call(
449
+ self.api.change_discussion_status,
450
+ repo_id=repo_id,
451
+ discussion_num=int(pr_num),
452
+ new_status=new_status,
453
+ comment=comment,
454
+ repo_type=repo_type,
455
+ )
456
+
457
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
458
+ return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1}
459
+
460
+ # =========================================================================
461
+ # REPO MANAGEMENT
462
+ # =========================================================================
463
+
464
+ async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
465
+ """Create a new repository."""
466
+ repo_id = args.get("repo_id")
467
+
468
+ if not repo_id:
469
+ return self._error("repo_id is required")
470
+
471
+ repo_type = args.get("repo_type", "model")
472
+ private = args.get("private", True)
473
+ space_sdk = args.get("space_sdk")
474
+
475
+ if repo_type == "space" and not space_sdk:
476
+ return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
477
+
478
+ kwargs = {
479
+ "repo_id": repo_id,
480
+ "repo_type": repo_type,
481
+ "private": private,
482
+ "exist_ok": args.get("exist_ok", False),
483
+ }
484
+ if space_sdk:
485
+ kwargs["space_sdk"] = space_sdk
486
+
487
+ result = await _async_call(self.api.create_repo, **kwargs)
488
+
489
+ return {
490
+ "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
491
+ "totalResults": 1,
492
+ "resultsShared": 1,
493
+ }
494
+
495
+ async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
496
+ """Update repository settings."""
497
+ repo_id = args.get("repo_id")
498
+
499
+ if not repo_id:
500
+ return self._error("repo_id is required")
501
+
502
+ repo_type = args.get("repo_type", "model")
503
+ private = args.get("private")
504
+ gated = args.get("gated")
505
+
506
+ if private is None and gated is None:
507
+ return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
508
+
509
+ kwargs = {"repo_id": repo_id, "repo_type": repo_type}
510
+ if private is not None:
511
+ kwargs["private"] = private
512
+ if gated is not None:
513
+ kwargs["gated"] = gated
514
+
515
+ await _async_call(self.api.update_repo_settings, **kwargs)
516
+
517
+ changes = []
518
+ if private is not None:
519
+ changes.append(f"private={private}")
520
+ if gated is not None:
521
+ changes.append(f"gated={gated}")
522
+
523
+ url = f"{_build_repo_url(repo_id, repo_type)}/settings"
524
+ return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
525
+
526
+ def _error(self, message: str) -> ToolResult:
527
+ """Return an error result."""
528
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
529
+
530
+
531
+ # Tool specification
532
+ HF_REPO_GIT_TOOL_SPEC = {
533
+ "name": "hf_repo_git",
534
+ "description": (
535
+ "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
536
+ "## Operations\n"
537
+ "**Branches:** create_branch, delete_branch, list_refs\n"
538
+ "**Tags:** create_tag, delete_tag\n"
539
+ "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n"
540
+ "**Repo:** create_repo, update_repo\n\n"
541
+ "## Use when\n"
542
+ "- Creating feature branches for experiments\n"
543
+ "- Tagging model versions (v1.0, v2.0)\n"
544
+ "- Opening PRs to contribute to repos you don't own\n"
545
+ "- Reviewing and merging PRs on your repos\n"
546
+ "- Creating new model/dataset/space repos\n"
547
+ "- Changing repo visibility (public/private) or gated access\n\n"
548
+ "## Examples\n"
549
+ '{"operation": "list_refs", "repo_id": "my-model"}\n'
550
+ '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
551
+ '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
552
+ '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
553
+ '{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n'
554
+ '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
555
+ '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
556
+ '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
557
+ "## PR Workflow\n"
558
+ "1. create_pr → creates draft PR (empty by default)\n"
559
+ "2. Upload files with revision='refs/pr/N' to add commits\n"
560
+ "3. change_pr_status with new_status='open' to publish (convert draft to open)\n"
561
+ "4. merge_pr when ready\n\n"
562
+ "## Notes\n"
563
+ "- PR status: draft (default), open, merged, closed\n"
564
+ "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
565
+ "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
566
+ "- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
567
+ ),
568
+ "parameters": {
569
+ "type": "object",
570
+ "properties": {
571
+ "operation": {
572
+ "type": "string",
573
+ "enum": [
574
+ "create_branch", "delete_branch",
575
+ "create_tag", "delete_tag", "list_refs",
576
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
577
+ "create_repo", "update_repo",
578
+ ],
579
+ "description": "Operation to execute",
580
+ },
581
+ "repo_id": {
582
+ "type": "string",
583
+ "description": "Repository ID (e.g., 'username/repo-name')",
584
+ },
585
+ "repo_type": {
586
+ "type": "string",
587
+ "enum": ["model", "dataset", "space"],
588
+ "description": "Repository type (default: model)",
589
+ },
590
+ "branch": {
591
+ "type": "string",
592
+ "description": "Branch name (create_branch, delete_branch)",
593
+ },
594
+ "from_rev": {
595
+ "type": "string",
596
+ "description": "Create branch from this revision (default: main)",
597
+ },
598
+ "tag": {
599
+ "type": "string",
600
+ "description": "Tag name (create_tag, delete_tag)",
601
+ },
602
+ "revision": {
603
+ "type": "string",
604
+ "description": "Revision for tag (default: main)",
605
+ },
606
+ "tag_message": {
607
+ "type": "string",
608
+ "description": "Tag description",
609
+ },
610
+ "title": {
611
+ "type": "string",
612
+ "description": "PR title (create_pr)",
613
+ },
614
+ "description": {
615
+ "type": "string",
616
+ "description": "PR description (create_pr)",
617
+ },
618
+ "pr_num": {
619
+ "type": "integer",
620
+ "description": "PR/discussion number",
621
+ },
622
+ "comment": {
623
+ "type": "string",
624
+ "description": "Comment text",
625
+ },
626
+ "status": {
627
+ "type": "string",
628
+ "enum": ["open", "closed", "all"],
629
+ "description": "Filter PRs by status (list_prs)",
630
+ },
631
+ "new_status": {
632
+ "type": "string",
633
+ "enum": ["open", "closed"],
634
+ "description": "New status for PR/discussion (change_pr_status)",
635
+ },
636
+ "private": {
637
+ "type": "boolean",
638
+ "description": "Make repo private (create_repo, update_repo)",
639
+ },
640
+ "gated": {
641
+ "type": "string",
642
+ "enum": ["auto", "manual", "false"],
643
+ "description": "Gated access setting (update_repo)",
644
+ },
645
+ "space_sdk": {
646
+ "type": "string",
647
+ "enum": ["gradio", "streamlit", "docker", "static"],
648
+ "description": "Space SDK (required for create_repo with space)",
649
+ },
650
+ },
651
+ "required": ["operation"],
652
+ },
653
+ }
654
+
655
+
656
+ async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
657
+ """Handler for agent tool router."""
658
+ try:
659
+ tool = HfRepoGitTool()
660
+ result = await tool.execute(arguments)
661
+ return result["formatted"], not result.get("isError", False)
662
+ except Exception as e:
663
+ return f"Error: {str(e)}", False
agent/tools/utils_tools.py DELETED
@@ -1,203 +0,0 @@
1
- """
2
- Utils Tools - General utility operations
3
-
4
- Provides system information like current date/time with timezone support.
5
- """
6
-
7
- import zoneinfo
8
- from datetime import datetime
9
- from typing import Any, Dict, Literal
10
-
11
- from agent.tools.types import ToolResult
12
-
13
- # Operation names
14
- OperationType = Literal["get_datetime"]
15
-
16
-
17
- class UtilsTool:
18
- """Tool for general utility operations."""
19
-
20
- async def execute(self, params: Dict[str, Any]) -> ToolResult:
21
- """Execute the specified utility operation."""
22
- operation = params.get("operation")
23
- args = params.get("args", {})
24
-
25
- # If no operation provided, return usage instructions
26
- if not operation:
27
- return self._show_help()
28
-
29
- # Normalize operation name
30
- operation = operation.lower()
31
-
32
- # Check if help is requested
33
- if args.get("help"):
34
- return self._show_operation_help(operation)
35
-
36
- try:
37
- # Route to appropriate handler
38
- if operation == "get_datetime":
39
- return await self._get_datetime(args)
40
- else:
41
- return {
42
- "formatted": f'Unknown operation: "{operation}"\n\n'
43
- "Available operations: get_datetime\n\n"
44
- "Call this tool with no operation for full usage instructions.",
45
- "totalResults": 0,
46
- "resultsShared": 0,
47
- "isError": True,
48
- }
49
-
50
- except Exception as e:
51
- return {
52
- "formatted": f"Error executing {operation}: {str(e)}",
53
- "totalResults": 0,
54
- "resultsShared": 0,
55
- "isError": True,
56
- }
57
-
58
- def _show_help(self) -> ToolResult:
59
- """Show usage instructions when tool is called with no arguments."""
60
- usage_text = """# Utils Tool
61
-
62
- Utility operations for system information.
63
-
64
- ## Available Commands
65
-
66
- - **get_datetime** - Get current date and time with timezone support
67
-
68
- ## Examples
69
-
70
- ### Get current date and time (Paris timezone by default)
71
- Call this tool with:
72
- ```json
73
- {
74
- "operation": "get_datetime",
75
- "args": {}
76
- }
77
- ```
78
-
79
- ### Get current date and time in a specific timezone
80
- Call this tool with:
81
- ```json
82
- {
83
- "operation": "get_datetime",
84
- "args": {
85
- "timezone": "America/New_York"
86
- }
87
- }
88
- ```
89
-
90
- Common timezones: Europe/Paris, America/New_York, America/Los_Angeles, Asia/Tokyo, UTC
91
-
92
- ## Tips
93
-
94
- - **Default timezone**: Paris (Europe/Paris)
95
- - **Date format**: dd-mm-yyyy
96
- - **Time format**: HH:MM:SS.mmm (24-hour format with milliseconds)
97
- - **Timezone names**: Use IANA timezone database names (e.g., "Europe/Paris", "UTC")
98
- """
99
- return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
100
-
101
- def _show_operation_help(self, operation: str) -> ToolResult:
102
- """Show help for a specific operation."""
103
- help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
104
- return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
105
-
106
- async def _get_datetime(self, args: Dict[str, Any]) -> ToolResult:
107
- """Get current date and time with timezone support."""
108
- timezone_name = args.get("timezone", "Europe/Paris")
109
-
110
- try:
111
- # Get timezone object
112
- tz = zoneinfo.ZoneInfo(timezone_name)
113
-
114
- # Get current datetime in specified timezone
115
- now = datetime.now(tz)
116
-
117
- # Format date as dd-mm-yyyy
118
- date_str = now.strftime("%d-%m-%Y")
119
-
120
- # Format time as HH:MM:SS.mmm
121
- time_str = now.strftime("%H:%M:%S.%f")[
122
- :-3
123
- ] # Remove last 3 digits to keep only milliseconds
124
-
125
- # Get timezone abbreviation/offset
126
- tz_offset = now.strftime("%z")
127
- tz_name = now.strftime("%Z")
128
-
129
- response = f"""✓ Current date and time
130
-
131
- **Date:** {date_str}
132
- **Time:** {time_str}
133
- **Timezone:** {timezone_name} ({tz_name}, UTC{tz_offset[:3]}:{tz_offset[3:]})
134
-
135
- **ISO Format:** {now.isoformat()}
136
- **Unix Timestamp:** {int(now.timestamp())}"""
137
-
138
- return {"formatted": response, "totalResults": 1, "resultsShared": 1}
139
-
140
- except zoneinfo.ZoneInfoNotFoundError:
141
- return {
142
- "formatted": f"Invalid timezone: {timezone_name}\n\n"
143
- "Use IANA timezone database names like:\n"
144
- "- Europe/Paris\n"
145
- "- America/New_York\n"
146
- "- Asia/Tokyo\n"
147
- "- UTC\n\n"
148
- "See: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones",
149
- "totalResults": 0,
150
- "resultsShared": 0,
151
- "isError": True,
152
- }
153
- except Exception as e:
154
- return {
155
- "formatted": f"Failed to get date/time: {str(e)}",
156
- "totalResults": 0,
157
- "resultsShared": 0,
158
- "isError": True,
159
- }
160
-
161
-
162
- # Tool specification for agent registration
163
- UTILS_TOOL_SPEC = {
164
- "name": "utils",
165
- "description": (
166
- "System utility operations - currently provides date/time with timezone support. "
167
- "**Use when:** (1) Need current date for logging/timestamps, (2) User asks 'what time is it', "
168
- "(3) Need timezone-aware datetime for scheduling/coordination, (4) Creating timestamped filenames. "
169
- "**Operation:** get_datetime with optional timezone parameter (default: Europe/Paris). "
170
- "Returns: Date (dd-mm-yyyy), time (HH:MM:SS.mmm), timezone info, ISO format, Unix timestamp. "
171
- "**Pattern:** utils get_datetime → use timestamp in filename/log → upload to hf_private_repos. "
172
- "Supports IANA timezone names: 'Europe/Paris', 'America/New_York', 'Asia/Tokyo', 'UTC'."
173
- ),
174
- "parameters": {
175
- "type": "object",
176
- "properties": {
177
- "operation": {
178
- "type": "string",
179
- "enum": ["get_datetime"],
180
- "description": "Operation to execute. Valid values: [get_datetime]",
181
- },
182
- "args": {
183
- "type": "object",
184
- "description": (
185
- "Operation-specific arguments as a JSON object. "
186
- "For get_datetime: timezone (string, optional, default: Europe/Paris). "
187
- "Use IANA timezone names like 'America/New_York', 'Asia/Tokyo', 'UTC'."
188
- ),
189
- "additionalProperties": True,
190
- },
191
- },
192
- },
193
- }
194
-
195
-
196
- async def utils_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
197
- """Handler for agent tool router."""
198
- try:
199
- tool = UtilsTool()
200
- result = await tool.execute(arguments)
201
- return result["formatted"], not result.get("isError", False)
202
- except Exception as e:
203
- return f"Error executing Utils tool: {str(e)}", False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/utils/__init__.py CHANGED
@@ -1,7 +1,3 @@
1
  """
2
  Utility functions and helpers
3
  """
4
-
5
- from agent.utils.logging import setup_logger
6
-
7
- __all__ = ["setup_logger"]
 
1
  """
2
  Utility functions and helpers
3
  """
 
 
 
 
agent/utils/logging.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- Logging utilities
3
- """
4
-
5
- import logging
6
- import sys
7
- from pathlib import Path
8
- from typing import Optional
9
-
10
-
11
- def setup_logger(
12
- name: str = "hf_agent", level: int = logging.INFO, log_file: Optional[Path] = None
13
- ) -> logging.Logger:
14
- """Setup and configure logger"""
15
-
16
- logger = logging.getLogger(name)
17
- logger.setLevel(level)
18
-
19
- # Remove existing handlers
20
- logger.handlers = []
21
-
22
- # Console handler
23
- console_handler = logging.StreamHandler(sys.stdout)
24
- console_handler.setLevel(level)
25
- console_format = logging.Formatter(
26
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
27
- datefmt="%Y-%m-%d %H:%M:%S",
28
- )
29
- console_handler.setFormatter(console_format)
30
- logger.addHandler(console_handler)
31
-
32
- # File handler if log_file specified
33
- if log_file:
34
- log_file.parent.mkdir(parents=True, exist_ok=True)
35
- file_handler = logging.FileHandler(log_file)
36
- file_handler.setLevel(level)
37
- file_handler.setFormatter(console_format)
38
- logger.addHandler(file_handler)
39
-
40
- return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/utils/terminal_display.py CHANGED
@@ -94,13 +94,18 @@ def format_tool_call(tool_name: str, arguments: str) -> str:
94
 
95
  def format_tool_output(output: str, success: bool, truncate: bool = True) -> str:
96
  """Format tool output with color and optional truncation"""
 
97
  if truncate:
98
  output = truncate_to_lines(output, max_lines=6)
99
 
100
  if success:
101
- return f"{Colors.YELLOW}Tool output:{Colors.RESET}\n{output}"
 
 
102
  else:
103
- return f"{Colors.RED}Tool output:{Colors.RESET}\n{output}"
 
 
104
 
105
 
106
  def format_turn_complete() -> str:
 
94
 
95
  def format_tool_output(output: str, success: bool, truncate: bool = True) -> str:
96
  """Format tool output with color and optional truncation"""
97
+ original_length = len(output)
98
  if truncate:
99
  output = truncate_to_lines(output, max_lines=6)
100
 
101
  if success:
102
+ return (
103
+ f"{Colors.YELLOW}Tool output ({original_length} tkns): {Colors.RESET}\n{output}"
104
+ )
105
  else:
106
+ return (
107
+ f"{Colors.RED}Tool output ({original_length} tokens): {Colors.RESET}\n{output}"
108
+ )
109
 
110
 
111
  def format_turn_complete() -> str:
pyproject.toml CHANGED
@@ -24,6 +24,7 @@ agent = [
24
  "nbconvert>=7.16.6",
25
  "nbformat>=5.10.4",
26
  "datasets>=4.3.0", # For session logging to HF datasets
 
27
  ]
28
 
29
  # Evaluation/benchmarking dependencies
 
24
  "nbconvert>=7.16.6",
25
  "nbformat>=5.10.4",
26
  "datasets>=4.3.0", # For session logging to HF datasets
27
+ "whoosh>=2.7.4",
28
  ]
29
 
30
  # Evaluation/benchmarking dependencies
test_dataset_tools.py DELETED
@@ -1,79 +0,0 @@
1
- """
2
- Test script for unified dataset inspection tool
3
- """
4
-
5
- import asyncio
6
- import sys
7
- from typing import TypedDict
8
- from unittest.mock import MagicMock
9
-
10
-
11
- # Mock the types module before importing dataset_tools
12
- class ToolResult(TypedDict, total=False):
13
- formatted: str
14
- totalResults: int
15
- resultsShared: int
16
- isError: bool
17
-
18
-
19
- mock_types = MagicMock()
20
- mock_types.ToolResult = ToolResult
21
- sys.modules["agent.tools.types"] = mock_types
22
-
23
- # Now import directly from the file
24
- sys.path.insert(0, "/Users/akseljoonas/Documents/hf-agent/agent/tools")
25
- from dataset_tools import hf_inspect_dataset_handler, inspect_dataset
26
-
27
-
28
- async def test_inspect_dataset():
29
- """Test the unified inspect_dataset function"""
30
- print("=" * 70)
31
- print("Testing inspect_dataset()")
32
- print("=" * 70)
33
-
34
- # Test with akseljoonas/hf-agent-sessions as specified
35
- print("\n→ inspect_dataset('akseljoonas/hf-agent-sessions'):")
36
- result = await inspect_dataset("akseljoonas/hf-agent-sessions")
37
- print(f" isError: {result['isError']}")
38
- print(f" Output:\n{result['formatted']}")
39
-
40
- print("\n" + "=" * 70)
41
-
42
- # # Test with stanfordnlp/imdb
43
- # print("\n→ inspect_dataset('stanfordnlp/imdb'):")
44
- # result = await inspect_dataset("stanfordnlp/imdb")
45
- # print(f" isError: {result['isError']}")
46
- # print(f" Output:\n{result['formatted']}")
47
-
48
- # print("\n" + "=" * 70)
49
-
50
- # # Test with multi-config dataset
51
- # print("\n→ inspect_dataset('nyu-mll/glue', config='mrpc'):")
52
- # result = await inspect_dataset("nyu-mll/glue", config="mrpc")
53
- # print(f" isError: {result['isError']}")
54
- # print(f" Output:\n{result['formatted']}")
55
-
56
-
57
- async def test_handler():
58
- """Test the handler (what the agent calls)"""
59
- print("\n" + "=" * 70)
60
- print("Testing hf_inspect_dataset_handler()")
61
- print("=" * 70)
62
-
63
- result, success = await hf_inspect_dataset_handler(
64
- {
65
- "dataset": "stanfordnlp/imdb",
66
- "sample_rows": 2,
67
- }
68
- )
69
- print("\n→ Handler result:")
70
- print(f" success: {success}")
71
- print(f" output:\n{result}")
72
-
73
-
74
- if __name__ == "__main__":
75
- print("\nUnified Dataset Inspection Tool Test\n")
76
- asyncio.run(test_inspect_dataset())
77
- # asyncio.run(test_handler())
78
- print("\n" + "=" * 70)
79
- print("Done!")