Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit ·
bea6687
1
Parent(s): 2a25e6b
Add hf_repo_files and hf_repo_git tools
Browse filesNew tools to replace private_hf_repo_tools with comprehensive HF repo management:
hf_repo_files (4 ops):
- list: List files with sizes
- read: Read file content
- upload: Upload content (approval required)
- delete: Delete files with wildcards (approval required)
hf_repo_git (13 ops):
- Branches: create_branch, delete_branch, list_refs
- Tags: create_tag, delete_tag
- PRs: create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr
- Repo: create_repo, update_repo
Approval required for destructive/overwriting operations:
- hf_repo_files: upload, delete
- hf_repo_git: delete_branch, delete_tag, merge_pr, create_repo, update_repo
- agent/core/agent_loop.py +13 -1
- agent/core/tools.py +22 -8
- agent/tools/hf_repo_files_tool.py +322 -0
- agent/tools/hf_repo_git_tool.py +609 -0
- test_dataset_tools.py +57 -47
agent/core/agent_loop.py
CHANGED
|
@@ -76,7 +76,19 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
|
|
| 76 |
# Other operations (create_repo, etc.) always require approval
|
| 77 |
if operation in ["create_repo"]:
|
| 78 |
return True
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
return False
|
| 81 |
|
| 82 |
|
|
|
|
| 76 |
# Other operations (create_repo, etc.) always require approval
|
| 77 |
if operation in ["create_repo"]:
|
| 78 |
return True
|
| 79 |
+
|
| 80 |
+
# hf_repo_files: upload (can overwrite) and delete require approval
|
| 81 |
+
if tool_name == "hf_repo_files":
|
| 82 |
+
operation = tool_args.get("operation", "")
|
| 83 |
+
if operation in ["upload", "delete"]:
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
# hf_repo_git: destructive operations require approval
|
| 87 |
+
if tool_name == "hf_repo_git":
|
| 88 |
+
operation = tool_args.get("operation", "")
|
| 89 |
+
if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
return False
|
| 93 |
|
| 94 |
|
agent/core/tools.py
CHANGED
|
@@ -37,8 +37,16 @@ from agent.tools.github_read_file import (
|
|
| 37 |
)
|
| 38 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 39 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
# NOTE: Private HF repo tool disabled -
|
| 42 |
# from agent.tools.private_hf_repo_tools import (
|
| 43 |
# PRIVATE_HF_REPO_TOOL_SPEC,
|
| 44 |
# private_hf_repo_handler,
|
|
@@ -280,13 +288,19 @@ def create_builtin_tools() -> list[ToolSpec]:
|
|
| 280 |
parameters=HF_JOBS_TOOL_SPEC["parameters"],
|
| 281 |
handler=hf_jobs_handler,
|
| 282 |
),
|
| 283 |
-
#
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
# NOTE: Github search code tool disabled - a bit buggy
|
|
|
|
| 37 |
)
|
| 38 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 39 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 40 |
+
from agent.tools.hf_repo_files_tool import (
|
| 41 |
+
HF_REPO_FILES_TOOL_SPEC,
|
| 42 |
+
hf_repo_files_handler,
|
| 43 |
+
)
|
| 44 |
+
from agent.tools.hf_repo_git_tool import (
|
| 45 |
+
HF_REPO_GIT_TOOL_SPEC,
|
| 46 |
+
hf_repo_git_handler,
|
| 47 |
+
)
|
| 48 |
|
| 49 |
+
# NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
|
| 50 |
# from agent.tools.private_hf_repo_tools import (
|
| 51 |
# PRIVATE_HF_REPO_TOOL_SPEC,
|
| 52 |
# private_hf_repo_handler,
|
|
|
|
| 288 |
parameters=HF_JOBS_TOOL_SPEC["parameters"],
|
| 289 |
handler=hf_jobs_handler,
|
| 290 |
),
|
| 291 |
+
# HF Repo management tools
|
| 292 |
+
ToolSpec(
|
| 293 |
+
name=HF_REPO_FILES_TOOL_SPEC["name"],
|
| 294 |
+
description=HF_REPO_FILES_TOOL_SPEC["description"],
|
| 295 |
+
parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
|
| 296 |
+
handler=hf_repo_files_handler,
|
| 297 |
+
),
|
| 298 |
+
ToolSpec(
|
| 299 |
+
name=HF_REPO_GIT_TOOL_SPEC["name"],
|
| 300 |
+
description=HF_REPO_GIT_TOOL_SPEC["description"],
|
| 301 |
+
parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
|
| 302 |
+
handler=hf_repo_git_handler,
|
| 303 |
+
),
|
| 304 |
|
| 305 |
|
| 306 |
# NOTE: Github search code tool disabled - a bit buggy
|
agent/tools/hf_repo_files_tool.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF Repo Files Tool - File operations on Hugging Face repositories
|
| 3 |
+
|
| 4 |
+
Operations: list, read, upload, delete
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any, Dict, Literal, Optional
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
+
|
| 13 |
+
from agent.tools.types import ToolResult
|
| 14 |
+
|
| 15 |
+
OperationType = Literal["list", "read", "upload", "delete"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def _async_call(func, *args, **kwargs):
|
| 19 |
+
"""Wrap synchronous HfApi calls for async context."""
|
| 20 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
| 24 |
+
"""Build the Hub URL for a repository."""
|
| 25 |
+
if repo_type == "model":
|
| 26 |
+
return f"https://huggingface.co/{repo_id}"
|
| 27 |
+
return f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _format_size(size_bytes: int) -> str:
|
| 31 |
+
"""Format file size in human-readable form."""
|
| 32 |
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
| 33 |
+
if size_bytes < 1024:
|
| 34 |
+
return f"{size_bytes:.1f}{unit}"
|
| 35 |
+
size_bytes /= 1024
|
| 36 |
+
return f"{size_bytes:.1f}PB"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HfRepoFilesTool:
|
| 40 |
+
"""Tool for file operations on HF repos."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, hf_token: Optional[str] = None):
|
| 43 |
+
self.api = HfApi(token=hf_token)
|
| 44 |
+
|
| 45 |
+
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 46 |
+
"""Execute the specified operation."""
|
| 47 |
+
operation = args.get("operation")
|
| 48 |
+
|
| 49 |
+
if not operation:
|
| 50 |
+
return self._help()
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
handlers = {
|
| 54 |
+
"list": self._list,
|
| 55 |
+
"read": self._read,
|
| 56 |
+
"upload": self._upload,
|
| 57 |
+
"delete": self._delete,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
handler = handlers.get(operation)
|
| 61 |
+
if handler:
|
| 62 |
+
return await handler(args)
|
| 63 |
+
else:
|
| 64 |
+
return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
|
| 65 |
+
|
| 66 |
+
except RepositoryNotFoundError:
|
| 67 |
+
return self._error(f"Repository not found: {args.get('repo_id')}")
|
| 68 |
+
except EntryNotFoundError:
|
| 69 |
+
return self._error(f"File not found: {args.get('path')}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
return self._error(f"Error: {str(e)}")
|
| 72 |
+
|
| 73 |
+
def _help(self) -> ToolResult:
|
| 74 |
+
"""Show usage instructions."""
|
| 75 |
+
return {
|
| 76 |
+
"formatted": """**hf_repo_files** - File operations on HF repos
|
| 77 |
+
|
| 78 |
+
**Operations:**
|
| 79 |
+
- `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
|
| 80 |
+
- `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
|
| 81 |
+
- `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
|
| 82 |
+
- `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
|
| 83 |
+
|
| 84 |
+
**Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
|
| 85 |
+
"totalResults": 1,
|
| 86 |
+
"resultsShared": 1,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
async def _list(self, args: Dict[str, Any]) -> ToolResult:
|
| 90 |
+
"""List files in a repository."""
|
| 91 |
+
repo_id = args.get("repo_id")
|
| 92 |
+
if not repo_id:
|
| 93 |
+
return self._error("repo_id is required")
|
| 94 |
+
|
| 95 |
+
repo_type = args.get("repo_type", "model")
|
| 96 |
+
revision = args.get("revision", "main")
|
| 97 |
+
path = args.get("path", "")
|
| 98 |
+
|
| 99 |
+
items = list(await _async_call(
|
| 100 |
+
self.api.list_repo_tree,
|
| 101 |
+
repo_id=repo_id,
|
| 102 |
+
repo_type=repo_type,
|
| 103 |
+
revision=revision,
|
| 104 |
+
path_in_repo=path,
|
| 105 |
+
recursive=True,
|
| 106 |
+
))
|
| 107 |
+
|
| 108 |
+
if not items:
|
| 109 |
+
return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
|
| 110 |
+
|
| 111 |
+
lines = []
|
| 112 |
+
total_size = 0
|
| 113 |
+
for item in sorted(items, key=lambda x: x.path):
|
| 114 |
+
if hasattr(item, "size") and item.size:
|
| 115 |
+
total_size += item.size
|
| 116 |
+
lines.append(f"{item.path} ({_format_size(item.size)})")
|
| 117 |
+
else:
|
| 118 |
+
lines.append(f"{item.path}/")
|
| 119 |
+
|
| 120 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 121 |
+
response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
|
| 122 |
+
|
| 123 |
+
return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
|
| 124 |
+
|
| 125 |
+
async def _read(self, args: Dict[str, Any]) -> ToolResult:
|
| 126 |
+
"""Read file content from a repository."""
|
| 127 |
+
repo_id = args.get("repo_id")
|
| 128 |
+
path = args.get("path")
|
| 129 |
+
|
| 130 |
+
if not repo_id:
|
| 131 |
+
return self._error("repo_id is required")
|
| 132 |
+
if not path:
|
| 133 |
+
return self._error("path is required")
|
| 134 |
+
|
| 135 |
+
repo_type = args.get("repo_type", "model")
|
| 136 |
+
revision = args.get("revision", "main")
|
| 137 |
+
max_chars = args.get("max_chars", 50000)
|
| 138 |
+
|
| 139 |
+
file_path = await _async_call(
|
| 140 |
+
hf_hub_download,
|
| 141 |
+
repo_id=repo_id,
|
| 142 |
+
filename=path,
|
| 143 |
+
repo_type=repo_type,
|
| 144 |
+
revision=revision,
|
| 145 |
+
token=self.api.token,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 150 |
+
content = f.read()
|
| 151 |
+
|
| 152 |
+
truncated = len(content) > max_chars
|
| 153 |
+
if truncated:
|
| 154 |
+
content = content[:max_chars]
|
| 155 |
+
|
| 156 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
|
| 157 |
+
response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
|
| 158 |
+
|
| 159 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 160 |
+
|
| 161 |
+
except UnicodeDecodeError:
|
| 162 |
+
import os
|
| 163 |
+
size = os.path.getsize(file_path)
|
| 164 |
+
return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
|
| 165 |
+
|
| 166 |
+
async def _upload(self, args: Dict[str, Any]) -> ToolResult:
|
| 167 |
+
"""Upload content to a repository."""
|
| 168 |
+
repo_id = args.get("repo_id")
|
| 169 |
+
path = args.get("path")
|
| 170 |
+
content = args.get("content")
|
| 171 |
+
|
| 172 |
+
if not repo_id:
|
| 173 |
+
return self._error("repo_id is required")
|
| 174 |
+
if not path:
|
| 175 |
+
return self._error("path is required")
|
| 176 |
+
if content is None:
|
| 177 |
+
return self._error("content is required")
|
| 178 |
+
|
| 179 |
+
repo_type = args.get("repo_type", "model")
|
| 180 |
+
revision = args.get("revision", "main")
|
| 181 |
+
create_pr = args.get("create_pr", False)
|
| 182 |
+
commit_message = args.get("commit_message", f"Upload {path}")
|
| 183 |
+
|
| 184 |
+
file_bytes = content.encode("utf-8") if isinstance(content, str) else content
|
| 185 |
+
|
| 186 |
+
result = await _async_call(
|
| 187 |
+
self.api.upload_file,
|
| 188 |
+
path_or_fileobj=file_bytes,
|
| 189 |
+
path_in_repo=path,
|
| 190 |
+
repo_id=repo_id,
|
| 191 |
+
repo_type=repo_type,
|
| 192 |
+
revision=revision,
|
| 193 |
+
commit_message=commit_message,
|
| 194 |
+
create_pr=create_pr,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 198 |
+
if create_pr and hasattr(result, "pr_url"):
|
| 199 |
+
response = f"**Uploaded as PR**\n{result.pr_url}"
|
| 200 |
+
else:
|
| 201 |
+
response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
|
| 202 |
+
|
| 203 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 204 |
+
|
| 205 |
+
async def _delete(self, args: Dict[str, Any]) -> ToolResult:
|
| 206 |
+
"""Delete files from a repository."""
|
| 207 |
+
repo_id = args.get("repo_id")
|
| 208 |
+
patterns = args.get("patterns")
|
| 209 |
+
|
| 210 |
+
if not repo_id:
|
| 211 |
+
return self._error("repo_id is required")
|
| 212 |
+
if not patterns:
|
| 213 |
+
return self._error("patterns is required (list of paths/wildcards)")
|
| 214 |
+
|
| 215 |
+
if isinstance(patterns, str):
|
| 216 |
+
patterns = [patterns]
|
| 217 |
+
|
| 218 |
+
repo_type = args.get("repo_type", "model")
|
| 219 |
+
revision = args.get("revision", "main")
|
| 220 |
+
create_pr = args.get("create_pr", False)
|
| 221 |
+
commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
|
| 222 |
+
|
| 223 |
+
await _async_call(
|
| 224 |
+
self.api.delete_files,
|
| 225 |
+
repo_id=repo_id,
|
| 226 |
+
delete_patterns=patterns,
|
| 227 |
+
repo_type=repo_type,
|
| 228 |
+
revision=revision,
|
| 229 |
+
commit_message=commit_message,
|
| 230 |
+
create_pr=create_pr,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
|
| 234 |
+
return {"formatted": response, "totalResults": 1, "resultsShared": 1}
|
| 235 |
+
|
| 236 |
+
def _error(self, message: str) -> ToolResult:
|
| 237 |
+
"""Return an error result."""
|
| 238 |
+
return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Tool specification
|
| 242 |
+
HF_REPO_FILES_TOOL_SPEC = {
|
| 243 |
+
"name": "hf_repo_files",
|
| 244 |
+
"description": (
|
| 245 |
+
"Read and write files in HF repos (models/datasets/spaces).\n\n"
|
| 246 |
+
"## Operations\n"
|
| 247 |
+
"- **list**: List files with sizes and structure\n"
|
| 248 |
+
"- **read**: Read file content (text files only)\n"
|
| 249 |
+
"- **upload**: Upload content to repo (can create PR)\n"
|
| 250 |
+
"- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
|
| 251 |
+
"## Use when\n"
|
| 252 |
+
"- Need to see what files exist in a repo\n"
|
| 253 |
+
"- Want to read config.json, README.md, or other text files\n"
|
| 254 |
+
"- Uploading training scripts, configs, or results to a repo\n"
|
| 255 |
+
"- Cleaning up temporary files from a repo\n\n"
|
| 256 |
+
"## Examples\n"
|
| 257 |
+
'{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
|
| 258 |
+
'{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
|
| 259 |
+
'{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
|
| 260 |
+
'{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
|
| 261 |
+
'{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
|
| 262 |
+
"## Notes\n"
|
| 263 |
+
"- For binary files (safetensors, bin), use list to see them but can't read content\n"
|
| 264 |
+
"- upload/delete require approval (can overwrite/destroy data)\n"
|
| 265 |
+
"- Use create_pr=true to propose changes instead of direct commit\n"
|
| 266 |
+
),
|
| 267 |
+
"parameters": {
|
| 268 |
+
"type": "object",
|
| 269 |
+
"properties": {
|
| 270 |
+
"operation": {
|
| 271 |
+
"type": "string",
|
| 272 |
+
"enum": ["list", "read", "upload", "delete"],
|
| 273 |
+
"description": "Operation: list, read, upload, delete",
|
| 274 |
+
},
|
| 275 |
+
"repo_id": {
|
| 276 |
+
"type": "string",
|
| 277 |
+
"description": "Repository ID (e.g., 'username/repo-name')",
|
| 278 |
+
},
|
| 279 |
+
"repo_type": {
|
| 280 |
+
"type": "string",
|
| 281 |
+
"enum": ["model", "dataset", "space"],
|
| 282 |
+
"description": "Repository type (default: model)",
|
| 283 |
+
},
|
| 284 |
+
"revision": {
|
| 285 |
+
"type": "string",
|
| 286 |
+
"description": "Branch/tag/commit (default: main)",
|
| 287 |
+
},
|
| 288 |
+
"path": {
|
| 289 |
+
"type": "string",
|
| 290 |
+
"description": "File path for read/upload",
|
| 291 |
+
},
|
| 292 |
+
"content": {
|
| 293 |
+
"type": "string",
|
| 294 |
+
"description": "File content for upload",
|
| 295 |
+
},
|
| 296 |
+
"patterns": {
|
| 297 |
+
"type": "array",
|
| 298 |
+
"items": {"type": "string"},
|
| 299 |
+
"description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
|
| 300 |
+
},
|
| 301 |
+
"create_pr": {
|
| 302 |
+
"type": "boolean",
|
| 303 |
+
"description": "Create PR instead of direct commit",
|
| 304 |
+
},
|
| 305 |
+
"commit_message": {
|
| 306 |
+
"type": "string",
|
| 307 |
+
"description": "Custom commit message",
|
| 308 |
+
},
|
| 309 |
+
},
|
| 310 |
+
"required": ["operation"],
|
| 311 |
+
},
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
|
| 316 |
+
"""Handler for agent tool router."""
|
| 317 |
+
try:
|
| 318 |
+
tool = HfRepoFilesTool()
|
| 319 |
+
result = await tool.execute(arguments)
|
| 320 |
+
return result["formatted"], not result.get("isError", False)
|
| 321 |
+
except Exception as e:
|
| 322 |
+
return f"Error: {str(e)}", False
|
agent/tools/hf_repo_git_tool.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF Repo Git Tool - Git-like operations on Hugging Face repositories
|
| 3 |
+
|
| 4 |
+
Operations: branches, tags, PRs, repo management
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any, Dict, Literal, Optional
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi
|
| 11 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
+
|
| 13 |
+
from agent.tools.types import ToolResult
|
| 14 |
+
|
| 15 |
+
OperationType = Literal[
|
| 16 |
+
"create_branch", "delete_branch",
|
| 17 |
+
"create_tag", "delete_tag",
|
| 18 |
+
"list_refs",
|
| 19 |
+
"create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr",
|
| 20 |
+
"create_repo", "update_repo",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
async def _async_call(func, *args, **kwargs):
|
| 25 |
+
"""Wrap synchronous HfApi calls for async context."""
|
| 26 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
| 30 |
+
"""Build the Hub URL for a repository."""
|
| 31 |
+
if repo_type == "model":
|
| 32 |
+
return f"https://huggingface.co/{repo_id}"
|
| 33 |
+
return f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HfRepoGitTool:
|
| 37 |
+
"""Tool for git-like operations on HF repos."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, hf_token: Optional[str] = None):
|
| 40 |
+
self.api = HfApi(token=hf_token)
|
| 41 |
+
|
| 42 |
+
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 43 |
+
"""Execute the specified operation."""
|
| 44 |
+
operation = args.get("operation")
|
| 45 |
+
|
| 46 |
+
if not operation:
|
| 47 |
+
return self._help()
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
handlers = {
|
| 51 |
+
"create_branch": self._create_branch,
|
| 52 |
+
"delete_branch": self._delete_branch,
|
| 53 |
+
"create_tag": self._create_tag,
|
| 54 |
+
"delete_tag": self._delete_tag,
|
| 55 |
+
"list_refs": self._list_refs,
|
| 56 |
+
"create_pr": self._create_pr,
|
| 57 |
+
"list_prs": self._list_prs,
|
| 58 |
+
"get_pr": self._get_pr,
|
| 59 |
+
"merge_pr": self._merge_pr,
|
| 60 |
+
"close_pr": self._close_pr,
|
| 61 |
+
"comment_pr": self._comment_pr,
|
| 62 |
+
"create_repo": self._create_repo,
|
| 63 |
+
"update_repo": self._update_repo,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
handler = handlers.get(operation)
|
| 67 |
+
if handler:
|
| 68 |
+
return await handler(args)
|
| 69 |
+
else:
|
| 70 |
+
ops = ", ".join(handlers.keys())
|
| 71 |
+
return self._error(f"Unknown operation: {operation}. Valid: {ops}")
|
| 72 |
+
|
| 73 |
+
except RepositoryNotFoundError:
|
| 74 |
+
return self._error(f"Repository not found: {args.get('repo_id')}")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return self._error(f"Error: {str(e)}")
|
| 77 |
+
|
| 78 |
+
def _help(self) -> ToolResult:
|
| 79 |
+
"""Show usage instructions."""
|
| 80 |
+
return {
|
| 81 |
+
"formatted": """**hf_repo_git** - Git-like operations on HF repos
|
| 82 |
+
|
| 83 |
+
**Branch/Tag:**
|
| 84 |
+
- `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
|
| 85 |
+
- `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
|
| 86 |
+
- `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
|
| 87 |
+
- `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
|
| 88 |
+
- `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
|
| 89 |
+
|
| 90 |
+
**PRs:**
|
| 91 |
+
- `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}`
|
| 92 |
+
- `list_prs`: `{"operation": "list_prs", "repo_id": "..."}`
|
| 93 |
+
- `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}`
|
| 94 |
+
- `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
|
| 95 |
+
- `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
|
| 96 |
+
- `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
|
| 97 |
+
|
| 98 |
+
**Repo:**
|
| 99 |
+
- `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
|
| 100 |
+
- `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
|
| 101 |
+
"totalResults": 1,
|
| 102 |
+
"resultsShared": 1,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# =========================================================================
|
| 106 |
+
# BRANCH OPERATIONS
|
| 107 |
+
# =========================================================================
|
| 108 |
+
|
| 109 |
+
async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 110 |
+
"""Create a new branch."""
|
| 111 |
+
repo_id = args.get("repo_id")
|
| 112 |
+
branch = args.get("branch")
|
| 113 |
+
|
| 114 |
+
if not repo_id:
|
| 115 |
+
return self._error("repo_id is required")
|
| 116 |
+
if not branch:
|
| 117 |
+
return self._error("branch is required")
|
| 118 |
+
|
| 119 |
+
repo_type = args.get("repo_type", "model")
|
| 120 |
+
from_rev = args.get("from_rev", "main")
|
| 121 |
+
|
| 122 |
+
await _async_call(
|
| 123 |
+
self.api.create_branch,
|
| 124 |
+
repo_id=repo_id,
|
| 125 |
+
branch=branch,
|
| 126 |
+
revision=from_rev,
|
| 127 |
+
repo_type=repo_type,
|
| 128 |
+
exist_ok=args.get("exist_ok", False),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
|
| 132 |
+
return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 133 |
+
|
| 134 |
+
async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 135 |
+
"""Delete a branch."""
|
| 136 |
+
repo_id = args.get("repo_id")
|
| 137 |
+
branch = args.get("branch")
|
| 138 |
+
|
| 139 |
+
if not repo_id:
|
| 140 |
+
return self._error("repo_id is required")
|
| 141 |
+
if not branch:
|
| 142 |
+
return self._error("branch is required")
|
| 143 |
+
|
| 144 |
+
repo_type = args.get("repo_type", "model")
|
| 145 |
+
|
| 146 |
+
await _async_call(
|
| 147 |
+
self.api.delete_branch,
|
| 148 |
+
repo_id=repo_id,
|
| 149 |
+
branch=branch,
|
| 150 |
+
repo_type=repo_type,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
|
| 154 |
+
|
| 155 |
+
# =========================================================================
|
| 156 |
+
# TAG OPERATIONS
|
| 157 |
+
# =========================================================================
|
| 158 |
+
|
| 159 |
+
async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 160 |
+
"""Create a tag."""
|
| 161 |
+
repo_id = args.get("repo_id")
|
| 162 |
+
tag = args.get("tag")
|
| 163 |
+
|
| 164 |
+
if not repo_id:
|
| 165 |
+
return self._error("repo_id is required")
|
| 166 |
+
if not tag:
|
| 167 |
+
return self._error("tag is required")
|
| 168 |
+
|
| 169 |
+
repo_type = args.get("repo_type", "model")
|
| 170 |
+
revision = args.get("revision", "main")
|
| 171 |
+
tag_message = args.get("tag_message", "")
|
| 172 |
+
|
| 173 |
+
await _async_call(
|
| 174 |
+
self.api.create_tag,
|
| 175 |
+
repo_id=repo_id,
|
| 176 |
+
tag=tag,
|
| 177 |
+
revision=revision,
|
| 178 |
+
tag_message=tag_message,
|
| 179 |
+
repo_type=repo_type,
|
| 180 |
+
exist_ok=args.get("exist_ok", False),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
|
| 184 |
+
return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 185 |
+
|
| 186 |
+
async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 187 |
+
"""Delete a tag."""
|
| 188 |
+
repo_id = args.get("repo_id")
|
| 189 |
+
tag = args.get("tag")
|
| 190 |
+
|
| 191 |
+
if not repo_id:
|
| 192 |
+
return self._error("repo_id is required")
|
| 193 |
+
if not tag:
|
| 194 |
+
return self._error("tag is required")
|
| 195 |
+
|
| 196 |
+
repo_type = args.get("repo_type", "model")
|
| 197 |
+
|
| 198 |
+
await _async_call(
|
| 199 |
+
self.api.delete_tag,
|
| 200 |
+
repo_id=repo_id,
|
| 201 |
+
tag=tag,
|
| 202 |
+
repo_type=repo_type,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
|
| 206 |
+
|
| 207 |
+
# =========================================================================
|
| 208 |
+
# LIST REFS
|
| 209 |
+
# =========================================================================
|
| 210 |
+
|
| 211 |
+
async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
|
| 212 |
+
"""List branches and tags."""
|
| 213 |
+
repo_id = args.get("repo_id")
|
| 214 |
+
|
| 215 |
+
if not repo_id:
|
| 216 |
+
return self._error("repo_id is required")
|
| 217 |
+
|
| 218 |
+
repo_type = args.get("repo_type", "model")
|
| 219 |
+
|
| 220 |
+
refs = await _async_call(
|
| 221 |
+
self.api.list_repo_refs,
|
| 222 |
+
repo_id=repo_id,
|
| 223 |
+
repo_type=repo_type,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
branches = [b.name for b in refs.branches] if refs.branches else []
|
| 227 |
+
tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
|
| 228 |
+
|
| 229 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 230 |
+
lines = [f"**{repo_id}**", url, ""]
|
| 231 |
+
|
| 232 |
+
if branches:
|
| 233 |
+
lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
|
| 234 |
+
else:
|
| 235 |
+
lines.append("**Branches:** none")
|
| 236 |
+
|
| 237 |
+
if tags:
|
| 238 |
+
lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
|
| 239 |
+
else:
|
| 240 |
+
lines.append("**Tags:** none")
|
| 241 |
+
|
| 242 |
+
return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
|
| 243 |
+
|
| 244 |
+
# =========================================================================
|
| 245 |
+
# PR OPERATIONS
|
| 246 |
+
# =========================================================================
|
| 247 |
+
|
| 248 |
+
async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 249 |
+
"""Create a pull request."""
|
| 250 |
+
repo_id = args.get("repo_id")
|
| 251 |
+
title = args.get("title")
|
| 252 |
+
|
| 253 |
+
if not repo_id:
|
| 254 |
+
return self._error("repo_id is required")
|
| 255 |
+
if not title:
|
| 256 |
+
return self._error("title is required")
|
| 257 |
+
|
| 258 |
+
repo_type = args.get("repo_type", "model")
|
| 259 |
+
description = args.get("description", "")
|
| 260 |
+
|
| 261 |
+
result = await _async_call(
|
| 262 |
+
self.api.create_pull_request,
|
| 263 |
+
repo_id=repo_id,
|
| 264 |
+
title=title,
|
| 265 |
+
description=description,
|
| 266 |
+
repo_type=repo_type,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
|
| 270 |
+
return {
|
| 271 |
+
"formatted": f"**PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
|
| 272 |
+
"totalResults": 1,
|
| 273 |
+
"resultsShared": 1,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
|
| 277 |
+
"""List PRs and discussions."""
|
| 278 |
+
repo_id = args.get("repo_id")
|
| 279 |
+
|
| 280 |
+
if not repo_id:
|
| 281 |
+
return self._error("repo_id is required")
|
| 282 |
+
|
| 283 |
+
repo_type = args.get("repo_type", "model")
|
| 284 |
+
status = args.get("status", "all") # open, closed, all
|
| 285 |
+
|
| 286 |
+
discussions = list(self.api.get_repo_discussions(
|
| 287 |
+
repo_id=repo_id,
|
| 288 |
+
repo_type=repo_type,
|
| 289 |
+
discussion_status=status if status != "all" else None,
|
| 290 |
+
))
|
| 291 |
+
|
| 292 |
+
if not discussions:
|
| 293 |
+
return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
|
| 294 |
+
|
| 295 |
+
url = _build_repo_url(repo_id, repo_type)
|
| 296 |
+
lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
|
| 297 |
+
|
| 298 |
+
for d in discussions[:20]:
|
| 299 |
+
emoji = "🟢" if d.status == "open" else "🔴"
|
| 300 |
+
type_label = "PR" if d.is_pull_request else "D"
|
| 301 |
+
lines.append(f"{emoji} #{d.num} [{type_label}] {d.title}")
|
| 302 |
+
|
| 303 |
+
return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
|
| 304 |
+
|
| 305 |
+
async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 306 |
+
"""Get PR details."""
|
| 307 |
+
repo_id = args.get("repo_id")
|
| 308 |
+
pr_num = args.get("pr_num")
|
| 309 |
+
|
| 310 |
+
if not repo_id:
|
| 311 |
+
return self._error("repo_id is required")
|
| 312 |
+
if not pr_num:
|
| 313 |
+
return self._error("pr_num is required")
|
| 314 |
+
|
| 315 |
+
repo_type = args.get("repo_type", "model")
|
| 316 |
+
|
| 317 |
+
pr = await _async_call(
|
| 318 |
+
self.api.get_discussion_details,
|
| 319 |
+
repo_id=repo_id,
|
| 320 |
+
discussion_num=int(pr_num),
|
| 321 |
+
repo_type=repo_type,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 325 |
+
status = "🟢 Open" if pr.status == "open" else "🔴 Closed"
|
| 326 |
+
type_label = "Pull Request" if pr.is_pull_request else "Discussion"
|
| 327 |
+
|
| 328 |
+
lines = [
|
| 329 |
+
f"**{type_label} #{pr_num}:** {pr.title}",
|
| 330 |
+
f"**Status:** {status}",
|
| 331 |
+
f"**Author:** {pr.author}",
|
| 332 |
+
url,
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
if pr.is_pull_request:
|
| 336 |
+
lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
|
| 337 |
+
|
| 338 |
+
return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
|
| 339 |
+
|
| 340 |
+
async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 341 |
+
"""Merge a pull request."""
|
| 342 |
+
repo_id = args.get("repo_id")
|
| 343 |
+
pr_num = args.get("pr_num")
|
| 344 |
+
|
| 345 |
+
if not repo_id:
|
| 346 |
+
return self._error("repo_id is required")
|
| 347 |
+
if not pr_num:
|
| 348 |
+
return self._error("pr_num is required")
|
| 349 |
+
|
| 350 |
+
repo_type = args.get("repo_type", "model")
|
| 351 |
+
comment = args.get("comment", "")
|
| 352 |
+
|
| 353 |
+
await _async_call(
|
| 354 |
+
self.api.merge_pull_request,
|
| 355 |
+
repo_id=repo_id,
|
| 356 |
+
discussion_num=int(pr_num),
|
| 357 |
+
comment=comment,
|
| 358 |
+
repo_type=repo_type,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 362 |
+
return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 363 |
+
|
| 364 |
+
async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 365 |
+
"""Close a PR/discussion."""
|
| 366 |
+
repo_id = args.get("repo_id")
|
| 367 |
+
pr_num = args.get("pr_num")
|
| 368 |
+
|
| 369 |
+
if not repo_id:
|
| 370 |
+
return self._error("repo_id is required")
|
| 371 |
+
if not pr_num:
|
| 372 |
+
return self._error("pr_num is required")
|
| 373 |
+
|
| 374 |
+
repo_type = args.get("repo_type", "model")
|
| 375 |
+
comment = args.get("comment", "")
|
| 376 |
+
|
| 377 |
+
await _async_call(
|
| 378 |
+
self.api.change_discussion_status,
|
| 379 |
+
repo_id=repo_id,
|
| 380 |
+
discussion_num=int(pr_num),
|
| 381 |
+
new_status="closed",
|
| 382 |
+
comment=comment,
|
| 383 |
+
repo_type=repo_type,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
|
| 387 |
+
|
| 388 |
+
async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 389 |
+
"""Add a comment to a PR/discussion."""
|
| 390 |
+
repo_id = args.get("repo_id")
|
| 391 |
+
pr_num = args.get("pr_num")
|
| 392 |
+
comment = args.get("comment")
|
| 393 |
+
|
| 394 |
+
if not repo_id:
|
| 395 |
+
return self._error("repo_id is required")
|
| 396 |
+
if not pr_num:
|
| 397 |
+
return self._error("pr_num is required")
|
| 398 |
+
if not comment:
|
| 399 |
+
return self._error("comment is required")
|
| 400 |
+
|
| 401 |
+
repo_type = args.get("repo_type", "model")
|
| 402 |
+
|
| 403 |
+
await _async_call(
|
| 404 |
+
self.api.comment_discussion,
|
| 405 |
+
repo_id=repo_id,
|
| 406 |
+
discussion_num=int(pr_num),
|
| 407 |
+
comment=comment,
|
| 408 |
+
repo_type=repo_type,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 412 |
+
return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 413 |
+
|
| 414 |
+
# =========================================================================
|
| 415 |
+
# REPO MANAGEMENT
|
| 416 |
+
# =========================================================================
|
| 417 |
+
|
| 418 |
+
async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
|
| 419 |
+
"""Create a new repository."""
|
| 420 |
+
repo_id = args.get("repo_id")
|
| 421 |
+
|
| 422 |
+
if not repo_id:
|
| 423 |
+
return self._error("repo_id is required")
|
| 424 |
+
|
| 425 |
+
repo_type = args.get("repo_type", "model")
|
| 426 |
+
private = args.get("private", True)
|
| 427 |
+
space_sdk = args.get("space_sdk")
|
| 428 |
+
|
| 429 |
+
if repo_type == "space" and not space_sdk:
|
| 430 |
+
return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
|
| 431 |
+
|
| 432 |
+
kwargs = {
|
| 433 |
+
"repo_id": repo_id,
|
| 434 |
+
"repo_type": repo_type,
|
| 435 |
+
"private": private,
|
| 436 |
+
"exist_ok": args.get("exist_ok", False),
|
| 437 |
+
}
|
| 438 |
+
if space_sdk:
|
| 439 |
+
kwargs["space_sdk"] = space_sdk
|
| 440 |
+
|
| 441 |
+
result = await _async_call(self.api.create_repo, **kwargs)
|
| 442 |
+
|
| 443 |
+
return {
|
| 444 |
+
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
| 445 |
+
"totalResults": 1,
|
| 446 |
+
"resultsShared": 1,
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
|
| 450 |
+
"""Update repository settings."""
|
| 451 |
+
repo_id = args.get("repo_id")
|
| 452 |
+
|
| 453 |
+
if not repo_id:
|
| 454 |
+
return self._error("repo_id is required")
|
| 455 |
+
|
| 456 |
+
repo_type = args.get("repo_type", "model")
|
| 457 |
+
private = args.get("private")
|
| 458 |
+
gated = args.get("gated")
|
| 459 |
+
|
| 460 |
+
if private is None and gated is None:
|
| 461 |
+
return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
|
| 462 |
+
|
| 463 |
+
kwargs = {"repo_id": repo_id, "repo_type": repo_type}
|
| 464 |
+
if private is not None:
|
| 465 |
+
kwargs["private"] = private
|
| 466 |
+
if gated is not None:
|
| 467 |
+
kwargs["gated"] = gated
|
| 468 |
+
|
| 469 |
+
await _async_call(self.api.update_repo_settings, **kwargs)
|
| 470 |
+
|
| 471 |
+
changes = []
|
| 472 |
+
if private is not None:
|
| 473 |
+
changes.append(f"private={private}")
|
| 474 |
+
if gated is not None:
|
| 475 |
+
changes.append(f"gated={gated}")
|
| 476 |
+
|
| 477 |
+
url = f"{_build_repo_url(repo_id, repo_type)}/settings"
|
| 478 |
+
return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
|
| 479 |
+
|
| 480 |
+
def _error(self, message: str) -> ToolResult:
|
| 481 |
+
"""Return an error result."""
|
| 482 |
+
return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# Tool specification
|
| 486 |
+
HF_REPO_GIT_TOOL_SPEC = {
|
| 487 |
+
"name": "hf_repo_git",
|
| 488 |
+
"description": (
|
| 489 |
+
"Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
|
| 490 |
+
"## Operations\n"
|
| 491 |
+
"**Branches:** create_branch, delete_branch, list_refs\n"
|
| 492 |
+
"**Tags:** create_tag, delete_tag\n"
|
| 493 |
+
"**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr\n"
|
| 494 |
+
"**Repo:** create_repo, update_repo\n\n"
|
| 495 |
+
"## Use when\n"
|
| 496 |
+
"- Creating feature branches for experiments\n"
|
| 497 |
+
"- Tagging model versions (v1.0, v2.0)\n"
|
| 498 |
+
"- Opening PRs to contribute to repos you don't own\n"
|
| 499 |
+
"- Reviewing and merging PRs on your repos\n"
|
| 500 |
+
"- Creating new model/dataset/space repos\n"
|
| 501 |
+
"- Changing repo visibility (public/private) or gated access\n\n"
|
| 502 |
+
"## Examples\n"
|
| 503 |
+
'{"operation": "list_refs", "repo_id": "my-model"}\n'
|
| 504 |
+
'{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
|
| 505 |
+
'{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
|
| 506 |
+
'{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
|
| 507 |
+
'{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
|
| 508 |
+
'{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
|
| 509 |
+
'{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
|
| 510 |
+
"## PR Workflow\n"
|
| 511 |
+
"1. create_pr → creates empty draft PR\n"
|
| 512 |
+
"2. Upload files with revision='refs/pr/N' to add commits\n"
|
| 513 |
+
"3. merge_pr when ready\n\n"
|
| 514 |
+
"## Notes\n"
|
| 515 |
+
"- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
|
| 516 |
+
"- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
|
| 517 |
+
"- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
|
| 518 |
+
),
|
| 519 |
+
"parameters": {
|
| 520 |
+
"type": "object",
|
| 521 |
+
"properties": {
|
| 522 |
+
"operation": {
|
| 523 |
+
"type": "string",
|
| 524 |
+
"enum": [
|
| 525 |
+
"create_branch", "delete_branch",
|
| 526 |
+
"create_tag", "delete_tag", "list_refs",
|
| 527 |
+
"create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr",
|
| 528 |
+
"create_repo", "update_repo",
|
| 529 |
+
],
|
| 530 |
+
"description": "Operation to execute",
|
| 531 |
+
},
|
| 532 |
+
"repo_id": {
|
| 533 |
+
"type": "string",
|
| 534 |
+
"description": "Repository ID (e.g., 'username/repo-name')",
|
| 535 |
+
},
|
| 536 |
+
"repo_type": {
|
| 537 |
+
"type": "string",
|
| 538 |
+
"enum": ["model", "dataset", "space"],
|
| 539 |
+
"description": "Repository type (default: model)",
|
| 540 |
+
},
|
| 541 |
+
"branch": {
|
| 542 |
+
"type": "string",
|
| 543 |
+
"description": "Branch name (create_branch, delete_branch)",
|
| 544 |
+
},
|
| 545 |
+
"from_rev": {
|
| 546 |
+
"type": "string",
|
| 547 |
+
"description": "Create branch from this revision (default: main)",
|
| 548 |
+
},
|
| 549 |
+
"tag": {
|
| 550 |
+
"type": "string",
|
| 551 |
+
"description": "Tag name (create_tag, delete_tag)",
|
| 552 |
+
},
|
| 553 |
+
"revision": {
|
| 554 |
+
"type": "string",
|
| 555 |
+
"description": "Revision for tag (default: main)",
|
| 556 |
+
},
|
| 557 |
+
"tag_message": {
|
| 558 |
+
"type": "string",
|
| 559 |
+
"description": "Tag description",
|
| 560 |
+
},
|
| 561 |
+
"title": {
|
| 562 |
+
"type": "string",
|
| 563 |
+
"description": "PR title (create_pr)",
|
| 564 |
+
},
|
| 565 |
+
"description": {
|
| 566 |
+
"type": "string",
|
| 567 |
+
"description": "PR description (create_pr)",
|
| 568 |
+
},
|
| 569 |
+
"pr_num": {
|
| 570 |
+
"type": "integer",
|
| 571 |
+
"description": "PR/discussion number",
|
| 572 |
+
},
|
| 573 |
+
"comment": {
|
| 574 |
+
"type": "string",
|
| 575 |
+
"description": "Comment text",
|
| 576 |
+
},
|
| 577 |
+
"status": {
|
| 578 |
+
"type": "string",
|
| 579 |
+
"enum": ["open", "closed", "all"],
|
| 580 |
+
"description": "Filter PRs by status (list_prs)",
|
| 581 |
+
},
|
| 582 |
+
"private": {
|
| 583 |
+
"type": "boolean",
|
| 584 |
+
"description": "Make repo private (create_repo, update_repo)",
|
| 585 |
+
},
|
| 586 |
+
"gated": {
|
| 587 |
+
"type": "string",
|
| 588 |
+
"enum": ["auto", "manual", "false"],
|
| 589 |
+
"description": "Gated access setting (update_repo)",
|
| 590 |
+
},
|
| 591 |
+
"space_sdk": {
|
| 592 |
+
"type": "string",
|
| 593 |
+
"enum": ["gradio", "streamlit", "docker", "static"],
|
| 594 |
+
"description": "Space SDK (required for create_repo with space)",
|
| 595 |
+
},
|
| 596 |
+
},
|
| 597 |
+
"required": ["operation"],
|
| 598 |
+
},
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
|
| 603 |
+
"""Handler for agent tool router."""
|
| 604 |
+
try:
|
| 605 |
+
tool = HfRepoGitTool()
|
| 606 |
+
result = await tool.execute(arguments)
|
| 607 |
+
return result["formatted"], not result.get("isError", False)
|
| 608 |
+
except Exception as e:
|
| 609 |
+
return f"Error: {str(e)}", False
|
test_dataset_tools.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Test script for
|
| 3 |
"""
|
| 4 |
|
| 5 |
import asyncio
|
|
@@ -8,7 +8,7 @@ from typing import TypedDict
|
|
| 8 |
from unittest.mock import MagicMock
|
| 9 |
|
| 10 |
|
| 11 |
-
# Mock the types module before importing
|
| 12 |
class ToolResult(TypedDict, total=False):
|
| 13 |
formatted: str
|
| 14 |
totalResults: int
|
|
@@ -20,60 +20,70 @@ mock_types = MagicMock()
|
|
| 20 |
mock_types.ToolResult = ToolResult
|
| 21 |
sys.modules["agent.tools.types"] = mock_types
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
from dataset_tools import hf_inspect_dataset_handler, inspect_dataset
|
| 26 |
|
| 27 |
|
| 28 |
-
async def
|
| 29 |
-
"""Test
|
| 30 |
-
print("=" *
|
| 31 |
-
print("Testing
|
| 32 |
-
print("=" *
|
| 33 |
|
| 34 |
-
|
| 35 |
-
print("\n→ inspect_dataset('akseljoonas/hf-agent-sessions'):")
|
| 36 |
-
result = await inspect_dataset("akseljoonas/hf-agent-sessions")
|
| 37 |
-
print(f" isError: {result['isError']}")
|
| 38 |
-
print(f" Output:\n{result['formatted']}")
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
# # Test with multi-config dataset
|
| 51 |
-
# print("\n→ inspect_dataset('nyu-mll/glue', config='mrpc'):")
|
| 52 |
-
# result = await inspect_dataset("nyu-mll/glue", config="mrpc")
|
| 53 |
-
# print(f" isError: {result['isError']}")
|
| 54 |
-
# print(f" Output:\n{result['formatted']}")
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
"""Test the handler (what the agent calls)"""
|
| 59 |
-
print("\n" + "=" * 70)
|
| 60 |
-
print("Testing hf_inspect_dataset_handler()")
|
| 61 |
-
print("=" * 70)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
| 75 |
-
print("\
|
| 76 |
-
asyncio.run(
|
| 77 |
-
|
| 78 |
-
print("\n" + "=" *
|
| 79 |
-
print("
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Test script for hf_repo_files and hf_repo_git tools
|
| 3 |
"""
|
| 4 |
|
| 5 |
import asyncio
|
|
|
|
| 8 |
from unittest.mock import MagicMock
|
| 9 |
|
| 10 |
|
| 11 |
+
# Mock the types module before importing
|
| 12 |
class ToolResult(TypedDict, total=False):
|
| 13 |
formatted: str
|
| 14 |
totalResults: int
|
|
|
|
| 20 |
mock_types.ToolResult = ToolResult
|
| 21 |
sys.modules["agent.tools.types"] = mock_types
|
| 22 |
|
| 23 |
+
from agent.tools.hf_repo_files_tool import HfRepoFilesTool
|
| 24 |
+
from agent.tools.hf_repo_git_tool import HfRepoGitTool
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
+
async def test_hf_repo_files():
|
| 28 |
+
"""Test hf_repo_files tool"""
|
| 29 |
+
print("=" * 60)
|
| 30 |
+
print("Testing hf_repo_files")
|
| 31 |
+
print("=" * 60)
|
| 32 |
|
| 33 |
+
tool = HfRepoFilesTool()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# Test list
|
| 36 |
+
print("\n→ list files in gpt2:")
|
| 37 |
+
result = await tool.execute(
|
| 38 |
+
{"operation": "list", "repo_id": "openai-community/gpt2"}
|
| 39 |
+
)
|
| 40 |
+
print(f" isError: {result.get('isError', False)}")
|
| 41 |
+
print(f" totalResults: {result['totalResults']}")
|
| 42 |
+
# Just show first few lines
|
| 43 |
+
lines = result["formatted"].split("\n")
|
| 44 |
+
print(" Output (first 5 lines):\n" + "\n".join(f" {line}" for line in lines))
|
| 45 |
+
|
| 46 |
+
# Test read
|
| 47 |
+
print("\n→ read config.json from gpt2:")
|
| 48 |
+
result = await tool.execute(
|
| 49 |
+
{"operation": "read", "repo_id": "openai-community/gpt2", "path": "config.json"}
|
| 50 |
+
)
|
| 51 |
+
print(f" isError: {result.get('isError', False)}")
|
| 52 |
+
lines = result["formatted"].split("\n")
|
| 53 |
+
print(" Output (first 10 lines):\n" + "\n".join(f" {line}" for line in lines))
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
async def test_hf_repo_git():
|
| 57 |
+
"""Test hf_repo_git tool"""
|
| 58 |
+
print("\n" + "=" * 60)
|
| 59 |
+
print("Testing hf_repo_git")
|
| 60 |
+
print("=" * 60)
|
| 61 |
|
| 62 |
+
tool = HfRepoGitTool()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
# Test list_refs
|
| 65 |
+
print("\n→ list_refs for gpt2:")
|
| 66 |
+
result = await tool.execute(
|
| 67 |
+
{"operation": "list_refs", "repo_id": "openai-community/gpt2"}
|
| 68 |
+
)
|
| 69 |
+
print(f" isError: {result.get('isError', False)}")
|
| 70 |
+
print(
|
| 71 |
+
" Output:\n"
|
| 72 |
+
+ "\n".join(f" {line}" for line in result["formatted"].split("\n"))
|
| 73 |
)
|
| 74 |
+
|
| 75 |
+
# Test help (no operation)
|
| 76 |
+
print("\n→ help (no operation):")
|
| 77 |
+
result = await tool.execute({})
|
| 78 |
+
print(f" isError: {result.get('isError', False)}")
|
| 79 |
+
lines = result["formatted"].split("\n")[:6]
|
| 80 |
+
print(" Output (first 6 lines):\n" + "\n".join(f" {line}" for line in lines))
|
| 81 |
|
| 82 |
|
| 83 |
if __name__ == "__main__":
|
| 84 |
+
print("\nHF Repo Tools Test\n")
|
| 85 |
+
asyncio.run(test_hf_repo_files())
|
| 86 |
+
asyncio.run(test_hf_repo_git())
|
| 87 |
+
print("\n" + "=" * 60)
|
| 88 |
+
print("Tests complete!")
|
| 89 |
+
print("=" * 60)
|