Henri Bonamy commited on
Commit
fbad7f5
·
2 Parent(s): 146ac6f8406ff4

Merge pull request #19 from huggingface/feat/hf-repo-tools

Browse files
.gitignore CHANGED
@@ -16,4 +16,5 @@ wheels/
16
  /logs
17
  hf-agent-leaderboard/
18
  .cursor/
19
- session_logs/skills/
 
 
16
  /logs
17
  hf-agent-leaderboard/
18
  .cursor/
19
+ session_logs/
20
+ skills/
agent/core/agent_loop.py CHANGED
@@ -76,7 +76,19 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
-
 
 
 
 
 
 
 
 
 
 
 
 
80
  return False
81
 
82
 
 
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
+
80
+ # hf_repo_files: upload (can overwrite) and delete require approval
81
+ if tool_name == "hf_repo_files":
82
+ operation = tool_args.get("operation", "")
83
+ if operation in ["upload", "delete"]:
84
+ return True
85
+
86
+ # hf_repo_git: destructive operations require approval
87
+ if tool_name == "hf_repo_git":
88
+ operation = tool_args.get("operation", "")
89
+ if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
90
+ return True
91
+
92
  return False
93
 
94
 
agent/core/tools.py CHANGED
@@ -37,13 +37,20 @@ from agent.tools.github_read_file import (
37
  )
38
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
39
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
40
- from agent.tools.private_hf_repo_tools import (
41
- PRIVATE_HF_REPO_TOOL_SPEC,
42
- private_hf_repo_handler,
 
 
 
 
43
  )
44
 
45
- # NOTE: Utils tool disabled - date/time now loaded into system prompt at initialization
46
- # from agent.tools.utils_tools import UTILS_TOOL_SPEC, utils_handler
 
 
 
47
 
48
  # Suppress aiohttp deprecation warning
49
  warnings.filterwarnings(
@@ -281,20 +288,21 @@ def create_builtin_tools() -> list[ToolSpec]:
281
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
282
  handler=hf_jobs_handler,
283
  ),
 
284
  ToolSpec(
285
- name=PRIVATE_HF_REPO_TOOL_SPEC["name"],
286
- description=PRIVATE_HF_REPO_TOOL_SPEC["description"],
287
- parameters=PRIVATE_HF_REPO_TOOL_SPEC["parameters"],
288
- handler=private_hf_repo_handler,
289
  ),
290
- # NOTE: Utils tool disabled - date/time now loaded into system prompt at initialization (less tool calls=more reliablity)
291
- # ToolSpec(
292
- # name=UTILS_TOOL_SPEC["name"],
293
- # description=UTILS_TOOL_SPEC["description"],
294
- # parameters=UTILS_TOOL_SPEC["parameters"],
295
- # handler=utils_handler,
296
- # ),
297
- # GitHub tools
298
  # NOTE: Github search code tool disabled - a bit buggy
299
  # ToolSpec(
300
  # name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
 
37
  )
38
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
39
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
40
+ from agent.tools.hf_repo_files_tool import (
41
+ HF_REPO_FILES_TOOL_SPEC,
42
+ hf_repo_files_handler,
43
+ )
44
+ from agent.tools.hf_repo_git_tool import (
45
+ HF_REPO_GIT_TOOL_SPEC,
46
+ hf_repo_git_handler,
47
  )
48
 
49
+ # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
50
+ # from agent.tools.private_hf_repo_tools import (
51
+ # PRIVATE_HF_REPO_TOOL_SPEC,
52
+ # private_hf_repo_handler,
53
+ # )
54
 
55
  # Suppress aiohttp deprecation warning
56
  warnings.filterwarnings(
 
288
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
289
  handler=hf_jobs_handler,
290
  ),
291
+ # HF Repo management tools
292
  ToolSpec(
293
+ name=HF_REPO_FILES_TOOL_SPEC["name"],
294
+ description=HF_REPO_FILES_TOOL_SPEC["description"],
295
+ parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
296
+ handler=hf_repo_files_handler,
297
  ),
298
+ ToolSpec(
299
+ name=HF_REPO_GIT_TOOL_SPEC["name"],
300
+ description=HF_REPO_GIT_TOOL_SPEC["description"],
301
+ parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
302
+ handler=hf_repo_git_handler,
303
+ ),
304
+
305
+
306
  # NOTE: Github search code tool disabled - a bit buggy
307
  # ToolSpec(
308
  # name=GITHUB_SEARCH_CODE_TOOL_SPEC["name"],
agent/main.py CHANGED
@@ -287,6 +287,95 @@ async def event_listener(
287
  if len(all_lines) > 5:
288
  print("...")
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  # Get user decision for this item
291
  response = await prompt_session.prompt_async(
292
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
 
287
  if len(all_lines) > 5:
288
  print("...")
289
 
290
+ elif tool_name == "hf_repo_files":
291
+ # Handle repo files operations (upload, delete)
292
+ repo_id = arguments.get("repo_id", "")
293
+ repo_type = arguments.get("repo_type", "model")
294
+ revision = arguments.get("revision", "main")
295
+
296
+ # Build repo URL
297
+ if repo_type == "model":
298
+ repo_url = f"https://huggingface.co/{repo_id}"
299
+ else:
300
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
301
+
302
+ print(f"Repository: {repo_id}")
303
+ print(f"Type: {repo_type}")
304
+ print(f"Branch: {revision}")
305
+ print(f"URL: {repo_url}")
306
+
307
+ if operation == "upload":
308
+ path = arguments.get("path", "")
309
+ content = arguments.get("content", "")
310
+ create_pr = arguments.get("create_pr", False)
311
+
312
+ print(f"File: {path}")
313
+ if create_pr:
314
+ print("Mode: Create PR")
315
+
316
+ if isinstance(content, str):
317
+ all_lines = content.split("\n")
318
+ line_count = len(all_lines)
319
+ size_bytes = len(content.encode("utf-8"))
320
+ size_kb = size_bytes / 1024
321
+
322
+ print(f"Lines: {line_count}")
323
+ if size_kb < 1024:
324
+ print(f"Size: {size_kb:.2f} KB")
325
+ else:
326
+ print(f"Size: {size_kb / 1024:.2f} MB")
327
+
328
+ # Show full content
329
+ print(f"Content:\n{content}")
330
+
331
+ elif operation == "delete":
332
+ patterns = arguments.get("patterns", [])
333
+ if isinstance(patterns, str):
334
+ patterns = [patterns]
335
+ print(f"Patterns to delete: {', '.join(patterns)}")
336
+
337
+ elif tool_name == "hf_repo_git":
338
+ # Handle git operations (branches, tags, PRs, repo management)
339
+ repo_id = arguments.get("repo_id", "")
340
+ repo_type = arguments.get("repo_type", "model")
341
+
342
+ # Build repo URL
343
+ if repo_type == "model":
344
+ repo_url = f"https://huggingface.co/{repo_id}"
345
+ else:
346
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
347
+
348
+ print(f"Repository: {repo_id}")
349
+ print(f"Type: {repo_type}")
350
+ print(f"URL: {repo_url}")
351
+
352
+ if operation == "delete_branch":
353
+ branch = arguments.get("branch", "")
354
+ print(f"Branch to delete: {branch}")
355
+
356
+ elif operation == "delete_tag":
357
+ tag = arguments.get("tag", "")
358
+ print(f"Tag to delete: {tag}")
359
+
360
+ elif operation == "merge_pr":
361
+ pr_num = arguments.get("pr_num", "")
362
+ print(f"PR to merge: #{pr_num}")
363
+
364
+ elif operation == "create_repo":
365
+ private = arguments.get("private", False)
366
+ space_sdk = arguments.get("space_sdk")
367
+ print(f"Private: {private}")
368
+ if space_sdk:
369
+ print(f"Space SDK: {space_sdk}")
370
+
371
+ elif operation == "update_repo":
372
+ private = arguments.get("private")
373
+ gated = arguments.get("gated")
374
+ if private is not None:
375
+ print(f"Private: {private}")
376
+ if gated is not None:
377
+ print(f"Gated: {gated}")
378
+
379
  # Get user decision for this item
380
  response = await prompt_session.prompt_async(
381
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
agent/tools/hf_repo_files_tool.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Files Tool - File operations on Hugging Face repositories
3
+
4
+ Operations: list, read, upload, delete
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi, hf_hub_download
11
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal["list", "read", "upload", "delete"]
16
+
17
+
18
+ async def _async_call(func, *args, **kwargs):
19
+ """Wrap synchronous HfApi calls for async context."""
20
+ return await asyncio.to_thread(func, *args, **kwargs)
21
+
22
+
23
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
24
+ """Build the Hub URL for a repository."""
25
+ if repo_type == "model":
26
+ return f"https://huggingface.co/{repo_id}"
27
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
28
+
29
+
30
+ def _format_size(size_bytes: int) -> str:
31
+ """Format file size in human-readable form."""
32
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
33
+ if size_bytes < 1024:
34
+ return f"{size_bytes:.1f}{unit}"
35
+ size_bytes /= 1024
36
+ return f"{size_bytes:.1f}PB"
37
+
38
+
39
+ class HfRepoFilesTool:
40
+ """Tool for file operations on HF repos."""
41
+
42
+ def __init__(self, hf_token: Optional[str] = None):
43
+ self.api = HfApi(token=hf_token)
44
+
45
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
+ """Execute the specified operation."""
47
+ operation = args.get("operation")
48
+
49
+ if not operation:
50
+ return self._help()
51
+
52
+ try:
53
+ handlers = {
54
+ "list": self._list,
55
+ "read": self._read,
56
+ "upload": self._upload,
57
+ "delete": self._delete,
58
+ }
59
+
60
+ handler = handlers.get(operation)
61
+ if handler:
62
+ return await handler(args)
63
+ else:
64
+ return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
65
+
66
+ except RepositoryNotFoundError:
67
+ return self._error(f"Repository not found: {args.get('repo_id')}")
68
+ except EntryNotFoundError:
69
+ return self._error(f"File not found: {args.get('path')}")
70
+ except Exception as e:
71
+ return self._error(f"Error: {str(e)}")
72
+
73
+ def _help(self) -> ToolResult:
74
+ """Show usage instructions."""
75
+ return {
76
+ "formatted": """**hf_repo_files** - File operations on HF repos
77
+
78
+ **Operations:**
79
+ - `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
80
+ - `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
81
+ - `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
82
+ - `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
83
+
84
+ **Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
85
+ "totalResults": 1,
86
+ "resultsShared": 1,
87
+ }
88
+
89
+ async def _list(self, args: Dict[str, Any]) -> ToolResult:
90
+ """List files in a repository."""
91
+ repo_id = args.get("repo_id")
92
+ if not repo_id:
93
+ return self._error("repo_id is required")
94
+
95
+ repo_type = args.get("repo_type", "model")
96
+ revision = args.get("revision", "main")
97
+ path = args.get("path", "")
98
+
99
+ items = list(await _async_call(
100
+ self.api.list_repo_tree,
101
+ repo_id=repo_id,
102
+ repo_type=repo_type,
103
+ revision=revision,
104
+ path_in_repo=path,
105
+ recursive=True,
106
+ ))
107
+
108
+ if not items:
109
+ return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
110
+
111
+ lines = []
112
+ total_size = 0
113
+ for item in sorted(items, key=lambda x: x.path):
114
+ if hasattr(item, "size") and item.size:
115
+ total_size += item.size
116
+ lines.append(f"{item.path} ({_format_size(item.size)})")
117
+ else:
118
+ lines.append(f"{item.path}/")
119
+
120
+ url = _build_repo_url(repo_id, repo_type)
121
+ response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
122
+
123
+ return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
124
+
125
+ async def _read(self, args: Dict[str, Any]) -> ToolResult:
126
+ """Read file content from a repository."""
127
+ repo_id = args.get("repo_id")
128
+ path = args.get("path")
129
+
130
+ if not repo_id:
131
+ return self._error("repo_id is required")
132
+ if not path:
133
+ return self._error("path is required")
134
+
135
+ repo_type = args.get("repo_type", "model")
136
+ revision = args.get("revision", "main")
137
+ max_chars = args.get("max_chars", 50000)
138
+
139
+ file_path = await _async_call(
140
+ hf_hub_download,
141
+ repo_id=repo_id,
142
+ filename=path,
143
+ repo_type=repo_type,
144
+ revision=revision,
145
+ token=self.api.token,
146
+ )
147
+
148
+ try:
149
+ with open(file_path, "r", encoding="utf-8") as f:
150
+ content = f.read()
151
+
152
+ truncated = len(content) > max_chars
153
+ if truncated:
154
+ content = content[:max_chars]
155
+
156
+ url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
157
+ response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
158
+
159
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
160
+
161
+ except UnicodeDecodeError:
162
+ import os
163
+ size = os.path.getsize(file_path)
164
+ return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
165
+
166
+ async def _upload(self, args: Dict[str, Any]) -> ToolResult:
167
+ """Upload content to a repository."""
168
+ repo_id = args.get("repo_id")
169
+ path = args.get("path")
170
+ content = args.get("content")
171
+
172
+ if not repo_id:
173
+ return self._error("repo_id is required")
174
+ if not path:
175
+ return self._error("path is required")
176
+ if content is None:
177
+ return self._error("content is required")
178
+
179
+ repo_type = args.get("repo_type", "model")
180
+ revision = args.get("revision", "main")
181
+ create_pr = args.get("create_pr", False)
182
+ commit_message = args.get("commit_message", f"Upload {path}")
183
+
184
+ file_bytes = content.encode("utf-8") if isinstance(content, str) else content
185
+
186
+ result = await _async_call(
187
+ self.api.upload_file,
188
+ path_or_fileobj=file_bytes,
189
+ path_in_repo=path,
190
+ repo_id=repo_id,
191
+ repo_type=repo_type,
192
+ revision=revision,
193
+ commit_message=commit_message,
194
+ create_pr=create_pr,
195
+ )
196
+
197
+ url = _build_repo_url(repo_id, repo_type)
198
+ if create_pr and hasattr(result, "pr_url"):
199
+ response = f"**Uploaded as PR**\n{result.pr_url}"
200
+ else:
201
+ response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
202
+
203
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
204
+
205
+ async def _delete(self, args: Dict[str, Any]) -> ToolResult:
206
+ """Delete files from a repository."""
207
+ repo_id = args.get("repo_id")
208
+ patterns = args.get("patterns")
209
+
210
+ if not repo_id:
211
+ return self._error("repo_id is required")
212
+ if not patterns:
213
+ return self._error("patterns is required (list of paths/wildcards)")
214
+
215
+ if isinstance(patterns, str):
216
+ patterns = [patterns]
217
+
218
+ repo_type = args.get("repo_type", "model")
219
+ revision = args.get("revision", "main")
220
+ create_pr = args.get("create_pr", False)
221
+ commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
222
+
223
+ await _async_call(
224
+ self.api.delete_files,
225
+ repo_id=repo_id,
226
+ delete_patterns=patterns,
227
+ repo_type=repo_type,
228
+ revision=revision,
229
+ commit_message=commit_message,
230
+ create_pr=create_pr,
231
+ )
232
+
233
+ response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
234
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
235
+
236
+ def _error(self, message: str) -> ToolResult:
237
+ """Return an error result."""
238
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
239
+
240
+
241
+ # Tool specification
242
+ HF_REPO_FILES_TOOL_SPEC = {
243
+ "name": "hf_repo_files",
244
+ "description": (
245
+ "Read and write files in HF repos (models/datasets/spaces).\n\n"
246
+ "## Operations\n"
247
+ "- **list**: List files with sizes and structure\n"
248
+ "- **read**: Read file content (text files only)\n"
249
+ "- **upload**: Upload content to repo (can create PR)\n"
250
+ "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
251
+ "## Use when\n"
252
+ "- Need to see what files exist in a repo\n"
253
+ "- Want to read config.json, README.md, or other text files\n"
254
+ "- Uploading training scripts, configs, or results to a repo\n"
255
+ "- Cleaning up temporary files from a repo\n\n"
256
+ "## Examples\n"
257
+ '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
258
+ '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
259
+ '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
260
+ '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
261
+ '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
262
+ "## Notes\n"
263
+ "- For binary files (safetensors, bin), use list to see them but can't read content\n"
264
+ "- upload/delete require approval (can overwrite/destroy data)\n"
265
+ "- Use create_pr=true to propose changes instead of direct commit\n"
266
+ ),
267
+ "parameters": {
268
+ "type": "object",
269
+ "properties": {
270
+ "operation": {
271
+ "type": "string",
272
+ "enum": ["list", "read", "upload", "delete"],
273
+ "description": "Operation: list, read, upload, delete",
274
+ },
275
+ "repo_id": {
276
+ "type": "string",
277
+ "description": "Repository ID (e.g., 'username/repo-name')",
278
+ },
279
+ "repo_type": {
280
+ "type": "string",
281
+ "enum": ["model", "dataset", "space"],
282
+ "description": "Repository type (default: model)",
283
+ },
284
+ "revision": {
285
+ "type": "string",
286
+ "description": "Branch/tag/commit (default: main)",
287
+ },
288
+ "path": {
289
+ "type": "string",
290
+ "description": "File path for read/upload",
291
+ },
292
+ "content": {
293
+ "type": "string",
294
+ "description": "File content for upload",
295
+ },
296
+ "patterns": {
297
+ "type": "array",
298
+ "items": {"type": "string"},
299
+ "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
300
+ },
301
+ "create_pr": {
302
+ "type": "boolean",
303
+ "description": "Create PR instead of direct commit",
304
+ },
305
+ "commit_message": {
306
+ "type": "string",
307
+ "description": "Custom commit message",
308
+ },
309
+ },
310
+ "required": ["operation"],
311
+ },
312
+ }
313
+
314
+
315
+ async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
316
+ """Handler for agent tool router."""
317
+ try:
318
+ tool = HfRepoFilesTool()
319
+ result = await tool.execute(arguments)
320
+ return result["formatted"], not result.get("isError", False)
321
+ except Exception as e:
322
+ return f"Error: {str(e)}", False
agent/tools/hf_repo_git_tool.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Git Tool - Git-like operations on Hugging Face repositories
3
+
4
+ Operations: branches, tags, PRs, repo management
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal[
16
+ "create_branch", "delete_branch",
17
+ "create_tag", "delete_tag",
18
+ "list_refs",
19
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
20
+ "create_repo", "update_repo",
21
+ ]
22
+
23
+
24
+ async def _async_call(func, *args, **kwargs):
25
+ """Wrap synchronous HfApi calls for async context."""
26
+ return await asyncio.to_thread(func, *args, **kwargs)
27
+
28
+
29
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
30
+ """Build the Hub URL for a repository."""
31
+ if repo_type == "model":
32
+ return f"https://huggingface.co/{repo_id}"
33
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
34
+
35
+
36
+ class HfRepoGitTool:
37
+ """Tool for git-like operations on HF repos."""
38
+
39
+ def __init__(self, hf_token: Optional[str] = None):
40
+ self.api = HfApi(token=hf_token)
41
+
42
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
43
+ """Execute the specified operation."""
44
+ operation = args.get("operation")
45
+
46
+ if not operation:
47
+ return self._help()
48
+
49
+ try:
50
+ handlers = {
51
+ "create_branch": self._create_branch,
52
+ "delete_branch": self._delete_branch,
53
+ "create_tag": self._create_tag,
54
+ "delete_tag": self._delete_tag,
55
+ "list_refs": self._list_refs,
56
+ "create_pr": self._create_pr,
57
+ "list_prs": self._list_prs,
58
+ "get_pr": self._get_pr,
59
+ "merge_pr": self._merge_pr,
60
+ "close_pr": self._close_pr,
61
+ "comment_pr": self._comment_pr,
62
+ "change_pr_status": self._change_pr_status,
63
+ "create_repo": self._create_repo,
64
+ "update_repo": self._update_repo,
65
+ }
66
+
67
+ handler = handlers.get(operation)
68
+ if handler:
69
+ return await handler(args)
70
+ else:
71
+ ops = ", ".join(handlers.keys())
72
+ return self._error(f"Unknown operation: {operation}. Valid: {ops}")
73
+
74
+ except RepositoryNotFoundError:
75
+ return self._error(f"Repository not found: {args.get('repo_id')}")
76
+ except Exception as e:
77
+ return self._error(f"Error: {str(e)}")
78
+
79
+ def _help(self) -> ToolResult:
80
+ """Show usage instructions."""
81
+ return {
82
+ "formatted": """**hf_repo_git** - Git-like operations on HF repos
83
+
84
+ **Branch/Tag:**
85
+ - `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
86
+ - `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
87
+ - `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
88
+ - `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
89
+ - `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
90
+
91
+ **PRs:**
92
+ - `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR)
93
+ - `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed)
94
+ - `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status)
95
+ - `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open)
96
+ - `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
97
+ - `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
98
+ - `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
99
+
100
+ **Repo:**
101
+ - `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
102
+ - `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
103
+ "totalResults": 1,
104
+ "resultsShared": 1,
105
+ }
106
+
107
+ # =========================================================================
108
+ # BRANCH OPERATIONS
109
+ # =========================================================================
110
+
111
+ async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
112
+ """Create a new branch."""
113
+ repo_id = args.get("repo_id")
114
+ branch = args.get("branch")
115
+
116
+ if not repo_id:
117
+ return self._error("repo_id is required")
118
+ if not branch:
119
+ return self._error("branch is required")
120
+
121
+ repo_type = args.get("repo_type", "model")
122
+ from_rev = args.get("from_rev", "main")
123
+
124
+ await _async_call(
125
+ self.api.create_branch,
126
+ repo_id=repo_id,
127
+ branch=branch,
128
+ revision=from_rev,
129
+ repo_type=repo_type,
130
+ exist_ok=args.get("exist_ok", False),
131
+ )
132
+
133
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
134
+ return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
135
+
136
+ async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
137
+ """Delete a branch."""
138
+ repo_id = args.get("repo_id")
139
+ branch = args.get("branch")
140
+
141
+ if not repo_id:
142
+ return self._error("repo_id is required")
143
+ if not branch:
144
+ return self._error("branch is required")
145
+
146
+ repo_type = args.get("repo_type", "model")
147
+
148
+ await _async_call(
149
+ self.api.delete_branch,
150
+ repo_id=repo_id,
151
+ branch=branch,
152
+ repo_type=repo_type,
153
+ )
154
+
155
+ return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
156
+
157
+ # =========================================================================
158
+ # TAG OPERATIONS
159
+ # =========================================================================
160
+
161
+ async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
162
+ """Create a tag."""
163
+ repo_id = args.get("repo_id")
164
+ tag = args.get("tag")
165
+
166
+ if not repo_id:
167
+ return self._error("repo_id is required")
168
+ if not tag:
169
+ return self._error("tag is required")
170
+
171
+ repo_type = args.get("repo_type", "model")
172
+ revision = args.get("revision", "main")
173
+ tag_message = args.get("tag_message", "")
174
+
175
+ await _async_call(
176
+ self.api.create_tag,
177
+ repo_id=repo_id,
178
+ tag=tag,
179
+ revision=revision,
180
+ tag_message=tag_message,
181
+ repo_type=repo_type,
182
+ exist_ok=args.get("exist_ok", False),
183
+ )
184
+
185
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
186
+ return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
187
+
188
+ async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
189
+ """Delete a tag."""
190
+ repo_id = args.get("repo_id")
191
+ tag = args.get("tag")
192
+
193
+ if not repo_id:
194
+ return self._error("repo_id is required")
195
+ if not tag:
196
+ return self._error("tag is required")
197
+
198
+ repo_type = args.get("repo_type", "model")
199
+
200
+ await _async_call(
201
+ self.api.delete_tag,
202
+ repo_id=repo_id,
203
+ tag=tag,
204
+ repo_type=repo_type,
205
+ )
206
+
207
+ return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
208
+
209
+ # =========================================================================
210
+ # LIST REFS
211
+ # =========================================================================
212
+
213
+ async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
214
+ """List branches and tags."""
215
+ repo_id = args.get("repo_id")
216
+
217
+ if not repo_id:
218
+ return self._error("repo_id is required")
219
+
220
+ repo_type = args.get("repo_type", "model")
221
+
222
+ refs = await _async_call(
223
+ self.api.list_repo_refs,
224
+ repo_id=repo_id,
225
+ repo_type=repo_type,
226
+ )
227
+
228
+ branches = [b.name for b in refs.branches] if refs.branches else []
229
+ tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
230
+
231
+ url = _build_repo_url(repo_id, repo_type)
232
+ lines = [f"**{repo_id}**", url, ""]
233
+
234
+ if branches:
235
+ lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
236
+ else:
237
+ lines.append("**Branches:** none")
238
+
239
+ if tags:
240
+ lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
241
+ else:
242
+ lines.append("**Tags:** none")
243
+
244
+ return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
245
+
246
+ # =========================================================================
247
+ # PR OPERATIONS
248
+ # =========================================================================
249
+
250
+ async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
251
+ """Create a pull request."""
252
+ repo_id = args.get("repo_id")
253
+ title = args.get("title")
254
+
255
+ if not repo_id:
256
+ return self._error("repo_id is required")
257
+ if not title:
258
+ return self._error("title is required")
259
+
260
+ repo_type = args.get("repo_type", "model")
261
+ description = args.get("description", "")
262
+
263
+ result = await _async_call(
264
+ self.api.create_pull_request,
265
+ repo_id=repo_id,
266
+ title=title,
267
+ description=description,
268
+ repo_type=repo_type,
269
+ )
270
+
271
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
272
+ return {
273
+ "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
274
+ "totalResults": 1,
275
+ "resultsShared": 1,
276
+ }
277
+
278
+ async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
279
+ """List PRs and discussions."""
280
+ repo_id = args.get("repo_id")
281
+
282
+ if not repo_id:
283
+ return self._error("repo_id is required")
284
+
285
+ repo_type = args.get("repo_type", "model")
286
+ status = args.get("status", "all") # open, closed, all
287
+
288
+ discussions = list(self.api.get_repo_discussions(
289
+ repo_id=repo_id,
290
+ repo_type=repo_type,
291
+ discussion_status=status if status != "all" else None,
292
+ ))
293
+
294
+ if not discussions:
295
+ return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
296
+
297
+ url = _build_repo_url(repo_id, repo_type)
298
+ lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
299
+
300
+ for d in discussions[:20]:
301
+ if d.status == "draft":
302
+ status_label = "[DRAFT]"
303
+ elif d.status == "open":
304
+ status_label = "[OPEN]"
305
+ elif d.status == "merged":
306
+ status_label = "[MERGED]"
307
+ else:
308
+ status_label = "[CLOSED]"
309
+ type_label = "PR" if d.is_pull_request else "D"
310
+ lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
311
+
312
+ return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
313
+
314
+ async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
315
+ """Get PR details."""
316
+ repo_id = args.get("repo_id")
317
+ pr_num = args.get("pr_num")
318
+
319
+ if not repo_id:
320
+ return self._error("repo_id is required")
321
+ if not pr_num:
322
+ return self._error("pr_num is required")
323
+
324
+ repo_type = args.get("repo_type", "model")
325
+
326
+ pr = await _async_call(
327
+ self.api.get_discussion_details,
328
+ repo_id=repo_id,
329
+ discussion_num=int(pr_num),
330
+ repo_type=repo_type,
331
+ )
332
+
333
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
334
+ status_map = {
335
+ "draft": "Draft",
336
+ "open": "Open",
337
+ "merged": "Merged",
338
+ "closed": "Closed"
339
+ }
340
+ status = status_map.get(pr.status, pr.status.capitalize())
341
+ type_label = "Pull Request" if pr.is_pull_request else "Discussion"
342
+
343
+ lines = [
344
+ f"**{type_label} #{pr_num}:** {pr.title}",
345
+ f"**Status:** {status}",
346
+ f"**Author:** {pr.author}",
347
+ url,
348
+ ]
349
+
350
+ if pr.is_pull_request:
351
+ if pr.status == "draft":
352
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
353
+ elif pr.status == "open":
354
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
355
+
356
+ return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
357
+
358
+ async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
359
+ """Merge a pull request."""
360
+ repo_id = args.get("repo_id")
361
+ pr_num = args.get("pr_num")
362
+
363
+ if not repo_id:
364
+ return self._error("repo_id is required")
365
+ if not pr_num:
366
+ return self._error("pr_num is required")
367
+
368
+ repo_type = args.get("repo_type", "model")
369
+ comment = args.get("comment", "")
370
+
371
+ await _async_call(
372
+ self.api.merge_pull_request,
373
+ repo_id=repo_id,
374
+ discussion_num=int(pr_num),
375
+ comment=comment,
376
+ repo_type=repo_type,
377
+ )
378
+
379
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
380
+ return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
381
+
382
+ async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
383
+ """Close a PR/discussion."""
384
+ repo_id = args.get("repo_id")
385
+ pr_num = args.get("pr_num")
386
+
387
+ if not repo_id:
388
+ return self._error("repo_id is required")
389
+ if not pr_num:
390
+ return self._error("pr_num is required")
391
+
392
+ repo_type = args.get("repo_type", "model")
393
+ comment = args.get("comment", "")
394
+
395
+ await _async_call(
396
+ self.api.change_discussion_status,
397
+ repo_id=repo_id,
398
+ discussion_num=int(pr_num),
399
+ new_status="closed",
400
+ comment=comment,
401
+ repo_type=repo_type,
402
+ )
403
+
404
+ return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
405
+
406
+ async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
407
+ """Add a comment to a PR/discussion."""
408
+ repo_id = args.get("repo_id")
409
+ pr_num = args.get("pr_num")
410
+ comment = args.get("comment")
411
+
412
+ if not repo_id:
413
+ return self._error("repo_id is required")
414
+ if not pr_num:
415
+ return self._error("pr_num is required")
416
+ if not comment:
417
+ return self._error("comment is required")
418
+
419
+ repo_type = args.get("repo_type", "model")
420
+
421
+ await _async_call(
422
+ self.api.comment_discussion,
423
+ repo_id=repo_id,
424
+ discussion_num=int(pr_num),
425
+ comment=comment,
426
+ repo_type=repo_type,
427
+ )
428
+
429
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
430
+ return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
431
+
432
+ async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
433
+ """Change PR/discussion status (mainly to convert draft to open)."""
434
+ repo_id = args.get("repo_id")
435
+ pr_num = args.get("pr_num")
436
+ new_status = args.get("new_status")
437
+
438
+ if not repo_id:
439
+ return self._error("repo_id is required")
440
+ if not pr_num:
441
+ return self._error("pr_num is required")
442
+ if not new_status:
443
+ return self._error("new_status is required (open or closed)")
444
+
445
+ repo_type = args.get("repo_type", "model")
446
+ comment = args.get("comment", "")
447
+
448
+ await _async_call(
449
+ self.api.change_discussion_status,
450
+ repo_id=repo_id,
451
+ discussion_num=int(pr_num),
452
+ new_status=new_status,
453
+ comment=comment,
454
+ repo_type=repo_type,
455
+ )
456
+
457
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
458
+ return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1}
459
+
460
+ # =========================================================================
461
+ # REPO MANAGEMENT
462
+ # =========================================================================
463
+
464
+ async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
465
+ """Create a new repository."""
466
+ repo_id = args.get("repo_id")
467
+
468
+ if not repo_id:
469
+ return self._error("repo_id is required")
470
+
471
+ repo_type = args.get("repo_type", "model")
472
+ private = args.get("private", True)
473
+ space_sdk = args.get("space_sdk")
474
+
475
+ if repo_type == "space" and not space_sdk:
476
+ return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
477
+
478
+ kwargs = {
479
+ "repo_id": repo_id,
480
+ "repo_type": repo_type,
481
+ "private": private,
482
+ "exist_ok": args.get("exist_ok", False),
483
+ }
484
+ if space_sdk:
485
+ kwargs["space_sdk"] = space_sdk
486
+
487
+ result = await _async_call(self.api.create_repo, **kwargs)
488
+
489
+ return {
490
+ "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
491
+ "totalResults": 1,
492
+ "resultsShared": 1,
493
+ }
494
+
495
+ async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
496
+ """Update repository settings."""
497
+ repo_id = args.get("repo_id")
498
+
499
+ if not repo_id:
500
+ return self._error("repo_id is required")
501
+
502
+ repo_type = args.get("repo_type", "model")
503
+ private = args.get("private")
504
+ gated = args.get("gated")
505
+
506
+ if private is None and gated is None:
507
+ return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
508
+
509
+ kwargs = {"repo_id": repo_id, "repo_type": repo_type}
510
+ if private is not None:
511
+ kwargs["private"] = private
512
+ if gated is not None:
513
+ kwargs["gated"] = gated
514
+
515
+ await _async_call(self.api.update_repo_settings, **kwargs)
516
+
517
+ changes = []
518
+ if private is not None:
519
+ changes.append(f"private={private}")
520
+ if gated is not None:
521
+ changes.append(f"gated={gated}")
522
+
523
+ url = f"{_build_repo_url(repo_id, repo_type)}/settings"
524
+ return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
525
+
526
+ def _error(self, message: str) -> ToolResult:
527
+ """Return an error result."""
528
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
529
+
530
+
531
+ # Tool specification
532
+ HF_REPO_GIT_TOOL_SPEC = {
533
+ "name": "hf_repo_git",
534
+ "description": (
535
+ "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
536
+ "## Operations\n"
537
+ "**Branches:** create_branch, delete_branch, list_refs\n"
538
+ "**Tags:** create_tag, delete_tag\n"
539
+ "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n"
540
+ "**Repo:** create_repo, update_repo\n\n"
541
+ "## Use when\n"
542
+ "- Creating feature branches for experiments\n"
543
+ "- Tagging model versions (v1.0, v2.0)\n"
544
+ "- Opening PRs to contribute to repos you don't own\n"
545
+ "- Reviewing and merging PRs on your repos\n"
546
+ "- Creating new model/dataset/space repos\n"
547
+ "- Changing repo visibility (public/private) or gated access\n\n"
548
+ "## Examples\n"
549
+ '{"operation": "list_refs", "repo_id": "my-model"}\n'
550
+ '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
551
+ '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
552
+ '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
553
+ '{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n'
554
+ '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
555
+ '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
556
+ '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
557
+ "## PR Workflow\n"
558
+ "1. create_pr → creates draft PR (empty by default)\n"
559
+ "2. Upload files with revision='refs/pr/N' to add commits\n"
560
+ "3. change_pr_status with new_status='open' to publish (convert draft to open)\n"
561
+ "4. merge_pr when ready\n\n"
562
+ "## Notes\n"
563
+ "- PR status: draft (default), open, merged, closed\n"
564
+ "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
565
+ "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
566
+ "- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
567
+ ),
568
+ "parameters": {
569
+ "type": "object",
570
+ "properties": {
571
+ "operation": {
572
+ "type": "string",
573
+ "enum": [
574
+ "create_branch", "delete_branch",
575
+ "create_tag", "delete_tag", "list_refs",
576
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
577
+ "create_repo", "update_repo",
578
+ ],
579
+ "description": "Operation to execute",
580
+ },
581
+ "repo_id": {
582
+ "type": "string",
583
+ "description": "Repository ID (e.g., 'username/repo-name')",
584
+ },
585
+ "repo_type": {
586
+ "type": "string",
587
+ "enum": ["model", "dataset", "space"],
588
+ "description": "Repository type (default: model)",
589
+ },
590
+ "branch": {
591
+ "type": "string",
592
+ "description": "Branch name (create_branch, delete_branch)",
593
+ },
594
+ "from_rev": {
595
+ "type": "string",
596
+ "description": "Create branch from this revision (default: main)",
597
+ },
598
+ "tag": {
599
+ "type": "string",
600
+ "description": "Tag name (create_tag, delete_tag)",
601
+ },
602
+ "revision": {
603
+ "type": "string",
604
+ "description": "Revision for tag (default: main)",
605
+ },
606
+ "tag_message": {
607
+ "type": "string",
608
+ "description": "Tag description",
609
+ },
610
+ "title": {
611
+ "type": "string",
612
+ "description": "PR title (create_pr)",
613
+ },
614
+ "description": {
615
+ "type": "string",
616
+ "description": "PR description (create_pr)",
617
+ },
618
+ "pr_num": {
619
+ "type": "integer",
620
+ "description": "PR/discussion number",
621
+ },
622
+ "comment": {
623
+ "type": "string",
624
+ "description": "Comment text",
625
+ },
626
+ "status": {
627
+ "type": "string",
628
+ "enum": ["open", "closed", "all"],
629
+ "description": "Filter PRs by status (list_prs)",
630
+ },
631
+ "new_status": {
632
+ "type": "string",
633
+ "enum": ["open", "closed"],
634
+ "description": "New status for PR/discussion (change_pr_status)",
635
+ },
636
+ "private": {
637
+ "type": "boolean",
638
+ "description": "Make repo private (create_repo, update_repo)",
639
+ },
640
+ "gated": {
641
+ "type": "string",
642
+ "enum": ["auto", "manual", "false"],
643
+ "description": "Gated access setting (update_repo)",
644
+ },
645
+ "space_sdk": {
646
+ "type": "string",
647
+ "enum": ["gradio", "streamlit", "docker", "static"],
648
+ "description": "Space SDK (required for create_repo with space)",
649
+ },
650
+ },
651
+ "required": ["operation"],
652
+ },
653
+ }
654
+
655
+
656
+ async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
657
+ """Handler for agent tool router."""
658
+ try:
659
+ tool = HfRepoGitTool()
660
+ result = await tool.execute(arguments)
661
+ return result["formatted"], not result.get("isError", False)
662
+ except Exception as e:
663
+ return f"Error: {str(e)}", False
agent/tools/utils_tools.py DELETED
@@ -1,203 +0,0 @@
1
- """
2
- Utils Tools - General utility operations
3
-
4
- Provides system information like current date/time with timezone support.
5
- """
6
-
7
- import zoneinfo
8
- from datetime import datetime
9
- from typing import Any, Dict, Literal
10
-
11
- from agent.tools.types import ToolResult
12
-
13
- # Operation names
14
- OperationType = Literal["get_datetime"]
15
-
16
-
17
- class UtilsTool:
18
- """Tool for general utility operations."""
19
-
20
- async def execute(self, params: Dict[str, Any]) -> ToolResult:
21
- """Execute the specified utility operation."""
22
- operation = params.get("operation")
23
- args = params.get("args", {})
24
-
25
- # If no operation provided, return usage instructions
26
- if not operation:
27
- return self._show_help()
28
-
29
- # Normalize operation name
30
- operation = operation.lower()
31
-
32
- # Check if help is requested
33
- if args.get("help"):
34
- return self._show_operation_help(operation)
35
-
36
- try:
37
- # Route to appropriate handler
38
- if operation == "get_datetime":
39
- return await self._get_datetime(args)
40
- else:
41
- return {
42
- "formatted": f'Unknown operation: "{operation}"\n\n'
43
- "Available operations: get_datetime\n\n"
44
- "Call this tool with no operation for full usage instructions.",
45
- "totalResults": 0,
46
- "resultsShared": 0,
47
- "isError": True,
48
- }
49
-
50
- except Exception as e:
51
- return {
52
- "formatted": f"Error executing {operation}: {str(e)}",
53
- "totalResults": 0,
54
- "resultsShared": 0,
55
- "isError": True,
56
- }
57
-
58
- def _show_help(self) -> ToolResult:
59
- """Show usage instructions when tool is called with no arguments."""
60
- usage_text = """# Utils Tool
61
-
62
- Utility operations for system information.
63
-
64
- ## Available Commands
65
-
66
- - **get_datetime** - Get current date and time with timezone support
67
-
68
- ## Examples
69
-
70
- ### Get current date and time (Paris timezone by default)
71
- Call this tool with:
72
- ```json
73
- {
74
- "operation": "get_datetime",
75
- "args": {}
76
- }
77
- ```
78
-
79
- ### Get current date and time in a specific timezone
80
- Call this tool with:
81
- ```json
82
- {
83
- "operation": "get_datetime",
84
- "args": {
85
- "timezone": "America/New_York"
86
- }
87
- }
88
- ```
89
-
90
- Common timezones: Europe/Paris, America/New_York, America/Los_Angeles, Asia/Tokyo, UTC
91
-
92
- ## Tips
93
-
94
- - **Default timezone**: Paris (Europe/Paris)
95
- - **Date format**: dd-mm-yyyy
96
- - **Time format**: HH:MM:SS.mmm (24-hour format with milliseconds)
97
- - **Timezone names**: Use IANA timezone database names (e.g., "Europe/Paris", "UTC")
98
- """
99
- return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
100
-
101
- def _show_operation_help(self, operation: str) -> ToolResult:
102
- """Show help for a specific operation."""
103
- help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
104
- return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
105
-
106
- async def _get_datetime(self, args: Dict[str, Any]) -> ToolResult:
107
- """Get current date and time with timezone support."""
108
- timezone_name = args.get("timezone", "Europe/Paris")
109
-
110
- try:
111
- # Get timezone object
112
- tz = zoneinfo.ZoneInfo(timezone_name)
113
-
114
- # Get current datetime in specified timezone
115
- now = datetime.now(tz)
116
-
117
- # Format date as dd-mm-yyyy
118
- date_str = now.strftime("%d-%m-%Y")
119
-
120
- # Format time as HH:MM:SS.mmm
121
- time_str = now.strftime("%H:%M:%S.%f")[
122
- :-3
123
- ] # Remove last 3 digits to keep only milliseconds
124
-
125
- # Get timezone abbreviation/offset
126
- tz_offset = now.strftime("%z")
127
- tz_name = now.strftime("%Z")
128
-
129
- response = f"""✓ Current date and time
130
-
131
- **Date:** {date_str}
132
- **Time:** {time_str}
133
- **Timezone:** {timezone_name} ({tz_name}, UTC{tz_offset[:3]}:{tz_offset[3:]})
134
-
135
- **ISO Format:** {now.isoformat()}
136
- **Unix Timestamp:** {int(now.timestamp())}"""
137
-
138
- return {"formatted": response, "totalResults": 1, "resultsShared": 1}
139
-
140
- except zoneinfo.ZoneInfoNotFoundError:
141
- return {
142
- "formatted": f"Invalid timezone: {timezone_name}\n\n"
143
- "Use IANA timezone database names like:\n"
144
- "- Europe/Paris\n"
145
- "- America/New_York\n"
146
- "- Asia/Tokyo\n"
147
- "- UTC\n\n"
148
- "See: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones",
149
- "totalResults": 0,
150
- "resultsShared": 0,
151
- "isError": True,
152
- }
153
- except Exception as e:
154
- return {
155
- "formatted": f"Failed to get date/time: {str(e)}",
156
- "totalResults": 0,
157
- "resultsShared": 0,
158
- "isError": True,
159
- }
160
-
161
-
162
- # Tool specification for agent registration
163
- UTILS_TOOL_SPEC = {
164
- "name": "utils",
165
- "description": (
166
- "System utility operations - currently provides date/time with timezone support. "
167
- "**Use when:** (1) Need current date for logging/timestamps, (2) User asks 'what time is it', "
168
- "(3) Need timezone-aware datetime for scheduling/coordination, (4) Creating timestamped filenames. "
169
- "**Operation:** get_datetime with optional timezone parameter (default: Europe/Paris). "
170
- "Returns: Date (dd-mm-yyyy), time (HH:MM:SS.mmm), timezone info, ISO format, Unix timestamp. "
171
- "**Pattern:** utils get_datetime → use timestamp in filename/log → upload to hf_private_repos. "
172
- "Supports IANA timezone names: 'Europe/Paris', 'America/New_York', 'Asia/Tokyo', 'UTC'."
173
- ),
174
- "parameters": {
175
- "type": "object",
176
- "properties": {
177
- "operation": {
178
- "type": "string",
179
- "enum": ["get_datetime"],
180
- "description": "Operation to execute. Valid values: [get_datetime]",
181
- },
182
- "args": {
183
- "type": "object",
184
- "description": (
185
- "Operation-specific arguments as a JSON object. "
186
- "For get_datetime: timezone (string, optional, default: Europe/Paris). "
187
- "Use IANA timezone names like 'America/New_York', 'Asia/Tokyo', 'UTC'."
188
- ),
189
- "additionalProperties": True,
190
- },
191
- },
192
- },
193
- }
194
-
195
-
196
- async def utils_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
197
- """Handler for agent tool router."""
198
- try:
199
- tool = UtilsTool()
200
- result = await tool.execute(arguments)
201
- return result["formatted"], not result.get("isError", False)
202
- except Exception as e:
203
- return f"Error executing Utils tool: {str(e)}", False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_dataset_tools.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Test script for unified dataset inspection tool
3
  """
4
 
5
  import asyncio
@@ -8,7 +8,7 @@ from typing import TypedDict
8
  from unittest.mock import MagicMock
9
 
10
 
11
- # Mock the types module before importing dataset_tools
12
  class ToolResult(TypedDict, total=False):
13
  formatted: str
14
  totalResults: int
@@ -20,60 +20,70 @@ mock_types = MagicMock()
20
  mock_types.ToolResult = ToolResult
21
  sys.modules["agent.tools.types"] = mock_types
22
 
23
- # Now import directly from the file
24
- sys.path.insert(0, "/Users/akseljoonas/Documents/hf-agent/agent/tools")
25
- from dataset_tools import hf_inspect_dataset_handler, inspect_dataset
26
 
27
 
28
- async def test_inspect_dataset():
29
- """Test the unified inspect_dataset function"""
30
- print("=" * 70)
31
- print("Testing inspect_dataset()")
32
- print("=" * 70)
33
 
34
- # Test with akseljoonas/hf-agent-sessions as specified
35
- print("\n→ inspect_dataset('akseljoonas/hf-agent-sessions'):")
36
- result = await inspect_dataset("akseljoonas/hf-agent-sessions")
37
- print(f" isError: {result['isError']}")
38
- print(f" Output:\n{result['formatted']}")
39
 
40
- print("\n" + "=" * 70)
41
-
42
- # # Test with stanfordnlp/imdb
43
- # print("\n→ inspect_dataset('stanfordnlp/imdb'):")
44
- # result = await inspect_dataset("stanfordnlp/imdb")
45
- # print(f" isError: {result['isError']}")
46
- # print(f" Output:\n{result['formatted']}")
47
-
48
- # print("\n" + "=" * 70)
 
 
 
 
 
 
 
 
 
 
49
 
50
- # # Test with multi-config dataset
51
- # print("\n→ inspect_dataset('nyu-mll/glue', config='mrpc'):")
52
- # result = await inspect_dataset("nyu-mll/glue", config="mrpc")
53
- # print(f" isError: {result['isError']}")
54
- # print(f" Output:\n{result['formatted']}")
55
 
 
 
 
 
 
56
 
57
- async def test_handler():
58
- """Test the handler (what the agent calls)"""
59
- print("\n" + "=" * 70)
60
- print("Testing hf_inspect_dataset_handler()")
61
- print("=" * 70)
62
 
63
- result, success = await hf_inspect_dataset_handler(
64
- {
65
- "dataset": "stanfordnlp/imdb",
66
- "sample_rows": 2,
67
- }
 
 
 
 
68
  )
69
- print("\n→ Handler result:")
70
- print(f" success: {success}")
71
- print(f" output:\n{result}")
 
 
 
 
72
 
73
 
74
  if __name__ == "__main__":
75
- print("\nUnified Dataset Inspection Tool Test\n")
76
- asyncio.run(test_inspect_dataset())
77
- # asyncio.run(test_handler())
78
- print("\n" + "=" * 70)
79
- print("Done!")
 
 
1
  """
2
+ Test script for hf_repo_files and hf_repo_git tools
3
  """
4
 
5
  import asyncio
 
8
  from unittest.mock import MagicMock
9
 
10
 
11
+ # Mock the types module before importing
12
  class ToolResult(TypedDict, total=False):
13
  formatted: str
14
  totalResults: int
 
20
  mock_types.ToolResult = ToolResult
21
  sys.modules["agent.tools.types"] = mock_types
22
 
23
+ from agent.tools.hf_repo_files_tool import HfRepoFilesTool
24
+ from agent.tools.hf_repo_git_tool import HfRepoGitTool
 
25
 
26
 
27
+ async def test_hf_repo_files():
28
+ """Test hf_repo_files tool"""
29
+ print("=" * 60)
30
+ print("Testing hf_repo_files")
31
+ print("=" * 60)
32
 
33
+ tool = HfRepoFilesTool()
 
 
 
 
34
 
35
+ # Test list
36
+ print("\n→ list files in gpt2:")
37
+ result = await tool.execute(
38
+ {"operation": "list", "repo_id": "openai-community/gpt2"}
39
+ )
40
+ print(f" isError: {result.get('isError', False)}")
41
+ print(f" totalResults: {result['totalResults']}")
42
+ # Just show first few lines
43
+ lines = result["formatted"].split("\n")
44
+ print(" Output (first 5 lines):\n" + "\n".join(f" {line}" for line in lines))
45
+
46
+ # Test read
47
+ print("\n→ read config.json from gpt2:")
48
+ result = await tool.execute(
49
+ {"operation": "read", "repo_id": "openai-community/gpt2", "path": "config.json"}
50
+ )
51
+ print(f" isError: {result.get('isError', False)}")
52
+ lines = result["formatted"].split("\n")
53
+ print(" Output (first 10 lines):\n" + "\n".join(f" {line}" for line in lines))
54
 
 
 
 
 
 
55
 
56
+ async def test_hf_repo_git():
57
+ """Test hf_repo_git tool"""
58
+ print("\n" + "=" * 60)
59
+ print("Testing hf_repo_git")
60
+ print("=" * 60)
61
 
62
+ tool = HfRepoGitTool()
 
 
 
 
63
 
64
+ # Test list_refs
65
+ print("\n→ list_refs for gpt2:")
66
+ result = await tool.execute(
67
+ {"operation": "list_refs", "repo_id": "openai-community/gpt2"}
68
+ )
69
+ print(f" isError: {result.get('isError', False)}")
70
+ print(
71
+ " Output:\n"
72
+ + "\n".join(f" {line}" for line in result["formatted"].split("\n"))
73
  )
74
+
75
+ # Test help (no operation)
76
+ print("\n→ help (no operation):")
77
+ result = await tool.execute({})
78
+ print(f" isError: {result.get('isError', False)}")
79
+ lines = result["formatted"].split("\n")[:6]
80
+ print(" Output (first 6 lines):\n" + "\n".join(f" {line}" for line in lines))
81
 
82
 
83
  if __name__ == "__main__":
84
+ print("\nHF Repo Tools Test\n")
85
+ asyncio.run(test_hf_repo_files())
86
+ asyncio.run(test_hf_repo_git())
87
+ print("\n" + "=" * 60)
88
+ print("Tests complete!")
89
+ print("=" * 60)