akseljoonas HF Staff commited on
Commit
927e50a
·
1 Parent(s): 160da13

adding a user approval step to jobs api + ruff formatting

Browse files
agent/core/agent_loop.py CHANGED
@@ -15,6 +15,16 @@ from agent.core.tools import ToolRouter
15
  ToolCall = ChatCompletionMessageToolCall
16
 
17
 
 
 
 
 
 
 
 
 
 
 
18
  class Handlers:
19
  """Handler functions for each operation type"""
20
 
@@ -33,9 +43,10 @@ class Handlers:
33
 
34
  Laminar.set_trace_session_id(session_id=session.session_id)
35
 
36
- # Add user message to history
37
- user_msg = Message(role="user", content=text)
38
- session.context_manager.add_message(user_msg)
 
39
 
40
  # Send event that we're processing
41
  await session.send_event(
@@ -97,6 +108,28 @@ class Handlers:
97
  tool_name = tc.function.name
98
  tool_args = json.loads(tc.function.arguments)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  await session.send_event(
101
  Event(
102
  event_type="tool_call",
@@ -191,6 +224,85 @@ class Handlers:
191
 
192
  await session.send_event(Event(event_type="undo_complete"))
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @staticmethod
195
  async def shutdown(session: Session) -> bool:
196
  """Handle shutdown (like shutdown in codex.rs:1329)"""
@@ -226,6 +338,12 @@ async def process_submission(session: Session, submission) -> bool:
226
  await Handlers.undo(session)
227
  return True
228
 
 
 
 
 
 
 
229
  if op.op_type == OpType.SHUTDOWN:
230
  return not await Handlers.shutdown(session)
231
 
 
15
  ToolCall = ChatCompletionMessageToolCall
16
 
17
 
18
+ def _needs_approval(tool_name: str, tool_args: dict) -> bool:
19
+ """Check if a tool call requires user approval before execution"""
20
+ if tool_name != "hf_jobs":
21
+ return False
22
+
23
+ # Check if it's a run or uv operation
24
+ operation = tool_args.get("operation", "")
25
+ return operation in ["run", "uv"]
26
+
27
+
28
  class Handlers:
29
  """Handler functions for each operation type"""
30
 
 
43
 
44
  Laminar.set_trace_session_id(session_id=session.session_id)
45
 
46
+ # Add user message to history only if there's actual content
47
+ if text:
48
+ user_msg = Message(role="user", content=text)
49
+ session.context_manager.add_message(user_msg)
50
 
51
  # Send event that we're processing
52
  await session.send_event(
 
108
  tool_name = tc.function.name
109
  tool_args = json.loads(tc.function.arguments)
110
 
111
+ # Check if this tool requires user approval
112
+ if _needs_approval(tool_name, tool_args):
113
+ await session.send_event(
114
+ Event(
115
+ event_type="approval_required",
116
+ data={
117
+ "tool": tool_name,
118
+ "arguments": tool_args,
119
+ "tool_call_id": tc.id,
120
+ },
121
+ )
122
+ )
123
+
124
+ # Store pending approval and return early
125
+ session.pending_approval = {
126
+ "tool_call": tc,
127
+ "arguments": tool_args,
128
+ }
129
+
130
+ # Return early - wait for EXEC_APPROVAL operation
131
+ return None
132
+
133
  await session.send_event(
134
  Event(
135
  event_type="tool_call",
 
224
 
225
  await session.send_event(Event(event_type="undo_complete"))
226
 
227
+ @staticmethod
228
+ async def exec_approval(
229
+ session: Session, approved: bool, feedback: str | None = None
230
+ ) -> None:
231
+ """Handle job execution approval"""
232
+ if not session.pending_approval:
233
+ await session.send_event(
234
+ Event(
235
+ event_type="error",
236
+ data={"error": "No pending approval to process"},
237
+ )
238
+ )
239
+ return
240
+
241
+ tc = session.pending_approval["tool_call"]
242
+ tool_args = session.pending_approval["arguments"]
243
+ tool_name = tc.function.name
244
+
245
+ if approved:
246
+ # Execute the pending tool
247
+ await session.send_event(
248
+ Event(
249
+ event_type="tool_call",
250
+ data={"tool": tool_name, "arguments": tool_args},
251
+ )
252
+ )
253
+
254
+ output, success = await session.tool_router.call_tool(tool_name, tool_args)
255
+
256
+ # Add tool result to context
257
+ tool_msg = Message(
258
+ role="tool",
259
+ content=output,
260
+ tool_call_id=tc.id,
261
+ name=tool_name,
262
+ )
263
+ session.context_manager.add_message(tool_msg)
264
+
265
+ await session.send_event(
266
+ Event(
267
+ event_type="tool_output",
268
+ data={
269
+ "tool": tool_name,
270
+ "output": output,
271
+ "success": success,
272
+ },
273
+ )
274
+ )
275
+ else:
276
+ # User rejected - add cancellation message to context
277
+ cancellation_msg = "Job execution cancelled by user"
278
+ if feedback:
279
+ cancellation_msg += f". User feedback: {feedback}"
280
+
281
+ tool_msg = Message(
282
+ role="tool",
283
+ content=cancellation_msg,
284
+ tool_call_id=tc.id,
285
+ name=tool_name,
286
+ )
287
+ session.context_manager.add_message(tool_msg)
288
+
289
+ await session.send_event(
290
+ Event(
291
+ event_type="tool_output",
292
+ data={
293
+ "tool": tool_name,
294
+ "output": cancellation_msg,
295
+ "success": False,
296
+ },
297
+ )
298
+ )
299
+
300
+ # Clear pending approval
301
+ session.pending_approval = None
302
+
303
+ # Continue agent loop with empty input to process the tool result
304
+ await Handlers.run_agent(session, "")
305
+
306
  @staticmethod
307
  async def shutdown(session: Session) -> bool:
308
  """Handle shutdown (like shutdown in codex.rs:1329)"""
 
338
  await Handlers.undo(session)
339
  return True
340
 
341
+ if op.op_type == OpType.EXEC_APPROVAL:
342
+ approved = op.data.get("approved", False) if op.data else False
343
+ feedback = op.data.get("feedback") if op.data else None
344
+ await Handlers.exec_approval(session, approved, feedback)
345
+ return True
346
+
347
  if op.op_type == OpType.SHUTDOWN:
348
  return not await Handlers.shutdown(session)
349
 
agent/core/session.py CHANGED
@@ -53,6 +53,7 @@ class Session:
53
  )
54
  self.is_running = True
55
  self.current_task: asyncio.Task | None = None
 
56
 
57
  async def send_event(self, event: Event) -> None:
58
  """Send event back to client"""
 
53
  )
54
  self.is_running = True
55
  self.current_task: asyncio.Task | None = None
56
+ self.pending_approval: Optional[dict[str, Any]] = None
57
 
58
  async def send_event(self, event: Event) -> None:
59
  """Send event back to client"""
agent/main.py CHANGED
@@ -47,10 +47,13 @@ class Submission:
47
 
48
  async def event_listener(
49
  event_queue: asyncio.Queue,
 
50
  turn_complete_event: asyncio.Event,
51
  ready_event: asyncio.Event,
52
  ) -> None:
53
  """Background task that listens for events and displays them"""
 
 
54
  while True:
55
  try:
56
  event = await event_queue.get()
@@ -96,6 +99,69 @@ async def event_listener(
96
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
97
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
98
  print(f"📦 Compacted context: {old_tokens} → {new_tokens} tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Silently ignore other events
100
 
101
  except asyncio.CancelledError:
@@ -145,7 +211,7 @@ async def main():
145
 
146
  # Start event listener in background
147
  listener_task = asyncio.create_task(
148
- event_listener(event_queue, turn_complete_event, ready_event)
149
  )
150
 
151
  # Wait for agent to initialize
 
47
 
48
  async def event_listener(
49
  event_queue: asyncio.Queue,
50
+ submission_queue: asyncio.Queue,
51
  turn_complete_event: asyncio.Event,
52
  ready_event: asyncio.Event,
53
  ) -> None:
54
  """Background task that listens for events and displays them"""
55
+ submission_id = [1000] # Use list to make it mutable in closure
56
+
57
  while True:
58
  try:
59
  event = await event_queue.get()
 
99
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
100
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
101
  print(f"📦 Compacted context: {old_tokens} → {new_tokens} tokens")
102
+ elif event.event_type == "approval_required":
103
+ # Display job details and prompt for approval
104
+ tool_name = event.data.get("tool", "") if event.data else ""
105
+ arguments = event.data.get("arguments", {}) if event.data else {}
106
+
107
+ print("\n" + "=" * 60)
108
+ print("⚠️ JOB EXECUTION APPROVAL REQUIRED")
109
+ print("=" * 60)
110
+
111
+ operation = arguments.get("operation", "")
112
+ args = arguments.get("args", {})
113
+
114
+ print(f"Operation: {operation}")
115
+
116
+ if operation == "uv":
117
+ script = args.get("script", "")
118
+ dependencies = args.get("dependencies", [])
119
+ print(f"Script to run:\n{script}")
120
+ if dependencies:
121
+ print(f"Dependencies: {', '.join(dependencies)}")
122
+ elif operation == "run":
123
+ image = args.get("image", "")
124
+ command = args.get("command", "")
125
+ print(f"Docker image: {image}")
126
+ print(f"Command: {command}")
127
+
128
+ # Common parameters
129
+ flavor = args.get("flavor", "cpu-basic")
130
+ detached = args.get("detached", False)
131
+ print(f"Hardware: {flavor}")
132
+ print(f"Detached mode: {detached}")
133
+
134
+ secrets = args.get("secrets", [])
135
+ if secrets:
136
+ print(f"Secrets: {', '.join(secrets)}")
137
+
138
+ print("=" * 60)
139
+
140
+ # Get user decision
141
+ loop = asyncio.get_event_loop()
142
+ response = await loop.run_in_executor(
143
+ None,
144
+ input,
145
+ "Approve? (y=yes, n=no, or provide feedback to reject): ",
146
+ )
147
+
148
+ response = response.strip()
149
+ approved = response.lower() in ["y", "yes"]
150
+ feedback = (
151
+ None if approved or response.lower() in ["n", "no"] else response
152
+ )
153
+
154
+ # Submit approval
155
+ submission_id[0] += 1
156
+ approval_submission = Submission(
157
+ id=f"approval_{submission_id[0]}",
158
+ operation=Operation(
159
+ op_type=OpType.EXEC_APPROVAL,
160
+ data={"approved": approved, "feedback": feedback},
161
+ ),
162
+ )
163
+ await submission_queue.put(approval_submission)
164
+ print("=" * 60 + "\n")
165
  # Silently ignore other events
166
 
167
  except asyncio.CancelledError:
 
211
 
212
  # Start event listener in background
213
  listener_task = asyncio.create_task(
214
+ event_listener(event_queue, submission_queue, turn_complete_event, ready_event)
215
  )
216
 
217
  # Wait for agent to initialize
agent/tools/jobs_tool.py CHANGED
@@ -13,9 +13,12 @@ from huggingface_hub import HfApi
13
  from huggingface_hub.utils import HfHubHTTPError
14
 
15
  from agent.tools.types import ToolResult
16
- from agent.tools.utilities import (format_job_details, format_jobs_table,
17
- format_scheduled_job_details,
18
- format_scheduled_jobs_table)
 
 
 
19
 
20
  # Hardware flavors
21
  CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"]
@@ -446,7 +449,7 @@ To inspect, call this tool with `{{"operation": "inspect", "args": {{"job_id": "
446
 
447
  # Not detached - wait for completion and stream logs
448
  print(f"Job started: {job.id}")
449
- print(f"Streaming logs...\n---\n")
450
 
451
  final_status, all_logs = await self._wait_for_job_completion(
452
  job_id=job.id,
@@ -519,7 +522,7 @@ To check logs, call this tool with `{{"operation": "logs", "args": {{"job_id": "
519
 
520
  # Not detached - wait for completion and stream logs
521
  print(f"UV Job started: {job.id}")
522
- print(f"Streaming logs...\n---\n")
523
 
524
  final_status, all_logs = await self._wait_for_job_completion(
525
  job_id=job.id,
 
13
  from huggingface_hub.utils import HfHubHTTPError
14
 
15
  from agent.tools.types import ToolResult
16
+ from agent.tools.utilities import (
17
+ format_job_details,
18
+ format_jobs_table,
19
+ format_scheduled_job_details,
20
+ format_scheduled_jobs_table,
21
+ )
22
 
23
  # Hardware flavors
24
  CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"]
 
449
 
450
  # Not detached - wait for completion and stream logs
451
  print(f"Job started: {job.id}")
452
+ print("Streaming logs...\n---\n")
453
 
454
  final_status, all_logs = await self._wait_for_job_completion(
455
  job_id=job.id,
 
522
 
523
  # Not detached - wait for completion and stream logs
524
  print(f"UV Job started: {job.id}")
525
+ print("Streaming logs...\n---\n")
526
 
527
  final_status, all_logs = await self._wait_for_job_completion(
528
  job_id=job.id,