akseljoonas HF Staff commited on
Commit
8b2c9e3
·
1 Parent(s): 7efc07b

Add hf_repo_files and hf_repo_git tools

Browse files

New tools to replace private_hf_repo_tools with comprehensive HF repo management:

hf_repo_files (4 ops):
- list: List files with sizes
- read: Read file content
- upload: Upload content (approval required)
- delete: Delete files with wildcards (approval required)

hf_repo_git (13 ops):
- Branches: create_branch, delete_branch, list_refs
- Tags: create_tag, delete_tag
- PRs: create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr
- Repo: create_repo, update_repo

Approval required for destructive/overwriting operations:
- hf_repo_files: upload, delete
- hf_repo_git: delete_branch, delete_tag, merge_pr, create_repo, update_repo

agent/core/agent_loop.py CHANGED
@@ -76,7 +76,19 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
-
 
 
 
 
 
 
 
 
 
 
 
 
80
  return False
81
 
82
 
 
76
  # Other operations (create_repo, etc.) always require approval
77
  if operation in ["create_repo"]:
78
  return True
79
+
80
+ # hf_repo_files: upload (can overwrite) and delete require approval
81
+ if tool_name == "hf_repo_files":
82
+ operation = tool_args.get("operation", "")
83
+ if operation in ["upload", "delete"]:
84
+ return True
85
+
86
+ # hf_repo_git: destructive operations require approval
87
+ if tool_name == "hf_repo_git":
88
+ operation = tool_args.get("operation", "")
89
+ if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
90
+ return True
91
+
92
  return False
93
 
94
 
agent/core/tools.py CHANGED
@@ -37,8 +37,16 @@ from agent.tools.github_read_file import (
37
  )
38
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
39
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
 
 
 
 
 
 
 
 
40
 
41
- # NOTE: Private HF repo tool disabled - system prompt handles pushing to hub pretty well now
42
  # from agent.tools.private_hf_repo_tools import (
43
  # PRIVATE_HF_REPO_TOOL_SPEC,
44
  # private_hf_repo_handler,
@@ -280,13 +288,19 @@ def create_builtin_tools() -> list[ToolSpec]:
280
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
281
  handler=hf_jobs_handler,
282
  ),
283
- # NOTE: Private HF repo tool disabled - system prompt now reliably instructs agent to push to hub
284
- # ToolSpec(
285
- # name=PRIVATE_HF_REPO_TOOL_SPEC["name"],
286
- # description=PRIVATE_HF_REPO_TOOL_SPEC["description"],
287
- # parameters=PRIVATE_HF_REPO_TOOL_SPEC["parameters"],
288
- # handler=private_hf_repo_handler,
289
- # ),
 
 
 
 
 
 
290
 
291
 
292
  # NOTE: Github search code tool disabled - a bit buggy
 
37
  )
38
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
39
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
40
+ from agent.tools.hf_repo_files_tool import (
41
+ HF_REPO_FILES_TOOL_SPEC,
42
+ hf_repo_files_handler,
43
+ )
44
+ from agent.tools.hf_repo_git_tool import (
45
+ HF_REPO_GIT_TOOL_SPEC,
46
+ hf_repo_git_handler,
47
+ )
48
 
49
+ # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
50
  # from agent.tools.private_hf_repo_tools import (
51
  # PRIVATE_HF_REPO_TOOL_SPEC,
52
  # private_hf_repo_handler,
 
288
  parameters=HF_JOBS_TOOL_SPEC["parameters"],
289
  handler=hf_jobs_handler,
290
  ),
291
+ # HF Repo management tools
292
+ ToolSpec(
293
+ name=HF_REPO_FILES_TOOL_SPEC["name"],
294
+ description=HF_REPO_FILES_TOOL_SPEC["description"],
295
+ parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
296
+ handler=hf_repo_files_handler,
297
+ ),
298
+ ToolSpec(
299
+ name=HF_REPO_GIT_TOOL_SPEC["name"],
300
+ description=HF_REPO_GIT_TOOL_SPEC["description"],
301
+ parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
302
+ handler=hf_repo_git_handler,
303
+ ),
304
 
305
 
306
  # NOTE: Github search code tool disabled - a bit buggy
agent/tools/hf_repo_files_tool.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Files Tool - File operations on Hugging Face repositories
3
+
4
+ Operations: list, read, upload, delete
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi, hf_hub_download
11
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal["list", "read", "upload", "delete"]
16
+
17
+
18
+ async def _async_call(func, *args, **kwargs):
19
+ """Wrap synchronous HfApi calls for async context."""
20
+ return await asyncio.to_thread(func, *args, **kwargs)
21
+
22
+
23
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
24
+ """Build the Hub URL for a repository."""
25
+ if repo_type == "model":
26
+ return f"https://huggingface.co/{repo_id}"
27
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
28
+
29
+
30
+ def _format_size(size_bytes: int) -> str:
31
+ """Format file size in human-readable form."""
32
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
33
+ if size_bytes < 1024:
34
+ return f"{size_bytes:.1f}{unit}"
35
+ size_bytes /= 1024
36
+ return f"{size_bytes:.1f}PB"
37
+
38
+
39
+ class HfRepoFilesTool:
40
+ """Tool for file operations on HF repos."""
41
+
42
+ def __init__(self, hf_token: Optional[str] = None):
43
+ self.api = HfApi(token=hf_token)
44
+
45
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
+ """Execute the specified operation."""
47
+ operation = args.get("operation")
48
+
49
+ if not operation:
50
+ return self._help()
51
+
52
+ try:
53
+ handlers = {
54
+ "list": self._list,
55
+ "read": self._read,
56
+ "upload": self._upload,
57
+ "delete": self._delete,
58
+ }
59
+
60
+ handler = handlers.get(operation)
61
+ if handler:
62
+ return await handler(args)
63
+ else:
64
+ return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
65
+
66
+ except RepositoryNotFoundError:
67
+ return self._error(f"Repository not found: {args.get('repo_id')}")
68
+ except EntryNotFoundError:
69
+ return self._error(f"File not found: {args.get('path')}")
70
+ except Exception as e:
71
+ return self._error(f"Error: {str(e)}")
72
+
73
+ def _help(self) -> ToolResult:
74
+ """Show usage instructions."""
75
+ return {
76
+ "formatted": """**hf_repo_files** - File operations on HF repos
77
+
78
+ **Operations:**
79
+ - `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
80
+ - `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
81
+ - `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
82
+ - `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
83
+
84
+ **Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
85
+ "totalResults": 1,
86
+ "resultsShared": 1,
87
+ }
88
+
89
+ async def _list(self, args: Dict[str, Any]) -> ToolResult:
90
+ """List files in a repository."""
91
+ repo_id = args.get("repo_id")
92
+ if not repo_id:
93
+ return self._error("repo_id is required")
94
+
95
+ repo_type = args.get("repo_type", "model")
96
+ revision = args.get("revision", "main")
97
+ path = args.get("path", "")
98
+
99
+ items = list(await _async_call(
100
+ self.api.list_repo_tree,
101
+ repo_id=repo_id,
102
+ repo_type=repo_type,
103
+ revision=revision,
104
+ path_in_repo=path,
105
+ recursive=True,
106
+ ))
107
+
108
+ if not items:
109
+ return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
110
+
111
+ lines = []
112
+ total_size = 0
113
+ for item in sorted(items, key=lambda x: x.path):
114
+ if hasattr(item, "size") and item.size:
115
+ total_size += item.size
116
+ lines.append(f"{item.path} ({_format_size(item.size)})")
117
+ else:
118
+ lines.append(f"{item.path}/")
119
+
120
+ url = _build_repo_url(repo_id, repo_type)
121
+ response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
122
+
123
+ return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
124
+
125
+ async def _read(self, args: Dict[str, Any]) -> ToolResult:
126
+ """Read file content from a repository."""
127
+ repo_id = args.get("repo_id")
128
+ path = args.get("path")
129
+
130
+ if not repo_id:
131
+ return self._error("repo_id is required")
132
+ if not path:
133
+ return self._error("path is required")
134
+
135
+ repo_type = args.get("repo_type", "model")
136
+ revision = args.get("revision", "main")
137
+ max_chars = args.get("max_chars", 50000)
138
+
139
+ file_path = await _async_call(
140
+ hf_hub_download,
141
+ repo_id=repo_id,
142
+ filename=path,
143
+ repo_type=repo_type,
144
+ revision=revision,
145
+ token=self.api.token,
146
+ )
147
+
148
+ try:
149
+ with open(file_path, "r", encoding="utf-8") as f:
150
+ content = f.read()
151
+
152
+ truncated = len(content) > max_chars
153
+ if truncated:
154
+ content = content[:max_chars]
155
+
156
+ url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
157
+ response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
158
+
159
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
160
+
161
+ except UnicodeDecodeError:
162
+ import os
163
+ size = os.path.getsize(file_path)
164
+ return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
165
+
166
+ async def _upload(self, args: Dict[str, Any]) -> ToolResult:
167
+ """Upload content to a repository."""
168
+ repo_id = args.get("repo_id")
169
+ path = args.get("path")
170
+ content = args.get("content")
171
+
172
+ if not repo_id:
173
+ return self._error("repo_id is required")
174
+ if not path:
175
+ return self._error("path is required")
176
+ if content is None:
177
+ return self._error("content is required")
178
+
179
+ repo_type = args.get("repo_type", "model")
180
+ revision = args.get("revision", "main")
181
+ create_pr = args.get("create_pr", False)
182
+ commit_message = args.get("commit_message", f"Upload {path}")
183
+
184
+ file_bytes = content.encode("utf-8") if isinstance(content, str) else content
185
+
186
+ result = await _async_call(
187
+ self.api.upload_file,
188
+ path_or_fileobj=file_bytes,
189
+ path_in_repo=path,
190
+ repo_id=repo_id,
191
+ repo_type=repo_type,
192
+ revision=revision,
193
+ commit_message=commit_message,
194
+ create_pr=create_pr,
195
+ )
196
+
197
+ url = _build_repo_url(repo_id, repo_type)
198
+ if create_pr and hasattr(result, "pr_url"):
199
+ response = f"**Uploaded as PR**\n{result.pr_url}"
200
+ else:
201
+ response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
202
+
203
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
204
+
205
+ async def _delete(self, args: Dict[str, Any]) -> ToolResult:
206
+ """Delete files from a repository."""
207
+ repo_id = args.get("repo_id")
208
+ patterns = args.get("patterns")
209
+
210
+ if not repo_id:
211
+ return self._error("repo_id is required")
212
+ if not patterns:
213
+ return self._error("patterns is required (list of paths/wildcards)")
214
+
215
+ if isinstance(patterns, str):
216
+ patterns = [patterns]
217
+
218
+ repo_type = args.get("repo_type", "model")
219
+ revision = args.get("revision", "main")
220
+ create_pr = args.get("create_pr", False)
221
+ commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
222
+
223
+ await _async_call(
224
+ self.api.delete_files,
225
+ repo_id=repo_id,
226
+ delete_patterns=patterns,
227
+ repo_type=repo_type,
228
+ revision=revision,
229
+ commit_message=commit_message,
230
+ create_pr=create_pr,
231
+ )
232
+
233
+ response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
234
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
235
+
236
+ def _error(self, message: str) -> ToolResult:
237
+ """Return an error result."""
238
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
239
+
240
+
241
+ # Tool specification
242
+ HF_REPO_FILES_TOOL_SPEC = {
243
+ "name": "hf_repo_files",
244
+ "description": (
245
+ "Read and write files in HF repos (models/datasets/spaces).\n\n"
246
+ "## Operations\n"
247
+ "- **list**: List files with sizes and structure\n"
248
+ "- **read**: Read file content (text files only)\n"
249
+ "- **upload**: Upload content to repo (can create PR)\n"
250
+ "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
251
+ "## Use when\n"
252
+ "- Need to see what files exist in a repo\n"
253
+ "- Want to read config.json, README.md, or other text files\n"
254
+ "- Uploading training scripts, configs, or results to a repo\n"
255
+ "- Cleaning up temporary files from a repo\n\n"
256
+ "## Examples\n"
257
+ '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
258
+ '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
259
+ '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
260
+ '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
261
+ '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
262
+ "## Notes\n"
263
+ "- For binary files (safetensors, bin), use list to see them but can't read content\n"
264
+ "- upload/delete require approval (can overwrite/destroy data)\n"
265
+ "- Use create_pr=true to propose changes instead of direct commit\n"
266
+ ),
267
+ "parameters": {
268
+ "type": "object",
269
+ "properties": {
270
+ "operation": {
271
+ "type": "string",
272
+ "enum": ["list", "read", "upload", "delete"],
273
+ "description": "Operation: list, read, upload, delete",
274
+ },
275
+ "repo_id": {
276
+ "type": "string",
277
+ "description": "Repository ID (e.g., 'username/repo-name')",
278
+ },
279
+ "repo_type": {
280
+ "type": "string",
281
+ "enum": ["model", "dataset", "space"],
282
+ "description": "Repository type (default: model)",
283
+ },
284
+ "revision": {
285
+ "type": "string",
286
+ "description": "Branch/tag/commit (default: main)",
287
+ },
288
+ "path": {
289
+ "type": "string",
290
+ "description": "File path for read/upload",
291
+ },
292
+ "content": {
293
+ "type": "string",
294
+ "description": "File content for upload",
295
+ },
296
+ "patterns": {
297
+ "type": "array",
298
+ "items": {"type": "string"},
299
+ "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
300
+ },
301
+ "create_pr": {
302
+ "type": "boolean",
303
+ "description": "Create PR instead of direct commit",
304
+ },
305
+ "commit_message": {
306
+ "type": "string",
307
+ "description": "Custom commit message",
308
+ },
309
+ },
310
+ "required": ["operation"],
311
+ },
312
+ }
313
+
314
+
315
+ async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
316
+ """Handler for agent tool router."""
317
+ try:
318
+ tool = HfRepoFilesTool()
319
+ result = await tool.execute(arguments)
320
+ return result["formatted"], not result.get("isError", False)
321
+ except Exception as e:
322
+ return f"Error: {str(e)}", False
agent/tools/hf_repo_git_tool.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Git Tool - Git-like operations on Hugging Face repositories
3
+
4
+ Operations: branches, tags, PRs, repo management
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal[
16
+ "create_branch", "delete_branch",
17
+ "create_tag", "delete_tag",
18
+ "list_refs",
19
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr",
20
+ "create_repo", "update_repo",
21
+ ]
22
+
23
+
24
+ async def _async_call(func, *args, **kwargs):
25
+ """Wrap synchronous HfApi calls for async context."""
26
+ return await asyncio.to_thread(func, *args, **kwargs)
27
+
28
+
29
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
30
+ """Build the Hub URL for a repository."""
31
+ if repo_type == "model":
32
+ return f"https://huggingface.co/{repo_id}"
33
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
34
+
35
+
36
+ class HfRepoGitTool:
37
+ """Tool for git-like operations on HF repos."""
38
+
39
+ def __init__(self, hf_token: Optional[str] = None):
40
+ self.api = HfApi(token=hf_token)
41
+
42
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
43
+ """Execute the specified operation."""
44
+ operation = args.get("operation")
45
+
46
+ if not operation:
47
+ return self._help()
48
+
49
+ try:
50
+ handlers = {
51
+ "create_branch": self._create_branch,
52
+ "delete_branch": self._delete_branch,
53
+ "create_tag": self._create_tag,
54
+ "delete_tag": self._delete_tag,
55
+ "list_refs": self._list_refs,
56
+ "create_pr": self._create_pr,
57
+ "list_prs": self._list_prs,
58
+ "get_pr": self._get_pr,
59
+ "merge_pr": self._merge_pr,
60
+ "close_pr": self._close_pr,
61
+ "comment_pr": self._comment_pr,
62
+ "create_repo": self._create_repo,
63
+ "update_repo": self._update_repo,
64
+ }
65
+
66
+ handler = handlers.get(operation)
67
+ if handler:
68
+ return await handler(args)
69
+ else:
70
+ ops = ", ".join(handlers.keys())
71
+ return self._error(f"Unknown operation: {operation}. Valid: {ops}")
72
+
73
+ except RepositoryNotFoundError:
74
+ return self._error(f"Repository not found: {args.get('repo_id')}")
75
+ except Exception as e:
76
+ return self._error(f"Error: {str(e)}")
77
+
78
+ def _help(self) -> ToolResult:
79
+ """Show usage instructions."""
80
+ return {
81
+ "formatted": """**hf_repo_git** - Git-like operations on HF repos
82
+
83
+ **Branch/Tag:**
84
+ - `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
85
+ - `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
86
+ - `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
87
+ - `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
88
+ - `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
89
+
90
+ **PRs:**
91
+ - `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}`
92
+ - `list_prs`: `{"operation": "list_prs", "repo_id": "..."}`
93
+ - `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}`
94
+ - `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
95
+ - `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
96
+ - `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
97
+
98
+ **Repo:**
99
+ - `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
100
+ - `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
101
+ "totalResults": 1,
102
+ "resultsShared": 1,
103
+ }
104
+
105
+ # =========================================================================
106
+ # BRANCH OPERATIONS
107
+ # =========================================================================
108
+
109
+ async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
110
+ """Create a new branch."""
111
+ repo_id = args.get("repo_id")
112
+ branch = args.get("branch")
113
+
114
+ if not repo_id:
115
+ return self._error("repo_id is required")
116
+ if not branch:
117
+ return self._error("branch is required")
118
+
119
+ repo_type = args.get("repo_type", "model")
120
+ from_rev = args.get("from_rev", "main")
121
+
122
+ await _async_call(
123
+ self.api.create_branch,
124
+ repo_id=repo_id,
125
+ branch=branch,
126
+ revision=from_rev,
127
+ repo_type=repo_type,
128
+ exist_ok=args.get("exist_ok", False),
129
+ )
130
+
131
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
132
+ return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
133
+
134
+ async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
135
+ """Delete a branch."""
136
+ repo_id = args.get("repo_id")
137
+ branch = args.get("branch")
138
+
139
+ if not repo_id:
140
+ return self._error("repo_id is required")
141
+ if not branch:
142
+ return self._error("branch is required")
143
+
144
+ repo_type = args.get("repo_type", "model")
145
+
146
+ await _async_call(
147
+ self.api.delete_branch,
148
+ repo_id=repo_id,
149
+ branch=branch,
150
+ repo_type=repo_type,
151
+ )
152
+
153
+ return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
154
+
155
+ # =========================================================================
156
+ # TAG OPERATIONS
157
+ # =========================================================================
158
+
159
+ async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
160
+ """Create a tag."""
161
+ repo_id = args.get("repo_id")
162
+ tag = args.get("tag")
163
+
164
+ if not repo_id:
165
+ return self._error("repo_id is required")
166
+ if not tag:
167
+ return self._error("tag is required")
168
+
169
+ repo_type = args.get("repo_type", "model")
170
+ revision = args.get("revision", "main")
171
+ tag_message = args.get("tag_message", "")
172
+
173
+ await _async_call(
174
+ self.api.create_tag,
175
+ repo_id=repo_id,
176
+ tag=tag,
177
+ revision=revision,
178
+ tag_message=tag_message,
179
+ repo_type=repo_type,
180
+ exist_ok=args.get("exist_ok", False),
181
+ )
182
+
183
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
184
+ return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
185
+
186
+ async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
187
+ """Delete a tag."""
188
+ repo_id = args.get("repo_id")
189
+ tag = args.get("tag")
190
+
191
+ if not repo_id:
192
+ return self._error("repo_id is required")
193
+ if not tag:
194
+ return self._error("tag is required")
195
+
196
+ repo_type = args.get("repo_type", "model")
197
+
198
+ await _async_call(
199
+ self.api.delete_tag,
200
+ repo_id=repo_id,
201
+ tag=tag,
202
+ repo_type=repo_type,
203
+ )
204
+
205
+ return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
206
+
207
+ # =========================================================================
208
+ # LIST REFS
209
+ # =========================================================================
210
+
211
+ async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
212
+ """List branches and tags."""
213
+ repo_id = args.get("repo_id")
214
+
215
+ if not repo_id:
216
+ return self._error("repo_id is required")
217
+
218
+ repo_type = args.get("repo_type", "model")
219
+
220
+ refs = await _async_call(
221
+ self.api.list_repo_refs,
222
+ repo_id=repo_id,
223
+ repo_type=repo_type,
224
+ )
225
+
226
+ branches = [b.name for b in refs.branches] if refs.branches else []
227
+ tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
228
+
229
+ url = _build_repo_url(repo_id, repo_type)
230
+ lines = [f"**{repo_id}**", url, ""]
231
+
232
+ if branches:
233
+ lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
234
+ else:
235
+ lines.append("**Branches:** none")
236
+
237
+ if tags:
238
+ lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
239
+ else:
240
+ lines.append("**Tags:** none")
241
+
242
+ return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
243
+
244
+ # =========================================================================
245
+ # PR OPERATIONS
246
+ # =========================================================================
247
+
248
+ async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
249
+ """Create a pull request."""
250
+ repo_id = args.get("repo_id")
251
+ title = args.get("title")
252
+
253
+ if not repo_id:
254
+ return self._error("repo_id is required")
255
+ if not title:
256
+ return self._error("title is required")
257
+
258
+ repo_type = args.get("repo_type", "model")
259
+ description = args.get("description", "")
260
+
261
+ result = await _async_call(
262
+ self.api.create_pull_request,
263
+ repo_id=repo_id,
264
+ title=title,
265
+ description=description,
266
+ repo_type=repo_type,
267
+ )
268
+
269
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
270
+ return {
271
+ "formatted": f"**PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
272
+ "totalResults": 1,
273
+ "resultsShared": 1,
274
+ }
275
+
276
+ async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
277
+ """List PRs and discussions."""
278
+ repo_id = args.get("repo_id")
279
+
280
+ if not repo_id:
281
+ return self._error("repo_id is required")
282
+
283
+ repo_type = args.get("repo_type", "model")
284
+ status = args.get("status", "all") # open, closed, all
285
+
286
+ discussions = list(self.api.get_repo_discussions(
287
+ repo_id=repo_id,
288
+ repo_type=repo_type,
289
+ discussion_status=status if status != "all" else None,
290
+ ))
291
+
292
+ if not discussions:
293
+ return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
294
+
295
+ url = _build_repo_url(repo_id, repo_type)
296
+ lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
297
+
298
+ for d in discussions[:20]:
299
+ emoji = "🟢" if d.status == "open" else "🔴"
300
+ type_label = "PR" if d.is_pull_request else "D"
301
+ lines.append(f"{emoji} #{d.num} [{type_label}] {d.title}")
302
+
303
+ return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
304
+
305
+ async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
306
+ """Get PR details."""
307
+ repo_id = args.get("repo_id")
308
+ pr_num = args.get("pr_num")
309
+
310
+ if not repo_id:
311
+ return self._error("repo_id is required")
312
+ if not pr_num:
313
+ return self._error("pr_num is required")
314
+
315
+ repo_type = args.get("repo_type", "model")
316
+
317
+ pr = await _async_call(
318
+ self.api.get_discussion_details,
319
+ repo_id=repo_id,
320
+ discussion_num=int(pr_num),
321
+ repo_type=repo_type,
322
+ )
323
+
324
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
325
+ status = "🟢 Open" if pr.status == "open" else "🔴 Closed"
326
+ type_label = "Pull Request" if pr.is_pull_request else "Discussion"
327
+
328
+ lines = [
329
+ f"**{type_label} #{pr_num}:** {pr.title}",
330
+ f"**Status:** {status}",
331
+ f"**Author:** {pr.author}",
332
+ url,
333
+ ]
334
+
335
+ if pr.is_pull_request:
336
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
337
+
338
+ return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
339
+
340
+ async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
341
+ """Merge a pull request."""
342
+ repo_id = args.get("repo_id")
343
+ pr_num = args.get("pr_num")
344
+
345
+ if not repo_id:
346
+ return self._error("repo_id is required")
347
+ if not pr_num:
348
+ return self._error("pr_num is required")
349
+
350
+ repo_type = args.get("repo_type", "model")
351
+ comment = args.get("comment", "")
352
+
353
+ await _async_call(
354
+ self.api.merge_pull_request,
355
+ repo_id=repo_id,
356
+ discussion_num=int(pr_num),
357
+ comment=comment,
358
+ repo_type=repo_type,
359
+ )
360
+
361
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
362
+ return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
363
+
364
+ async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
365
+ """Close a PR/discussion."""
366
+ repo_id = args.get("repo_id")
367
+ pr_num = args.get("pr_num")
368
+
369
+ if not repo_id:
370
+ return self._error("repo_id is required")
371
+ if not pr_num:
372
+ return self._error("pr_num is required")
373
+
374
+ repo_type = args.get("repo_type", "model")
375
+ comment = args.get("comment", "")
376
+
377
+ await _async_call(
378
+ self.api.change_discussion_status,
379
+ repo_id=repo_id,
380
+ discussion_num=int(pr_num),
381
+ new_status="closed",
382
+ comment=comment,
383
+ repo_type=repo_type,
384
+ )
385
+
386
+ return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
387
+
388
+ async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
389
+ """Add a comment to a PR/discussion."""
390
+ repo_id = args.get("repo_id")
391
+ pr_num = args.get("pr_num")
392
+ comment = args.get("comment")
393
+
394
+ if not repo_id:
395
+ return self._error("repo_id is required")
396
+ if not pr_num:
397
+ return self._error("pr_num is required")
398
+ if not comment:
399
+ return self._error("comment is required")
400
+
401
+ repo_type = args.get("repo_type", "model")
402
+
403
+ await _async_call(
404
+ self.api.comment_discussion,
405
+ repo_id=repo_id,
406
+ discussion_num=int(pr_num),
407
+ comment=comment,
408
+ repo_type=repo_type,
409
+ )
410
+
411
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
412
+ return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
413
+
414
+ # =========================================================================
415
+ # REPO MANAGEMENT
416
+ # =========================================================================
417
+
418
+ async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
419
+ """Create a new repository."""
420
+ repo_id = args.get("repo_id")
421
+
422
+ if not repo_id:
423
+ return self._error("repo_id is required")
424
+
425
+ repo_type = args.get("repo_type", "model")
426
+ private = args.get("private", True)
427
+ space_sdk = args.get("space_sdk")
428
+
429
+ if repo_type == "space" and not space_sdk:
430
+ return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
431
+
432
+ kwargs = {
433
+ "repo_id": repo_id,
434
+ "repo_type": repo_type,
435
+ "private": private,
436
+ "exist_ok": args.get("exist_ok", False),
437
+ }
438
+ if space_sdk:
439
+ kwargs["space_sdk"] = space_sdk
440
+
441
+ result = await _async_call(self.api.create_repo, **kwargs)
442
+
443
+ return {
444
+ "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
445
+ "totalResults": 1,
446
+ "resultsShared": 1,
447
+ }
448
+
449
+ async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
450
+ """Update repository settings."""
451
+ repo_id = args.get("repo_id")
452
+
453
+ if not repo_id:
454
+ return self._error("repo_id is required")
455
+
456
+ repo_type = args.get("repo_type", "model")
457
+ private = args.get("private")
458
+ gated = args.get("gated")
459
+
460
+ if private is None and gated is None:
461
+ return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
462
+
463
+ kwargs = {"repo_id": repo_id, "repo_type": repo_type}
464
+ if private is not None:
465
+ kwargs["private"] = private
466
+ if gated is not None:
467
+ kwargs["gated"] = gated
468
+
469
+ await _async_call(self.api.update_repo_settings, **kwargs)
470
+
471
+ changes = []
472
+ if private is not None:
473
+ changes.append(f"private={private}")
474
+ if gated is not None:
475
+ changes.append(f"gated={gated}")
476
+
477
+ url = f"{_build_repo_url(repo_id, repo_type)}/settings"
478
+ return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
479
+
480
+ def _error(self, message: str) -> ToolResult:
481
+ """Return an error result."""
482
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
483
+
484
+
485
+ # Tool specification
486
+ HF_REPO_GIT_TOOL_SPEC = {
487
+ "name": "hf_repo_git",
488
+ "description": (
489
+ "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
490
+ "## Operations\n"
491
+ "**Branches:** create_branch, delete_branch, list_refs\n"
492
+ "**Tags:** create_tag, delete_tag\n"
493
+ "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr\n"
494
+ "**Repo:** create_repo, update_repo\n\n"
495
+ "## Use when\n"
496
+ "- Creating feature branches for experiments\n"
497
+ "- Tagging model versions (v1.0, v2.0)\n"
498
+ "- Opening PRs to contribute to repos you don't own\n"
499
+ "- Reviewing and merging PRs on your repos\n"
500
+ "- Creating new model/dataset/space repos\n"
501
+ "- Changing repo visibility (public/private) or gated access\n\n"
502
+ "## Examples\n"
503
+ '{"operation": "list_refs", "repo_id": "my-model"}\n'
504
+ '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
505
+ '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
506
+ '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
507
+ '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
508
+ '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
509
+ '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
510
+ "## PR Workflow\n"
511
+ "1. create_pr → creates empty draft PR\n"
512
+ "2. Upload files with revision='refs/pr/N' to add commits\n"
513
+ "3. merge_pr when ready\n\n"
514
+ "## Notes\n"
515
+ "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
516
+ "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
517
+ "- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
518
+ ),
519
+ "parameters": {
520
+ "type": "object",
521
+ "properties": {
522
+ "operation": {
523
+ "type": "string",
524
+ "enum": [
525
+ "create_branch", "delete_branch",
526
+ "create_tag", "delete_tag", "list_refs",
527
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr",
528
+ "create_repo", "update_repo",
529
+ ],
530
+ "description": "Operation to execute",
531
+ },
532
+ "repo_id": {
533
+ "type": "string",
534
+ "description": "Repository ID (e.g., 'username/repo-name')",
535
+ },
536
+ "repo_type": {
537
+ "type": "string",
538
+ "enum": ["model", "dataset", "space"],
539
+ "description": "Repository type (default: model)",
540
+ },
541
+ "branch": {
542
+ "type": "string",
543
+ "description": "Branch name (create_branch, delete_branch)",
544
+ },
545
+ "from_rev": {
546
+ "type": "string",
547
+ "description": "Create branch from this revision (default: main)",
548
+ },
549
+ "tag": {
550
+ "type": "string",
551
+ "description": "Tag name (create_tag, delete_tag)",
552
+ },
553
+ "revision": {
554
+ "type": "string",
555
+ "description": "Revision for tag (default: main)",
556
+ },
557
+ "tag_message": {
558
+ "type": "string",
559
+ "description": "Tag description",
560
+ },
561
+ "title": {
562
+ "type": "string",
563
+ "description": "PR title (create_pr)",
564
+ },
565
+ "description": {
566
+ "type": "string",
567
+ "description": "PR description (create_pr)",
568
+ },
569
+ "pr_num": {
570
+ "type": "integer",
571
+ "description": "PR/discussion number",
572
+ },
573
+ "comment": {
574
+ "type": "string",
575
+ "description": "Comment text",
576
+ },
577
+ "status": {
578
+ "type": "string",
579
+ "enum": ["open", "closed", "all"],
580
+ "description": "Filter PRs by status (list_prs)",
581
+ },
582
+ "private": {
583
+ "type": "boolean",
584
+ "description": "Make repo private (create_repo, update_repo)",
585
+ },
586
+ "gated": {
587
+ "type": "string",
588
+ "enum": ["auto", "manual", "false"],
589
+ "description": "Gated access setting (update_repo)",
590
+ },
591
+ "space_sdk": {
592
+ "type": "string",
593
+ "enum": ["gradio", "streamlit", "docker", "static"],
594
+ "description": "Space SDK (required for create_repo with space)",
595
+ },
596
+ },
597
+ "required": ["operation"],
598
+ },
599
+ }
600
+
601
+
602
+ async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
603
+ """Handler for agent tool router."""
604
+ try:
605
+ tool = HfRepoGitTool()
606
+ result = await tool.execute(arguments)
607
+ return result["formatted"], not result.get("isError", False)
608
+ except Exception as e:
609
+ return f"Error: {str(e)}", False
test_dataset_tools.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Test script for unified dataset inspection tool
3
  """
4
 
5
  import asyncio
@@ -8,7 +8,7 @@ from typing import TypedDict
8
  from unittest.mock import MagicMock
9
 
10
 
11
- # Mock the types module before importing dataset_tools
12
  class ToolResult(TypedDict, total=False):
13
  formatted: str
14
  totalResults: int
@@ -20,60 +20,70 @@ mock_types = MagicMock()
20
  mock_types.ToolResult = ToolResult
21
  sys.modules["agent.tools.types"] = mock_types
22
 
23
- # Now import directly from the file
24
- sys.path.insert(0, "/Users/akseljoonas/Documents/hf-agent/agent/tools")
25
- from dataset_tools import hf_inspect_dataset_handler, inspect_dataset
26
 
27
 
28
- async def test_inspect_dataset():
29
- """Test the unified inspect_dataset function"""
30
- print("=" * 70)
31
- print("Testing inspect_dataset()")
32
- print("=" * 70)
33
 
34
- # Test with akseljoonas/hf-agent-sessions as specified
35
- print("\n→ inspect_dataset('akseljoonas/hf-agent-sessions'):")
36
- result = await inspect_dataset("akseljoonas/hf-agent-sessions")
37
- print(f" isError: {result['isError']}")
38
- print(f" Output:\n{result['formatted']}")
39
 
40
- print("\n" + "=" * 70)
41
-
42
- # # Test with stanfordnlp/imdb
43
- # print("\n→ inspect_dataset('stanfordnlp/imdb'):")
44
- # result = await inspect_dataset("stanfordnlp/imdb")
45
- # print(f" isError: {result['isError']}")
46
- # print(f" Output:\n{result['formatted']}")
47
-
48
- # print("\n" + "=" * 70)
 
 
 
 
 
 
 
 
 
 
49
 
50
- # # Test with multi-config dataset
51
- # print("\n→ inspect_dataset('nyu-mll/glue', config='mrpc'):")
52
- # result = await inspect_dataset("nyu-mll/glue", config="mrpc")
53
- # print(f" isError: {result['isError']}")
54
- # print(f" Output:\n{result['formatted']}")
55
 
 
 
 
 
 
56
 
57
- async def test_handler():
58
- """Test the handler (what the agent calls)"""
59
- print("\n" + "=" * 70)
60
- print("Testing hf_inspect_dataset_handler()")
61
- print("=" * 70)
62
 
63
- result, success = await hf_inspect_dataset_handler(
64
- {
65
- "dataset": "stanfordnlp/imdb",
66
- "sample_rows": 2,
67
- }
 
 
 
 
68
  )
69
- print("\n→ Handler result:")
70
- print(f" success: {success}")
71
- print(f" output:\n{result}")
 
 
 
 
72
 
73
 
74
  if __name__ == "__main__":
75
- print("\nUnified Dataset Inspection Tool Test\n")
76
- asyncio.run(test_inspect_dataset())
77
- # asyncio.run(test_handler())
78
- print("\n" + "=" * 70)
79
- print("Done!")
 
 
1
  """
2
+ Test script for hf_repo_files and hf_repo_git tools
3
  """
4
 
5
  import asyncio
 
8
  from unittest.mock import MagicMock
9
 
10
 
11
+ # Mock the types module before importing
12
  class ToolResult(TypedDict, total=False):
13
  formatted: str
14
  totalResults: int
 
20
  mock_types.ToolResult = ToolResult
21
  sys.modules["agent.tools.types"] = mock_types
22
 
23
+ from agent.tools.hf_repo_files_tool import HfRepoFilesTool
24
+ from agent.tools.hf_repo_git_tool import HfRepoGitTool
 
25
 
26
 
27
+ async def test_hf_repo_files():
28
+ """Test hf_repo_files tool"""
29
+ print("=" * 60)
30
+ print("Testing hf_repo_files")
31
+ print("=" * 60)
32
 
33
+ tool = HfRepoFilesTool()
 
 
 
 
34
 
35
+ # Test list
36
+ print("\n→ list files in gpt2:")
37
+ result = await tool.execute(
38
+ {"operation": "list", "repo_id": "openai-community/gpt2"}
39
+ )
40
+ print(f" isError: {result.get('isError', False)}")
41
+ print(f" totalResults: {result['totalResults']}")
42
+ # Just show first few lines
43
+ lines = result["formatted"].split("\n")
44
+ print(" Output (first 5 lines):\n" + "\n".join(f" {line}" for line in lines))
45
+
46
+ # Test read
47
+ print("\n→ read config.json from gpt2:")
48
+ result = await tool.execute(
49
+ {"operation": "read", "repo_id": "openai-community/gpt2", "path": "config.json"}
50
+ )
51
+ print(f" isError: {result.get('isError', False)}")
52
+ lines = result["formatted"].split("\n")
53
+ print(" Output (first 10 lines):\n" + "\n".join(f" {line}" for line in lines))
54
 
 
 
 
 
 
55
 
56
+ async def test_hf_repo_git():
57
+ """Test hf_repo_git tool"""
58
+ print("\n" + "=" * 60)
59
+ print("Testing hf_repo_git")
60
+ print("=" * 60)
61
 
62
+ tool = HfRepoGitTool()
 
 
 
 
63
 
64
+ # Test list_refs
65
+ print("\n→ list_refs for gpt2:")
66
+ result = await tool.execute(
67
+ {"operation": "list_refs", "repo_id": "openai-community/gpt2"}
68
+ )
69
+ print(f" isError: {result.get('isError', False)}")
70
+ print(
71
+ " Output:\n"
72
+ + "\n".join(f" {line}" for line in result["formatted"].split("\n"))
73
  )
74
+
75
+ # Test help (no operation)
76
+ print("\n→ help (no operation):")
77
+ result = await tool.execute({})
78
+ print(f" isError: {result.get('isError', False)}")
79
+ lines = result["formatted"].split("\n")[:6]
80
+ print(" Output (first 6 lines):\n" + "\n".join(f" {line}" for line in lines))
81
 
82
 
83
  if __name__ == "__main__":
84
+ print("\nHF Repo Tools Test\n")
85
+ asyncio.run(test_hf_repo_files())
86
+ asyncio.run(test_hf_repo_git())
87
+ print("\n" + "=" * 60)
88
+ print("Tests complete!")
89
+ print("=" * 60)