akseljoonas HF Staff Claude Opus 4.6 commited on
Commit
3c91fc8
Β·
1 Parent(s): e229826

feat: kill sandbox processes and cancel HF jobs on user interrupt

Browse files

When the user clicks cancel, long-running sandbox commands and HF jobs
now actually stop instead of running to completion in the background.

- Sandbox server: track PIDs with Popen, add /api/kill endpoint
- Sandbox client: add kill_all() method
- HF jobs: track running job IDs in session, cancel on interrupt
- Agent loop: cleanup on cancel in both run_agent and exec_approval
- exec_approval: add cancellation support (was completely missing)
- Sandbox title: set descriptive name instead of inheriting template

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

agent/core/agent_loop.py CHANGED
@@ -146,6 +146,27 @@ def _needs_approval(
146
  return False
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  async def _compact_and_notify(session: Session) -> None:
150
  """Run compaction and send event if context was reduced."""
151
  old_length = session.context_manager.context_length
@@ -164,6 +185,32 @@ async def _compact_and_notify(session: Session) -> None:
164
  )
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  class Handlers:
168
  """Handler functions for each operation type"""
169
 
@@ -247,17 +294,37 @@ class Handlers:
247
  messages = session.context_manager.get_messages()
248
  tools = session.tool_router.get_tool_specs_for_llm()
249
  try:
250
- # ── Stream the LLM response ──────────────────────────
251
  llm_params = _resolve_hf_router_params(session.config.model_name)
252
- response = await acompletion(
253
- messages=messages,
254
- tools=tools,
255
- tool_choice="auto",
256
- stream=True,
257
- stream_options={"include_usage": True},
258
- timeout=600, # 10 min β€” long tool-use turns can take a while
259
- **llm_params,
260
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  full_content = ""
263
  tool_calls_acc: dict[int, dict] = {}
@@ -355,8 +422,8 @@ class Handlers:
355
  )
356
  await session.send_event(
357
  Event(
358
- event_type="error",
359
- data={"error": f"Output truncated β€” retrying with smaller content ({dropped_names})"},
360
  )
361
  )
362
  iteration += 1
@@ -510,6 +577,7 @@ class Handlers:
510
  await gather_task
511
  except asyncio.CancelledError:
512
  pass
 
513
  break
514
 
515
  cancel_task.cancel()
@@ -593,6 +661,7 @@ class Handlers:
593
  break
594
 
595
  if session.is_cancelled:
 
596
  await session.send_event(Event(event_type="interrupted"))
597
  elif not errored:
598
  await session.send_event(
@@ -743,16 +812,37 @@ class Handlers:
743
 
744
  return (tc, tool_name, output, success, was_edited)
745
 
746
- # Execute all approved tools concurrently and wait for ALL to complete
747
  if approved_tasks:
748
- results = await asyncio.gather(
749
  *[
750
  execute_tool(tc, tool_name, tool_args, was_edited)
751
  for tc, tool_name, tool_args, was_edited in approved_tasks
752
  ],
753
  return_exceptions=True,
 
 
 
 
 
 
754
  )
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  # Process results and add to context
757
  for result in results:
758
  if isinstance(result, Exception):
 
146
  return False
147
 
148
 
149
+ # -- LLM retry constants --------------------------------------------------
150
+ _MAX_LLM_RETRIES = 3
151
+ _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
152
+
153
+
154
+ def _is_transient_error(error: Exception) -> bool:
155
+ """Return True for errors that are likely transient and worth retrying."""
156
+ err_str = str(error).lower()
157
+ transient_patterns = [
158
+ "timeout", "timed out",
159
+ "429", "rate limit", "rate_limit",
160
+ "503", "service unavailable",
161
+ "502", "bad gateway",
162
+ "500", "internal server error",
163
+ "overloaded", "capacity",
164
+ "connection reset", "connection refused", "connection error",
165
+ "eof", "broken pipe",
166
+ ]
167
+ return any(pattern in err_str for pattern in transient_patterns)
168
+
169
+
170
  async def _compact_and_notify(session: Session) -> None:
171
  """Run compaction and send event if context was reduced."""
172
  old_length = session.context_manager.context_length
 
185
  )
186
 
187
 
188
+ async def _cleanup_on_cancel(session: Session) -> None:
189
+ """Kill sandbox processes and cancel HF jobs when the user interrupts."""
190
+ # Kill active sandbox processes
191
+ sandbox = getattr(session, "sandbox", None)
192
+ if sandbox:
193
+ try:
194
+ await asyncio.to_thread(sandbox.kill_all)
195
+ logger.info("Killed sandbox processes on cancel")
196
+ except Exception as e:
197
+ logger.warning("Failed to kill sandbox processes: %s", e)
198
+
199
+ # Cancel running HF jobs
200
+ job_ids = list(session._running_job_ids)
201
+ if job_ids:
202
+ from huggingface_hub import HfApi
203
+
204
+ api = HfApi(token=session.hf_token)
205
+ for job_id in job_ids:
206
+ try:
207
+ await asyncio.to_thread(api.cancel_job, job_id=job_id)
208
+ logger.info("Cancelled HF job %s on interrupt", job_id)
209
+ except Exception as e:
210
+ logger.warning("Failed to cancel HF job %s: %s", job_id, e)
211
+ session._running_job_ids.clear()
212
+
213
+
214
  class Handlers:
215
  """Handler functions for each operation type"""
216
 
 
294
  messages = session.context_manager.get_messages()
295
  tools = session.tool_router.get_tool_specs_for_llm()
296
  try:
297
+ # ── Stream the LLM response (with retry for transient errors) ──
298
  llm_params = _resolve_hf_router_params(session.config.model_name)
299
+ response = None
300
+ for _llm_attempt in range(_MAX_LLM_RETRIES):
301
+ try:
302
+ response = await acompletion(
303
+ messages=messages,
304
+ tools=tools,
305
+ tool_choice="auto",
306
+ stream=True,
307
+ stream_options={"include_usage": True},
308
+ timeout=600,
309
+ **llm_params,
310
+ )
311
+ break
312
+ except ContextWindowExceededError:
313
+ raise
314
+ except Exception as e:
315
+ if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
316
+ _delay = _LLM_RETRY_DELAYS[_llm_attempt]
317
+ logger.warning(
318
+ "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds",
319
+ _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
320
+ )
321
+ await session.send_event(Event(
322
+ event_type="tool_log",
323
+ data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
324
+ ))
325
+ await asyncio.sleep(_delay)
326
+ continue
327
+ raise
328
 
329
  full_content = ""
330
  tool_calls_acc: dict[int, dict] = {}
 
422
  )
423
  await session.send_event(
424
  Event(
425
+ event_type="tool_log",
426
+ data={"tool": "system", "log": f"Output truncated β€” retrying with smaller content ({dropped_names})"},
427
  )
428
  )
429
  iteration += 1
 
577
  await gather_task
578
  except asyncio.CancelledError:
579
  pass
580
+ await _cleanup_on_cancel(session)
581
  break
582
 
583
  cancel_task.cancel()
 
661
  break
662
 
663
  if session.is_cancelled:
664
+ await _cleanup_on_cancel(session)
665
  await session.send_event(Event(event_type="interrupted"))
666
  elif not errored:
667
  await session.send_event(
 
812
 
813
  return (tc, tool_name, output, success, was_edited)
814
 
815
+ # Execute all approved tools concurrently (cancellable)
816
  if approved_tasks:
817
+ gather_task = asyncio.ensure_future(asyncio.gather(
818
  *[
819
  execute_tool(tc, tool_name, tool_args, was_edited)
820
  for tc, tool_name, tool_args, was_edited in approved_tasks
821
  ],
822
  return_exceptions=True,
823
+ ))
824
+ cancel_task = asyncio.ensure_future(session._cancelled.wait())
825
+
826
+ done, _ = await asyncio.wait(
827
+ [gather_task, cancel_task],
828
+ return_when=asyncio.FIRST_COMPLETED,
829
  )
830
 
831
+ if cancel_task in done:
832
+ gather_task.cancel()
833
+ try:
834
+ await gather_task
835
+ except asyncio.CancelledError:
836
+ pass
837
+ await _cleanup_on_cancel(session)
838
+ await session.send_event(Event(event_type="interrupted"))
839
+ session.increment_turn()
840
+ await session.auto_save_if_needed()
841
+ return
842
+
843
+ cancel_task.cancel()
844
+ results = gather_task.result()
845
+
846
  # Process results and add to context
847
  for result in results:
848
  if isinstance(result, Exception):
agent/core/session.py CHANGED
@@ -103,6 +103,7 @@ class Session:
103
  self._cancelled = asyncio.Event()
104
  self.pending_approval: Optional[dict[str, Any]] = None
105
  self.sandbox = None
 
106
 
107
  # Session trajectory logging
108
  self.logged_events: list[dict] = []
 
103
  self._cancelled = asyncio.Event()
104
  self.pending_approval: Optional[dict[str, Any]] = None
105
  self.sandbox = None
106
+ self._running_job_ids: set[str] = set() # HF job IDs currently executing
107
 
108
  # Session trajectory logging
109
  self.logged_events: list[dict] = []
agent/tools/jobs_tool.py CHANGED
@@ -531,6 +531,10 @@ class HfJobsTool:
531
  namespace=self.namespace,
532
  )
533
 
 
 
 
 
534
  # Send job URL immediately after job creation (before waiting for completion)
535
  if self.session and self.tool_call_id:
536
  await self.session.send_event(
@@ -554,6 +558,10 @@ class HfJobsTool:
554
  namespace=self.namespace,
555
  )
556
 
 
 
 
 
557
  # Notify frontend of final status
558
  if self.session and self.tool_call_id:
559
  await self.session.send_event(
 
531
  namespace=self.namespace,
532
  )
533
 
534
+ # Track job ID for cancellation on interrupt
535
+ if self.session:
536
+ self.session._running_job_ids.add(job.id)
537
+
538
  # Send job URL immediately after job creation (before waiting for completion)
539
  if self.session and self.tool_call_id:
540
  await self.session.send_event(
 
558
  namespace=self.namespace,
559
  )
560
 
561
+ # Untrack job ID (completed or failed, no longer needs cancellation)
562
+ if self.session:
563
+ self.session._running_job_ids.discard(job.id)
564
+
565
  # Notify frontend of final status
566
  if self.session and self.tool_call_id:
567
  await self.session.send_event(
agent/tools/sandbox_client.py CHANGED
@@ -97,7 +97,7 @@ CMD ["python", "sandbox_server.py"]
97
 
98
  _SANDBOX_SERVER = '''\
99
  """Minimal FastAPI server for sandbox operations."""
100
- import os, subprocess, pathlib
101
  from fastapi import FastAPI
102
  from pydantic import BaseModel
103
  from typing import Optional
@@ -105,6 +105,10 @@ import uvicorn
105
 
106
  app = FastAPI()
107
 
 
 
 
 
108
  class BashReq(BaseModel):
109
  command: str
110
  work_dir: str = "/app"
@@ -135,19 +139,49 @@ def health():
135
  @app.post("/api/bash")
136
  def bash(req: BashReq):
137
  try:
138
- r = subprocess.run(
139
- req.command, shell=True, capture_output=True, text=True,
140
- cwd=req.work_dir, timeout=req.timeout,
141
  )
142
- output = r.stdout + r.stderr
143
- if len(output) > 30000:
144
- output = output[:30000] + "\\n... (truncated)"
145
- return {"success": r.returncode == 0, "output": output, "error": "" if r.returncode == 0 else f"Exit code {r.returncode}"}
146
- except subprocess.TimeoutExpired:
147
- return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"}
 
 
 
 
 
 
 
 
 
 
 
 
148
  except Exception as e:
149
  return {"success": False, "output": "", "error": str(e)}
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  @app.post("/api/read")
152
  def read(req: ReadReq):
153
  try:
@@ -566,6 +600,10 @@ class Sandbox:
566
  },
567
  )
568
 
 
 
 
 
569
  # ── Tool schemas & dispatch ───────────────────────────────────
570
 
571
  TOOLS = {
 
97
 
98
  _SANDBOX_SERVER = '''\
99
  """Minimal FastAPI server for sandbox operations."""
100
+ import os, subprocess, pathlib, signal, threading
101
  from fastapi import FastAPI
102
  from pydantic import BaseModel
103
  from typing import Optional
 
105
 
106
  app = FastAPI()
107
 
108
+ # Track active bash processes so they can be killed on cancel
109
+ _active_procs = {} # pid -> subprocess.Popen
110
+ _proc_lock = threading.Lock()
111
+
112
  class BashReq(BaseModel):
113
  command: str
114
  work_dir: str = "/app"
 
139
  @app.post("/api/bash")
140
  def bash(req: BashReq):
141
  try:
142
+ proc = subprocess.Popen(
143
+ req.command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
144
+ text=True, cwd=req.work_dir, start_new_session=True,
145
  )
146
+ with _proc_lock:
147
+ _active_procs[proc.pid] = proc
148
+ try:
149
+ stdout, stderr = proc.communicate(timeout=req.timeout)
150
+ output = stdout + stderr
151
+ if len(output) > 30000:
152
+ output = output[:30000] + "\\n... (truncated)"
153
+ return {"success": proc.returncode == 0, "output": output, "error": "" if proc.returncode == 0 else f"Exit code {proc.returncode}"}
154
+ except subprocess.TimeoutExpired:
155
+ try:
156
+ os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
157
+ except OSError:
158
+ proc.kill()
159
+ proc.wait()
160
+ return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"}
161
+ finally:
162
+ with _proc_lock:
163
+ _active_procs.pop(proc.pid, None)
164
  except Exception as e:
165
  return {"success": False, "output": "", "error": str(e)}
166
 
167
+ @app.post("/api/kill")
168
+ def kill_all():
169
+ """Kill all active bash processes. Called when user cancels."""
170
+ with _proc_lock:
171
+ pids = list(_active_procs.keys())
172
+ killed = []
173
+ for pid in pids:
174
+ try:
175
+ os.killpg(os.getpgid(pid), signal.SIGTERM)
176
+ killed.append(pid)
177
+ except OSError:
178
+ try:
179
+ os.kill(pid, signal.SIGKILL)
180
+ killed.append(pid)
181
+ except OSError:
182
+ pass
183
+ return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""}
184
+
185
  @app.post("/api/read")
186
  def read(req: ReadReq):
187
  try:
 
600
  },
601
  )
602
 
603
+ def kill_all(self) -> ToolResult:
604
+ """Kill all active bash processes on the sandbox. Used on cancellation."""
605
+ return self._call("kill", {})
606
+
607
  # ── Tool schemas & dispatch ───────────────────────────────────
608
 
609
  TOOLS = {
agent/tools/sandbox_tool.py CHANGED
@@ -27,11 +27,17 @@ def _looks_like_path(script: str) -> bool:
27
  isinstance(script, str)
28
  and script.strip() == script
29
  and not any(c in script for c in "\r\n\0")
30
- and (script.startswith("/") or script.startswith("./") or script.startswith("../"))
 
 
 
 
31
  )
32
 
33
 
34
- async def resolve_sandbox_script(sandbox: Any, script: str) -> tuple[str | None, str | None]:
 
 
35
  """Read a file from the sandbox if *script* looks like a path.
36
 
37
  Returns:
@@ -42,15 +48,14 @@ async def resolve_sandbox_script(sandbox: Any, script: str) -> tuple[str | None,
42
  if not sandbox or not _looks_like_path(script):
43
  return None, None
44
  try:
45
- result = await asyncio.to_thread(
46
- sandbox.bash, f"cat {shlex.quote(script)}"
47
- )
48
  if result.success and result.output:
49
  return result.output, None
50
  return None, f"Failed to read {script} from sandbox: {result.error}"
51
  except Exception as e:
52
  return None, f"Failed to read {script} from sandbox: {e}"
53
 
 
54
  # ── Tool name mapping (short agent names β†’ Sandbox client names) ──────
55
 
56
 
@@ -98,12 +103,30 @@ async def _ensure_sandbox(
98
  Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
99
  )
100
 
101
- kwargs = {"owner": owner, "hardware": hardware, "token": token, "log": _log, **create_kwargs}
 
 
 
 
 
 
102
  if hardware != "cpu-basic":
103
  kwargs["sleep_time"] = 2700
104
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
105
  session.sandbox = sb
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # Inject the OAuth token into the sandbox so Hub operations work inside it
108
  await asyncio.to_thread(api.add_space_secret, sb.space_id, "HF_TOKEN", token)
109
 
 
27
  isinstance(script, str)
28
  and script.strip() == script
29
  and not any(c in script for c in "\r\n\0")
30
+ and (
31
+ script.startswith("/")
32
+ or script.startswith("./")
33
+ or script.startswith("../")
34
+ )
35
  )
36
 
37
 
38
+ async def resolve_sandbox_script(
39
+ sandbox: Any, script: str
40
+ ) -> tuple[str | None, str | None]:
41
  """Read a file from the sandbox if *script* looks like a path.
42
 
43
  Returns:
 
48
  if not sandbox or not _looks_like_path(script):
49
  return None, None
50
  try:
51
+ result = await asyncio.to_thread(sandbox.bash, f"cat {shlex.quote(script)}")
 
 
52
  if result.success and result.output:
53
  return result.output, None
54
  return None, f"Failed to read {script} from sandbox: {result.error}"
55
  except Exception as e:
56
  return None, f"Failed to read {script} from sandbox: {e}"
57
 
58
+
59
  # ── Tool name mapping (short agent names β†’ Sandbox client names) ──────
60
 
61
 
 
103
  Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
104
  )
105
 
106
+ kwargs = {
107
+ "owner": owner,
108
+ "hardware": hardware,
109
+ "token": token,
110
+ "log": _log,
111
+ **create_kwargs,
112
+ }
113
  if hardware != "cpu-basic":
114
  kwargs["sleep_time"] = 2700
115
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
116
  session.sandbox = sb
117
 
118
+ # Set a descriptive title (template title is inherited on duplicate)
119
+ from huggingface_hub import metadata_update
120
+
121
+ await asyncio.to_thread(
122
+ metadata_update,
123
+ sb.space_id,
124
+ {"title": "ml-agent sandbox"},
125
+ repo_type="space",
126
+ overwrite=True,
127
+ token=token,
128
+ )
129
+
130
  # Inject the OAuth token into the sandbox so Hub operations work inside it
131
  await asyncio.to_thread(api.add_space_secret, sb.space_id, "HF_TOKEN", token)
132