Aksel Joonas Reedi commited on
Commit
f29c8f5
·
2 Parent(s): af1a664927e50a

Merge pull request #3 from huggingface/job-tool-improvements

Browse files
agent/core/agent_loop.py CHANGED
@@ -15,6 +15,16 @@ from agent.core.tools import ToolRouter
15
  ToolCall = ChatCompletionMessageToolCall
16
 
17
 
 
 
 
 
 
 
 
 
 
 
18
  class Handlers:
19
  """Handler functions for each operation type"""
20
 
@@ -33,9 +43,10 @@ class Handlers:
33
 
34
  Laminar.set_trace_session_id(session_id=session.session_id)
35
 
36
- # Add user message to history
37
- user_msg = Message(role="user", content=text)
38
- session.context_manager.add_message(user_msg)
 
39
 
40
  # Send event that we're processing
41
  await session.send_event(
@@ -97,6 +108,28 @@ class Handlers:
97
  tool_name = tc.function.name
98
  tool_args = json.loads(tc.function.arguments)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  await session.send_event(
101
  Event(
102
  event_type="tool_call",
@@ -191,6 +224,85 @@ class Handlers:
191
 
192
  await session.send_event(Event(event_type="undo_complete"))
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @staticmethod
195
  async def shutdown(session: Session) -> bool:
196
  """Handle shutdown (like shutdown in codex.rs:1329)"""
@@ -226,6 +338,12 @@ async def process_submission(session: Session, submission) -> bool:
226
  await Handlers.undo(session)
227
  return True
228
 
 
 
 
 
 
 
229
  if op.op_type == OpType.SHUTDOWN:
230
  return not await Handlers.shutdown(session)
231
 
 
15
  ToolCall = ChatCompletionMessageToolCall
16
 
17
 
18
+ def _needs_approval(tool_name: str, tool_args: dict) -> bool:
19
+ """Check if a tool call requires user approval before execution"""
20
+ if tool_name != "hf_jobs":
21
+ return False
22
+
23
+ # Check if it's a run or uv operation
24
+ operation = tool_args.get("operation", "")
25
+ return operation in ["run", "uv"]
26
+
27
+
28
  class Handlers:
29
  """Handler functions for each operation type"""
30
 
 
43
 
44
  Laminar.set_trace_session_id(session_id=session.session_id)
45
 
46
+ # Add user message to history only if there's actual content
47
+ if text:
48
+ user_msg = Message(role="user", content=text)
49
+ session.context_manager.add_message(user_msg)
50
 
51
  # Send event that we're processing
52
  await session.send_event(
 
108
  tool_name = tc.function.name
109
  tool_args = json.loads(tc.function.arguments)
110
 
111
+ # Check if this tool requires user approval
112
+ if _needs_approval(tool_name, tool_args):
113
+ await session.send_event(
114
+ Event(
115
+ event_type="approval_required",
116
+ data={
117
+ "tool": tool_name,
118
+ "arguments": tool_args,
119
+ "tool_call_id": tc.id,
120
+ },
121
+ )
122
+ )
123
+
124
+ # Store pending approval and return early
125
+ session.pending_approval = {
126
+ "tool_call": tc,
127
+ "arguments": tool_args,
128
+ }
129
+
130
+ # Return early - wait for EXEC_APPROVAL operation
131
+ return None
132
+
133
  await session.send_event(
134
  Event(
135
  event_type="tool_call",
 
224
 
225
  await session.send_event(Event(event_type="undo_complete"))
226
 
227
+ @staticmethod
228
+ async def exec_approval(
229
+ session: Session, approved: bool, feedback: str | None = None
230
+ ) -> None:
231
+ """Handle job execution approval"""
232
+ if not session.pending_approval:
233
+ await session.send_event(
234
+ Event(
235
+ event_type="error",
236
+ data={"error": "No pending approval to process"},
237
+ )
238
+ )
239
+ return
240
+
241
+ tc = session.pending_approval["tool_call"]
242
+ tool_args = session.pending_approval["arguments"]
243
+ tool_name = tc.function.name
244
+
245
+ if approved:
246
+ # Execute the pending tool
247
+ await session.send_event(
248
+ Event(
249
+ event_type="tool_call",
250
+ data={"tool": tool_name, "arguments": tool_args},
251
+ )
252
+ )
253
+
254
+ output, success = await session.tool_router.call_tool(tool_name, tool_args)
255
+
256
+ # Add tool result to context
257
+ tool_msg = Message(
258
+ role="tool",
259
+ content=output,
260
+ tool_call_id=tc.id,
261
+ name=tool_name,
262
+ )
263
+ session.context_manager.add_message(tool_msg)
264
+
265
+ await session.send_event(
266
+ Event(
267
+ event_type="tool_output",
268
+ data={
269
+ "tool": tool_name,
270
+ "output": output,
271
+ "success": success,
272
+ },
273
+ )
274
+ )
275
+ else:
276
+ # User rejected - add cancellation message to context
277
+ cancellation_msg = "Job execution cancelled by user"
278
+ if feedback:
279
+ cancellation_msg += f". User feedback: {feedback}"
280
+
281
+ tool_msg = Message(
282
+ role="tool",
283
+ content=cancellation_msg,
284
+ tool_call_id=tc.id,
285
+ name=tool_name,
286
+ )
287
+ session.context_manager.add_message(tool_msg)
288
+
289
+ await session.send_event(
290
+ Event(
291
+ event_type="tool_output",
292
+ data={
293
+ "tool": tool_name,
294
+ "output": cancellation_msg,
295
+ "success": False,
296
+ },
297
+ )
298
+ )
299
+
300
+ # Clear pending approval
301
+ session.pending_approval = None
302
+
303
+ # Continue agent loop with empty input to process the tool result
304
+ await Handlers.run_agent(session, "")
305
+
306
  @staticmethod
307
  async def shutdown(session: Session) -> bool:
308
  """Handle shutdown (like shutdown in codex.rs:1329)"""
 
338
  await Handlers.undo(session)
339
  return True
340
 
341
+ if op.op_type == OpType.EXEC_APPROVAL:
342
+ approved = op.data.get("approved", False) if op.data else False
343
+ feedback = op.data.get("feedback") if op.data else None
344
+ await Handlers.exec_approval(session, approved, feedback)
345
+ return True
346
+
347
  if op.op_type == OpType.SHUTDOWN:
348
  return not await Handlers.shutdown(session)
349
 
agent/core/session.py CHANGED
@@ -53,6 +53,7 @@ class Session:
53
  )
54
  self.is_running = True
55
  self.current_task: asyncio.Task | None = None
 
56
 
57
  async def send_event(self, event: Event) -> None:
58
  """Send event back to client"""
 
53
  )
54
  self.is_running = True
55
  self.current_task: asyncio.Task | None = None
56
+ self.pending_approval: Optional[dict[str, Any]] = None
57
 
58
  async def send_event(self, event: Event) -> None:
59
  """Send event back to client"""
agent/core/tools.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass
8
  from typing import Any, Awaitable, Callable, Optional
9
 
10
  from fastmcp import Client
 
11
  from lmnr import observe
12
  from mcp.types import EmbeddedResource, ImageContent, TextContent
13
 
@@ -166,10 +167,14 @@ class ToolRouter:
166
 
167
  # Otherwise, use MCP client
168
  if self._mcp_initialized:
169
- result = await self.mcp_client.call_tool(tool_name, arguments)
170
- # Convert MCP content blocks to string
171
- output = convert_mcp_content_to_string(result.content)
172
- return output, not result.is_error
 
 
 
 
173
 
174
  return "MCP client not initialized", False
175
 
 
8
  from typing import Any, Awaitable, Callable, Optional
9
 
10
  from fastmcp import Client
11
+ from fastmcp.exceptions import ToolError
12
  from lmnr import observe
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
14
 
 
167
 
168
  # Otherwise, use MCP client
169
  if self._mcp_initialized:
170
+ try:
171
+ result = await self.mcp_client.call_tool(tool_name, arguments)
172
+ output = convert_mcp_content_to_string(result.content)
173
+ return output, not result.is_error
174
+ except ToolError as e:
175
+ # Catch MCP tool errors and return them to the agent
176
+ error_msg = f"Tool error: {str(e)}"
177
+ return error_msg, False
178
 
179
  return "MCP client not initialized", False
180
 
agent/main.py CHANGED
@@ -47,10 +47,13 @@ class Submission:
47
 
48
  async def event_listener(
49
  event_queue: asyncio.Queue,
 
50
  turn_complete_event: asyncio.Event,
51
  ready_event: asyncio.Event,
52
  ) -> None:
53
  """Background task that listens for events and displays them"""
 
 
54
  while True:
55
  try:
56
  event = await event_queue.get()
@@ -96,6 +99,69 @@ async def event_listener(
96
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
97
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
98
  print(f"📦 Compacted context: {old_tokens} → {new_tokens} tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Silently ignore other events
100
 
101
  except asyncio.CancelledError:
@@ -145,7 +211,7 @@ async def main():
145
 
146
  # Start event listener in background
147
  listener_task = asyncio.create_task(
148
- event_listener(event_queue, turn_complete_event, ready_event)
149
  )
150
 
151
  # Wait for agent to initialize
 
47
 
48
  async def event_listener(
49
  event_queue: asyncio.Queue,
50
+ submission_queue: asyncio.Queue,
51
  turn_complete_event: asyncio.Event,
52
  ready_event: asyncio.Event,
53
  ) -> None:
54
  """Background task that listens for events and displays them"""
55
+ submission_id = [1000] # Use list to make it mutable in closure
56
+
57
  while True:
58
  try:
59
  event = await event_queue.get()
 
99
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
100
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
101
  print(f"📦 Compacted context: {old_tokens} → {new_tokens} tokens")
102
+ elif event.event_type == "approval_required":
103
+ # Display job details and prompt for approval
104
+ tool_name = event.data.get("tool", "") if event.data else ""
105
+ arguments = event.data.get("arguments", {}) if event.data else {}
106
+
107
+ print("\n" + "=" * 60)
108
+ print("⚠️ JOB EXECUTION APPROVAL REQUIRED")
109
+ print("=" * 60)
110
+
111
+ operation = arguments.get("operation", "")
112
+ args = arguments.get("args", {})
113
+
114
+ print(f"Operation: {operation}")
115
+
116
+ if operation == "uv":
117
+ script = args.get("script", "")
118
+ dependencies = args.get("dependencies", [])
119
+ print(f"Script to run:\n{script}")
120
+ if dependencies:
121
+ print(f"Dependencies: {', '.join(dependencies)}")
122
+ elif operation == "run":
123
+ image = args.get("image", "")
124
+ command = args.get("command", "")
125
+ print(f"Docker image: {image}")
126
+ print(f"Command: {command}")
127
+
128
+ # Common parameters
129
+ flavor = args.get("flavor", "cpu-basic")
130
+ detached = args.get("detached", False)
131
+ print(f"Hardware: {flavor}")
132
+ print(f"Detached mode: {detached}")
133
+
134
+ secrets = args.get("secrets", [])
135
+ if secrets:
136
+ print(f"Secrets: {', '.join(secrets)}")
137
+
138
+ print("=" * 60)
139
+
140
+ # Get user decision
141
+ loop = asyncio.get_event_loop()
142
+ response = await loop.run_in_executor(
143
+ None,
144
+ input,
145
+ "Approve? (y=yes, n=no, or provide feedback to reject): ",
146
+ )
147
+
148
+ response = response.strip()
149
+ approved = response.lower() in ["y", "yes"]
150
+ feedback = (
151
+ None if approved or response.lower() in ["n", "no"] else response
152
+ )
153
+
154
+ # Submit approval
155
+ submission_id[0] += 1
156
+ approval_submission = Submission(
157
+ id=f"approval_{submission_id[0]}",
158
+ operation=Operation(
159
+ op_type=OpType.EXEC_APPROVAL,
160
+ data={"approved": approved, "feedback": feedback},
161
+ ),
162
+ )
163
+ await submission_queue.put(approval_submission)
164
+ print("=" * 60 + "\n")
165
  # Silently ignore other events
166
 
167
  except asyncio.CancelledError:
 
211
 
212
  # Start event listener in background
213
  listener_task = asyncio.create_task(
214
+ event_listener(event_queue, submission_queue, turn_complete_event, ready_event)
215
  )
216
 
217
  # Wait for agent to initialize
agent/prompts/system_prompt.yaml CHANGED
@@ -85,6 +85,7 @@ system_prompt: |
85
  - Always search Hugging Face Hub for existing resources before suggesting custom implementations
86
  - When referencing models, datasets, or papers, include direct links from search results
87
  - Never assume a library is available - check documentation first
 
88
  - Follow ML best practices: proper train/val/test splits, reproducibility, evaluation metrics
89
  - For training tasks, consider compute requirements and suggest appropriate hardware
90
  - Never expose or log API keys, tokens, or secrets
 
85
  - Always search Hugging Face Hub for existing resources before suggesting custom implementations
86
  - When referencing models, datasets, or papers, include direct links from search results
87
  - Never assume a library is available - check documentation first
88
+ - Before processing any dataset: inspect its actual structure first using the mcp__hf-mcp-server__hub_repo_details tool. Never assume column names: verify them beforehand.
89
  - Follow ML best practices: proper train/val/test splits, reproducibility, evaluation metrics
90
  - For training tasks, consider compute requirements and suggest appropriate hardware
91
  - Never expose or log API keys, tokens, or secrets
agent/tools/jobs_tool.py CHANGED
@@ -6,6 +6,7 @@ Refactored to use official huggingface-hub library instead of custom HTTP client
6
 
7
  import asyncio
8
  import base64
 
9
  from typing import Any, Dict, Literal, Optional
10
 
11
  from huggingface_hub import HfApi
@@ -60,10 +61,32 @@ OperationType = Literal[
60
  ]
61
 
62
  # Constants
63
- DEFAULT_LOG_WAIT_SECONDS = 10
64
  UV_DEFAULT_IMAGE = "ghcr.io/astral-sh/uv:python3.12-bookworm"
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def _build_uv_command(
68
  script: str,
69
  with_deps: list[str] | None = None,
@@ -103,6 +126,20 @@ def _wrap_inline_script(
103
  return f'echo "{encoded}" | base64 -d | {uv_command_str}'
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def _resolve_uv_command(
107
  script: str,
108
  with_deps: list[str] | None = None,
@@ -316,7 +353,8 @@ Call this tool with:
316
  {{
317
  "operation": "uv",
318
  "args": {{
319
- "script": "import random\\nprint(42 + random.randint(1, 5))"
 
320
  }}
321
  }}
322
  ```
@@ -335,7 +373,6 @@ Call this tool with:
335
  **String format (simple cases only):**
336
  - Still accepted for backwards compatibility, parsed with POSIX shell semantics
337
  - Rejects shell operators and can mis-handle characters such as `&`; switch to arrays when things turn complex
338
- - `$HF_TOKEN` stays literal—forward it via `secrets: {{ "HF_TOKEN": "$HF_TOKEN" }}`
339
 
340
  ### Show command-specific help
341
  Call this tool with:
@@ -345,9 +382,10 @@ Call this tool with:
345
 
346
  ## Tips
347
 
348
- - Jobs default to non-detached mode (tail logs for up to {DEFAULT_LOG_WAIT_SECONDS}s or until completion). Set `detach: true` to return immediately.
349
  - Prefer array commands to avoid shell parsing surprises
350
- - To access private Hub assets, include `secrets: {{ "HF_TOKEN": "$HF_TOKEN" }}` to inject your auth token.
 
351
  """
352
  return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
353
 
@@ -356,6 +394,33 @@ Call this tool with:
356
  help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
357
  return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  async def _run_job(self, args: Dict[str, Any]) -> ToolResult:
360
  """Run a job using HfApi.run_job()"""
361
  try:
@@ -363,8 +428,8 @@ Call this tool with:
363
  self.api.run_job,
364
  image=args.get("image", "python:3.12"),
365
  command=args.get("command"),
366
- env=args.get("env"),
367
- secrets=args.get("secrets"),
368
  flavor=args.get("flavor", "cpu-basic"),
369
  timeout=args.get("timeout", "30m"),
370
  namespace=args.get("namespace") or self.namespace,
@@ -382,14 +447,28 @@ To check logs, call this tool with `{{"operation": "logs", "args": {{"job_id": "
382
  To inspect, call this tool with `{{"operation": "inspect", "args": {{"job_id": "{job.id}"}}}}`"""
383
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
384
 
385
- # Not detached - return job info
386
- response = f"""Job started: {job.id}
 
387
 
388
- **Status:** {job.status.stage}
389
- **View logs at:** {job.url}
 
 
390
 
391
- Note: Logs are being collected. Check the job page for real-time logs.
392
- """
 
 
 
 
 
 
 
 
 
 
 
393
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
394
 
395
  except Exception as e:
@@ -402,10 +481,18 @@ Note: Logs are being collected. Check the job page for real-time logs.
402
  if not script:
403
  raise ValueError("script is required")
404
 
 
 
 
 
 
 
 
 
405
  # Resolve the command based on script type (URL, inline, or file)
406
  command = _resolve_uv_command(
407
  script=script,
408
- with_deps=args.get("with_deps") or args.get("dependencies"),
409
  python=args.get("python"),
410
  script_args=args.get("script_args"),
411
  )
@@ -415,20 +502,46 @@ Note: Logs are being collected. Check the job page for real-time logs.
415
  self.api.run_job,
416
  image=UV_DEFAULT_IMAGE,
417
  command=command,
418
- env=args.get("env"),
419
- secrets=args.get("secrets"),
420
- flavor=args.get("flavor", "cpu-basic"),
421
  timeout=args.get("timeout", "30m"),
422
  namespace=args.get("namespace") or self.namespace,
423
  )
424
 
425
- response = f"""UV Job started: {job.id}
 
 
426
 
 
427
  **Status:** {job.status.stage}
428
  **View at:** {job.url}
429
 
430
- To check logs, call this tool with `{{"operation": "logs", "args": {{"job_id": "{job.id}"}}}}`
431
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
433
 
434
  except Exception as e:
@@ -578,8 +691,8 @@ To verify, call this tool with `{{"operation": "inspect", "args": {{"job_id": "{
578
  image=args.get("image", "python:3.12"),
579
  command=args.get("command"),
580
  schedule=args.get("schedule"),
581
- env=args.get("env"),
582
- secrets=args.get("secrets"),
583
  flavor=args.get("flavor", "cpu-basic"),
584
  timeout=args.get("timeout", "30m"),
585
  namespace=args.get("namespace") or self.namespace,
@@ -613,10 +726,18 @@ To list all, call this tool with `{{"operation": "scheduled ps"}}`"""
613
  if not schedule:
614
  raise ValueError("schedule is required")
615
 
 
 
 
 
 
 
 
 
616
  # Resolve the command based on script type
617
  command = _resolve_uv_command(
618
  script=script,
619
- with_deps=args.get("with_deps") or args.get("dependencies"),
620
  python=args.get("python"),
621
  script_args=args.get("script_args"),
622
  )
@@ -627,9 +748,9 @@ To list all, call this tool with `{{"operation": "scheduled ps"}}`"""
627
  image=UV_DEFAULT_IMAGE,
628
  command=command,
629
  schedule=schedule,
630
- env=args.get("env"),
631
- secrets=args.get("secrets"),
632
- flavor=args.get("flavor", "cpu-basic"),
633
  timeout=args.get("timeout", "30m"),
634
  namespace=args.get("namespace") or self.namespace,
635
  )
@@ -788,6 +909,7 @@ HF_JOBS_TOOL_SPEC = {
788
  "description": (
789
  "Manage Hugging Face CPU/GPU compute jobs. Run commands in Docker containers, "
790
  "execute Python scripts with UV. List, schedule and monitor jobs/logs. "
 
791
  "Call this tool with no operation for full usage instructions and examples."
792
  ),
793
  "parameters": {
@@ -818,7 +940,12 @@ HF_JOBS_TOOL_SPEC = {
818
  },
819
  "args": {
820
  "type": "object",
821
- "description": "Operation-specific arguments as a JSON object",
 
 
 
 
 
822
  "additionalProperties": True,
823
  },
824
  },
 
6
 
7
  import asyncio
8
  import base64
9
+ import os
10
  from typing import Any, Dict, Literal, Optional
11
 
12
  from huggingface_hub import HfApi
 
61
  ]
62
 
63
  # Constants
 
64
  UV_DEFAULT_IMAGE = "ghcr.io/astral-sh/uv:python3.12-bookworm"
65
 
66
 
67
+ def _substitute_hf_token(params: Dict[str, Any] | None) -> Dict[str, Any] | None:
68
+ """
69
+ Substitute $HF_TOKEN with actual token value from environment.
70
+
71
+ Args:
72
+ params: Dictionary that may contain "$HF_TOKEN" in values
73
+
74
+ Returns:
75
+ Dictionary with $HF_TOKEN substituted
76
+ """
77
+ if params is None:
78
+ return None
79
+
80
+ result = {}
81
+ for key, value in params.items():
82
+ if value == "$HF_TOKEN":
83
+ result[key] = os.environ.get("HF_TOKEN", "")
84
+ else:
85
+ result[key] = value
86
+
87
+ return result
88
+
89
+
90
  def _build_uv_command(
91
  script: str,
92
  with_deps: list[str] | None = None,
 
126
  return f'echo "{encoded}" | base64 -d | {uv_command_str}'
127
 
128
 
129
+ def _ensure_hf_transfer_dependency(deps: list[str] | None) -> list[str]:
130
+ """Ensure hf-transfer is included in the dependencies list"""
131
+ if deps is None:
132
+ return ["hf-transfer"]
133
+
134
+ if isinstance(deps, list):
135
+ deps_copy = deps.copy() # Don't modify the original
136
+ if "hf-transfer" not in deps_copy:
137
+ deps_copy.append("hf-transfer")
138
+ return deps_copy
139
+
140
+ return ["hf-transfer"]
141
+
142
+
143
  def _resolve_uv_command(
144
  script: str,
145
  with_deps: list[str] | None = None,
 
353
  {{
354
  "operation": "uv",
355
  "args": {{
356
+ "script": "import random\\nprint(42 + random.randint(1, 5))",
357
+ "dependencies" : ["torch", "huggingface_hub"]
358
  }}
359
  }}
360
  ```
 
373
  **String format (simple cases only):**
374
  - Still accepted for backwards compatibility, parsed with POSIX shell semantics
375
  - Rejects shell operators and can mis-handle characters such as `&`; switch to arrays when things turn complex
 
376
 
377
  ### Show command-specific help
378
  Call this tool with:
 
382
 
383
  ## Tips
384
 
385
+ - Jobs default to non-detached mode (stream logs until completion). Set `detach: true` to return immediately.
386
  - Prefer array commands to avoid shell parsing surprises
387
+ - To access private Hub assets (spaces, private models, datasets, collections), pass `secrets: {{ "HF_TOKEN": "$HF_TOKEN" }}`
388
+ - Before calling a job, think about dependencies (they must be specified), which hardware flavor to run on (choose simplest for task), and whether to include secrets.
389
  """
390
  return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
391
 
 
394
  help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
395
  return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
396
 
397
+ async def _wait_for_job_completion(
398
+ self, job_id: str, namespace: Optional[str] = None
399
+ ) -> tuple[str, list[str]]:
400
+ """
401
+ Stream job logs until completion, printing them in real-time.
402
+
403
+ Returns:
404
+ tuple: (final_status, all_logs)
405
+ """
406
+ all_logs = []
407
+
408
+ # Fetch logs - generator streams logs as they arrive and ends when job completes
409
+ logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace)
410
+
411
+ # Stream logs in real-time
412
+ for log_line in logs_gen:
413
+ print("\t" + log_line)
414
+ all_logs.append(log_line)
415
+
416
+ # After logs complete, fetch final job status
417
+ job_info = await _async_call(
418
+ self.api.inspect_job, job_id=job_id, namespace=namespace
419
+ )
420
+ final_status = job_info.status.stage
421
+
422
+ return final_status, all_logs
423
+
424
  async def _run_job(self, args: Dict[str, Any]) -> ToolResult:
425
  """Run a job using HfApi.run_job()"""
426
  try:
 
428
  self.api.run_job,
429
  image=args.get("image", "python:3.12"),
430
  command=args.get("command"),
431
+ env=_substitute_hf_token(args.get("env")),
432
+ secrets=_substitute_hf_token(args.get("secrets")),
433
  flavor=args.get("flavor", "cpu-basic"),
434
  timeout=args.get("timeout", "30m"),
435
  namespace=args.get("namespace") or self.namespace,
 
447
  To inspect, call this tool with `{{"operation": "inspect", "args": {{"job_id": "{job.id}"}}}}`"""
448
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
449
 
450
+ # Not detached - wait for completion and stream logs
451
+ print(f"Job started: {job.id}")
452
+ print("Streaming logs...\n---\n")
453
 
454
+ final_status, all_logs = await self._wait_for_job_completion(
455
+ job_id=job.id,
456
+ namespace=args.get("namespace") or self.namespace,
457
+ )
458
 
459
+ # Format all logs for the agent
460
+ log_text = "\n".join(all_logs) if all_logs else "(no logs)"
461
+
462
+ response = f"""Job completed!
463
+
464
+ **Job ID:** {job.id}
465
+ **Final Status:** {final_status}
466
+ **View at:** {job.url}
467
+
468
+ **Logs:**
469
+ ```
470
+ {log_text}
471
+ ```"""
472
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
473
 
474
  except Exception as e:
 
481
  if not script:
482
  raise ValueError("script is required")
483
 
484
+ # Get dependencies and ensure hf-transfer is included
485
+ deps = (
486
+ args.get("with_deps")
487
+ or args.get("dependencies")
488
+ or args.get("packages")
489
+ )
490
+ deps = _ensure_hf_transfer_dependency(deps)
491
+
492
  # Resolve the command based on script type (URL, inline, or file)
493
  command = _resolve_uv_command(
494
  script=script,
495
+ with_deps=deps,
496
  python=args.get("python"),
497
  script_args=args.get("script_args"),
498
  )
 
502
  self.api.run_job,
503
  image=UV_DEFAULT_IMAGE,
504
  command=command,
505
+ env=_substitute_hf_token(args.get("env")),
506
+ secrets=_substitute_hf_token(args.get("secrets")),
507
+ flavor=args.get("flavor") or args.get("hardware") or "cpu-basic",
508
  timeout=args.get("timeout", "30m"),
509
  namespace=args.get("namespace") or self.namespace,
510
  )
511
 
512
+ # If detached, return immediately
513
+ if args.get("detach", False):
514
+ response = f"""UV Job started successfully!
515
 
516
+ **Job ID:** {job.id}
517
  **Status:** {job.status.stage}
518
  **View at:** {job.url}
519
 
520
+ To check logs, call this tool with `{{"operation": "logs", "args": {{"job_id": "{job.id}"}}}}`"""
521
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
522
+
523
+ # Not detached - wait for completion and stream logs
524
+ print(f"UV Job started: {job.id}")
525
+ print("Streaming logs...\n---\n")
526
+
527
+ final_status, all_logs = await self._wait_for_job_completion(
528
+ job_id=job.id,
529
+ namespace=args.get("namespace") or self.namespace,
530
+ )
531
+
532
+ # Format all logs for the agent
533
+ log_text = "\n".join(all_logs) if all_logs else "(no logs)"
534
+
535
+ response = f"""UV Job completed!
536
+
537
+ **Job ID:** {job.id}
538
+ **Final Status:** {final_status}
539
+ **View at:** {job.url}
540
+
541
+ **Logs:**
542
+ ```
543
+ {log_text}
544
+ ```"""
545
  return {"formatted": response, "totalResults": 1, "resultsShared": 1}
546
 
547
  except Exception as e:
 
691
  image=args.get("image", "python:3.12"),
692
  command=args.get("command"),
693
  schedule=args.get("schedule"),
694
+ env=_substitute_hf_token(args.get("env")),
695
+ secrets=_substitute_hf_token(args.get("secrets")),
696
  flavor=args.get("flavor", "cpu-basic"),
697
  timeout=args.get("timeout", "30m"),
698
  namespace=args.get("namespace") or self.namespace,
 
726
  if not schedule:
727
  raise ValueError("schedule is required")
728
 
729
+ # Get dependencies and ensure hf-transfer is included
730
+ deps = (
731
+ args.get("with_deps")
732
+ or args.get("dependencies")
733
+ or args.get("packages")
734
+ )
735
+ deps = _ensure_hf_transfer_dependency(deps)
736
+
737
  # Resolve the command based on script type
738
  command = _resolve_uv_command(
739
  script=script,
740
+ with_deps=deps,
741
  python=args.get("python"),
742
  script_args=args.get("script_args"),
743
  )
 
748
  image=UV_DEFAULT_IMAGE,
749
  command=command,
750
  schedule=schedule,
751
+ env=_substitute_hf_token(args.get("env")),
752
+ secrets=_substitute_hf_token(args.get("secrets")),
753
+ flavor=args.get("flavor") or args.get("hardware") or "cpu-basic",
754
  timeout=args.get("timeout", "30m"),
755
  namespace=args.get("namespace") or self.namespace,
756
  )
 
909
  "description": (
910
  "Manage Hugging Face CPU/GPU compute jobs. Run commands in Docker containers, "
911
  "execute Python scripts with UV. List, schedule and monitor jobs/logs. "
912
+ "Example hardware/flavor: cpu-basic, cpu-performance, t4-medium. "
913
  "Call this tool with no operation for full usage instructions and examples."
914
  ),
915
  "parameters": {
 
940
  },
941
  "args": {
942
  "type": "object",
943
+ "description": (
944
+ "Operation-specific arguments as a JSON object. "
945
+ "Common args: script (for uv), packages/dependencies (array), "
946
+ "flavor/hardware (e.g., a10g-large, cpu-basic), command (array), "
947
+ "image (string), env (object), secrets (object)."
948
+ ),
949
  "additionalProperties": True,
950
  },
951
  },