akseljoonas HF Staff commited on
Commit
d0d08fc
·
2 Parent(s): 471f346ff8c636

Deploy 2026-04-26

Browse files
Files changed (42) hide show
  1. .github/workflows/claude-review.yml +13 -2
  2. README.md +5 -2
  3. agent/context_manager/manager.py +36 -32
  4. agent/core/agent_loop.py +140 -6
  5. agent/core/doom_loop.py +24 -4
  6. agent/core/effort_probe.py +8 -4
  7. agent/core/hf_access.py +181 -0
  8. agent/core/llm_params.py +2 -2
  9. agent/core/model_switcher.py +5 -2
  10. agent/core/telemetry.py +38 -0
  11. agent/main.py +19 -9
  12. agent/tools/jobs_tool.py +20 -2
  13. agent/tools/research_tool.py +3 -1
  14. agent/tools/sandbox_tool.py +11 -1
  15. agent/utils/terminal_display.py +2 -20
  16. backend/dependencies.py +6 -50
  17. backend/models.py +1 -0
  18. backend/routes/agent.py +154 -2
  19. backend/session_manager.py +1 -1
  20. configs/cli_agent_config.json +14 -0
  21. configs/{main_agent_config.json → frontend_agent_config.json} +0 -0
  22. frontend/src/components/Chat/ChatInput.tsx +58 -4
  23. frontend/src/components/ClaudeCapDialog.tsx +3 -0
  24. frontend/src/components/JobsUpgradeDialog.tsx +191 -0
  25. frontend/src/components/SessionChat.tsx +3 -1
  26. frontend/src/hooks/useAgentChat.ts +130 -1
  27. frontend/src/lib/sse-chat-transport.ts +26 -0
  28. frontend/src/store/agentStore.ts +37 -0
  29. frontend/src/types/agent.ts +1 -0
  30. frontend/src/utils/model.ts +3 -4
  31. pyproject.toml +6 -2
  32. scripts/build_kpis.py +24 -2
  33. tests/unit/test_build_kpis.py +32 -0
  34. tests/unit/test_cli_rendering.py +44 -0
  35. tests/unit/test_dangling_tool_calls.py +121 -0
  36. tests/unit/test_doom_loop_polling.py +96 -0
  37. tests/unit/test_hf_access.py +39 -0
  38. tests/unit/test_llm_error_classification.py +100 -0
  39. tests/unit/test_llm_params.py +25 -0
  40. tests/unit/test_malformed_args_recovery.py +66 -0
  41. tests/unit/test_sandbox_already_active_message.py +47 -0
  42. uv.lock +114 -72
.github/workflows/claude-review.yml CHANGED
@@ -1,8 +1,8 @@
1
  name: Claude PR Review
2
 
3
  on:
4
- pull_request:
5
- types: [opened, synchronize, ready_for_review]
6
 
7
  permissions:
8
  contents: read
@@ -22,6 +22,10 @@ jobs:
22
  - uses: actions/checkout@v4
23
  with:
24
  fetch-depth: 0
 
 
 
 
25
 
26
  - name: Compose review prompt
27
  id: compose
@@ -58,5 +62,12 @@ jobs:
58
  - uses: anthropics/claude-code-action@v1
59
  with:
60
  anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
 
 
 
 
 
 
 
61
  track_progress: true
62
  prompt: ${{ steps.compose.outputs.prompt }}
 
1
  name: Claude PR Review
2
 
3
  on:
4
+ pull_request_target:
5
+ types: [opened, synchronize, ready_for_review, reopened]
6
 
7
  permissions:
8
  contents: read
 
22
  - uses: actions/checkout@v4
23
  with:
24
  fetch-depth: 0
25
+ # On pull_request_target, keep checkout on the trusted base-repo ref.
26
+ # The Claude action can review the PR via GitHub context/API without
27
+ # executing untrusted fork code with repository secrets.
28
+ persist-credentials: false
29
 
30
  - name: Compose review prompt
31
  id: compose
 
62
  - uses: anthropics/claude-code-action@v1
63
  with:
64
  anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
65
+ # Bypass the OIDC -> Claude GitHub App token exchange. That exchange
66
+ # rejects OIDC tokens minted for pull_request_target events with
67
+ # "401 Invalid OIDC token", which broke every review after the switch
68
+ # away from pull_request. Using the workflow's GITHUB_TOKEN works for
69
+ # both same-repo and fork PRs; comments post as github-actions[bot]
70
+ # instead of claude[bot], which is the documented trade-off.
71
+ github_token: ${{ secrets.GITHUB_TOKEN }}
72
  track_progress: true
73
  prompt: ${{ steps.compose.outputs.prompt }}
README.md CHANGED
@@ -23,7 +23,7 @@ hf_oauth_scopes:
23
 
24
  # ML Intern
25
 
26
- An ML intern that autonomously researches, writes, and ships good quality ML releated code using the Hugging Face ecosystem — with deep access to docs, papers, datasets, and cloud compute.
27
 
28
  ## Quick Start
29
 
@@ -46,6 +46,7 @@ Create a `.env` file in the project root (or export these in your shell):
46
 
47
  ```bash
48
  ANTHROPIC_API_KEY=<your-anthropic-api-key> # if using anthropic models
 
49
  HF_TOKEN=<your-hugging-face-token>
50
  GITHUB_TOKEN=<github-personal-access-token>
51
  ```
@@ -69,6 +70,7 @@ ml-intern "fine-tune llama on my dataset"
69
 
70
  ```bash
71
  ml-intern --model anthropic/claude-opus-4-6 "your prompt"
 
72
  ml-intern --max-iterations 100 "your prompt"
73
  ml-intern --no-stream "your prompt"
74
  ```
@@ -229,7 +231,8 @@ def create_builtin_tools() -> list[ToolSpec]:
229
 
230
  ### Adding MCP Servers
231
 
232
- Edit `configs/main_agent_config.json`:
 
233
 
234
  ```json
235
  {
 
23
 
24
  # ML Intern
25
 
26
+ An ML intern that autonomously researches, writes, and ships good quality ML related code using the Hugging Face ecosystem — with deep access to docs, papers, datasets, and cloud compute.
27
 
28
  ## Quick Start
29
 
 
46
 
47
  ```bash
48
  ANTHROPIC_API_KEY=<your-anthropic-api-key> # if using anthropic models
49
+ OPENAI_API_KEY=<your-openai-api-key> # if using openai models
50
  HF_TOKEN=<your-hugging-face-token>
51
  GITHUB_TOKEN=<github-personal-access-token>
52
  ```
 
70
 
71
  ```bash
72
  ml-intern --model anthropic/claude-opus-4-6 "your prompt"
73
+ ml-intern --model openai/gpt-5.5 "your prompt"
74
  ml-intern --max-iterations 100 "your prompt"
75
  ml-intern --no-stream "your prompt"
76
  ```
 
231
 
232
  ### Adding MCP Servers
233
 
234
+ Edit `configs/cli_agent_config.json` for CLI defaults, or
235
+ `configs/frontend_agent_config.json` for web-session defaults:
236
 
237
  ```json
238
  {
agent/context_manager/manager.py CHANGED
@@ -253,45 +253,49 @@ class ContextManager:
253
  def _patch_dangling_tool_calls(self) -> None:
254
  """Add stub tool results for any tool_calls that lack a matching result.
255
 
256
- Scans backwards to find the last assistant message with tool_calls,
257
- which may not be items[-1] if some tool results were already added.
 
 
258
  """
259
  if not self.items:
260
  return
261
 
262
- # Find the last assistant message with tool_calls
263
- assistant_msg = None
264
- for i in range(len(self.items) - 1, -1, -1):
265
  msg = self.items[i]
266
- if getattr(msg, "role", None) == "assistant" and getattr(
267
- msg, "tool_calls", None
268
- ):
269
- assistant_msg = msg
270
- break
271
- # Stop scanning once we hit a user message — anything before
272
- # that belongs to a previous (complete) turn.
273
- if getattr(msg, "role", None) == "user":
274
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- if not assistant_msg:
277
- return
 
278
 
279
- self._normalize_tool_calls(assistant_msg)
280
- answered_ids = {
281
- getattr(m, "tool_call_id", None)
282
- for m in self.items
283
- if getattr(m, "role", None) == "tool"
284
- }
285
- for tc in assistant_msg.tool_calls:
286
- if tc.id not in answered_ids:
287
- self.items.append(
288
- Message(
289
- role="tool",
290
- content="Tool was not executed (interrupted or error).",
291
- tool_call_id=tc.id,
292
- name=tc.function.name,
293
- )
294
- )
295
 
296
  def undo_last_turn(self) -> bool:
297
  """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
 
253
  def _patch_dangling_tool_calls(self) -> None:
254
  """Add stub tool results for any tool_calls that lack a matching result.
255
 
256
+ Ensures each assistant message's tool_calls are followed immediately
257
+ by matching tool-result messages. This has to work across the whole
258
+ history, not just the most recent turn, because a cancelled tool use
259
+ in an earlier turn can still poison the next provider request.
260
  """
261
  if not self.items:
262
  return
263
 
264
+ i = 0
265
+ while i < len(self.items):
 
266
  msg = self.items[i]
267
+ if getattr(msg, "role", None) != "assistant" or not getattr(msg, "tool_calls", None):
268
+ i += 1
269
+ continue
270
+
271
+ self._normalize_tool_calls(msg)
272
+
273
+ # Consume the contiguous tool-result block that immediately follows
274
+ # this assistant message. Any missing tool ids must be inserted
275
+ # before the next non-tool message to satisfy provider ordering.
276
+ j = i + 1
277
+ immediate_ids: set[str | None] = set()
278
+ while j < len(self.items) and getattr(self.items[j], "role", None) == "tool":
279
+ immediate_ids.add(getattr(self.items[j], "tool_call_id", None))
280
+ j += 1
281
+
282
+ missing: list[Message] = []
283
+ for tc in msg.tool_calls:
284
+ if tc.id not in immediate_ids:
285
+ missing.append(
286
+ Message(
287
+ role="tool",
288
+ content="Tool was not executed (interrupted or error).",
289
+ tool_call_id=tc.id,
290
+ name=tc.function.name,
291
+ )
292
+ )
293
 
294
+ if missing:
295
+ self.items[j:j] = missing
296
+ j += len(missing)
297
 
298
+ i = j
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  def undo_last_turn(self) -> bool:
301
  """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
agent/core/agent_loop.py CHANGED
@@ -25,6 +25,61 @@ logger = logging.getLogger(__name__)
25
 
26
  ToolCall = ChatCompletionMessageToolCall
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
30
  """
@@ -121,6 +176,54 @@ def _needs_approval(
121
  # -- LLM retry constants --------------------------------------------------
122
  _MAX_LLM_RETRIES = 3
123
  _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def _is_transient_error(error: Exception) -> bool:
@@ -128,7 +231,6 @@ def _is_transient_error(error: Exception) -> bool:
128
  err_str = str(error).lower()
129
  transient_patterns = [
130
  "timeout", "timed out",
131
- "429", "rate limit", "rate_limit",
132
  "503", "service unavailable",
133
  "502", "bad gateway",
134
  "500", "internal server error",
@@ -136,7 +238,7 @@ def _is_transient_error(error: Exception) -> bool:
136
  "connection reset", "connection refused", "connection error",
137
  "eof", "broken pipe",
138
  ]
139
- return any(pattern in err_str for pattern in transient_patterns)
140
 
141
 
142
  def _is_effort_config_error(error: Exception) -> bool:
@@ -317,6 +419,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
317
  except ContextWindowExceededError:
318
  raise
319
  except Exception as e:
 
 
320
  if not _healed_effort and _is_effort_config_error(e):
321
  _healed_effort = True
322
  llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
@@ -325,8 +429,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
325
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
326
  ))
327
  continue
328
- if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
329
- _delay = _LLM_RETRY_DELAYS[_llm_attempt]
330
  logger.warning(
331
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
332
  _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
@@ -424,6 +528,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
424
  except ContextWindowExceededError:
425
  raise
426
  except Exception as e:
 
 
427
  if not _healed_effort and _is_effort_config_error(e):
428
  _healed_effort = True
429
  llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
@@ -432,8 +538,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
432
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
433
  ))
434
  continue
435
- if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
436
- _delay = _LLM_RETRY_DELAYS[_llm_attempt]
437
  logger.warning(
438
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
439
  _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
@@ -585,6 +691,31 @@ class Handlers:
585
  )
586
  )
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  messages = session.context_manager.get_messages()
589
  tools = session.tool_router.get_tool_specs_for_llm()
590
  try:
@@ -1006,6 +1137,9 @@ class Handlers:
1006
  tool_args["script"] = edited_script
1007
  was_edited = True
1008
  logger.info(f"Using user-edited script for {tool_name} ({tc.id})")
 
 
 
1009
  approved_tasks.append((tc, tool_name, tool_args, was_edited))
1010
  else:
1011
  rejected_tasks.append((tc, tool_name, approval_decision))
 
25
 
26
  ToolCall = ChatCompletionMessageToolCall
27
 
28
+ _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
29
+ _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
30
+
31
+
32
+ def _malformed_tool_name(message: Message) -> str | None:
33
+ """Return the tool name for malformed-json tool-result messages."""
34
+ if getattr(message, "role", None) != "tool":
35
+ return None
36
+ content = getattr(message, "content", None)
37
+ if not isinstance(content, str):
38
+ return None
39
+ if not content.startswith(_MALFORMED_TOOL_PREFIX):
40
+ return None
41
+ end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX))
42
+ if end == -1:
43
+ return None
44
+ return content[len(_MALFORMED_TOOL_PREFIX):end]
45
+
46
+
47
+ def _detect_repeated_malformed(
48
+ items: list[Message], threshold: int = 2,
49
+ ) -> str | None:
50
+ """Return the repeated malformed tool name if the tail contains a streak.
51
+
52
+ Walk backward over the current conversation tail. A streak counts only
53
+ consecutive malformed tool-result messages for the same tool; any other
54
+ tool result breaks it.
55
+ """
56
+ if threshold <= 0:
57
+ return None
58
+
59
+ streak_tool: str | None = None
60
+ streak = 0
61
+
62
+ for item in reversed(items):
63
+ if getattr(item, "role", None) != "tool":
64
+ continue
65
+
66
+ malformed_tool = _malformed_tool_name(item)
67
+ if malformed_tool is None:
68
+ break
69
+
70
+ if streak_tool is None:
71
+ streak_tool = malformed_tool
72
+ streak = 1
73
+ elif malformed_tool == streak_tool:
74
+ streak += 1
75
+ else:
76
+ break
77
+
78
+ if streak >= threshold:
79
+ return streak_tool
80
+
81
+ return None
82
+
83
 
84
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
85
  """
 
176
  # -- LLM retry constants --------------------------------------------------
177
  _MAX_LLM_RETRIES = 3
178
  _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
179
+ _LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window
180
+
181
+
182
+ def _is_rate_limit_error(error: Exception) -> bool:
183
+ """Return True for rate-limit / quota-bucket style provider errors."""
184
+ err_str = str(error).lower()
185
+ rate_limit_patterns = [
186
+ "429",
187
+ "rate limit",
188
+ "rate_limit",
189
+ "too many requests",
190
+ "too many tokens",
191
+ "request limit",
192
+ "throttl",
193
+ ]
194
+ return any(pattern in err_str for pattern in rate_limit_patterns)
195
+
196
+
197
+ def _is_context_overflow_error(error: Exception) -> bool:
198
+ """Return True when the prompt exceeded the model's context window."""
199
+ if isinstance(error, ContextWindowExceededError):
200
+ return True
201
+
202
+ err_str = str(error).lower()
203
+ overflow_patterns = [
204
+ "context window exceeded",
205
+ "maximum context length",
206
+ "max context length",
207
+ "prompt is too long",
208
+ "context length exceeded",
209
+ "too many input tokens",
210
+ "input is too long",
211
+ ]
212
+ return any(pattern in err_str for pattern in overflow_patterns)
213
+
214
+
215
+ def _retry_delay_for(error: Exception, attempt_index: int) -> int | None:
216
+ """Return the delay for this retry attempt, or None if it should not retry."""
217
+ if _is_rate_limit_error(error):
218
+ schedule = _LLM_RATE_LIMIT_RETRY_DELAYS
219
+ elif _is_transient_error(error):
220
+ schedule = _LLM_RETRY_DELAYS
221
+ else:
222
+ return None
223
+
224
+ if attempt_index >= len(schedule):
225
+ return None
226
+ return schedule[attempt_index]
227
 
228
 
229
  def _is_transient_error(error: Exception) -> bool:
 
231
  err_str = str(error).lower()
232
  transient_patterns = [
233
  "timeout", "timed out",
 
234
  "503", "service unavailable",
235
  "502", "bad gateway",
236
  "500", "internal server error",
 
238
  "connection reset", "connection refused", "connection error",
239
  "eof", "broken pipe",
240
  ]
241
+ return _is_rate_limit_error(error) or any(pattern in err_str for pattern in transient_patterns)
242
 
243
 
244
  def _is_effort_config_error(error: Exception) -> bool:
 
419
  except ContextWindowExceededError:
420
  raise
421
  except Exception as e:
422
+ if _is_context_overflow_error(e):
423
+ raise ContextWindowExceededError(str(e)) from e
424
  if not _healed_effort and _is_effort_config_error(e):
425
  _healed_effort = True
426
  llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
 
429
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
430
  ))
431
  continue
432
+ _delay = _retry_delay_for(e, _llm_attempt)
433
+ if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
434
  logger.warning(
435
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
436
  _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
 
528
  except ContextWindowExceededError:
529
  raise
530
  except Exception as e:
531
+ if _is_context_overflow_error(e):
532
+ raise ContextWindowExceededError(str(e)) from e
533
  if not _healed_effort and _is_effort_config_error(e):
534
  _healed_effort = True
535
  llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
 
538
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
539
  ))
540
  continue
541
+ _delay = _retry_delay_for(e, _llm_attempt)
542
+ if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
543
  logger.warning(
544
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
545
  _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
 
691
  )
692
  )
693
 
694
+ malformed_tool = _detect_repeated_malformed(session.context_manager.items)
695
+ if malformed_tool:
696
+ recovery_prompt = (
697
+ "[SYSTEM: Repeated malformed tool arguments detected for "
698
+ f"'{malformed_tool}'. Stop retrying the same tool call shape. "
699
+ "Use a different strategy that produces smaller, valid JSON. "
700
+ "For large file writes, prefer bash with a heredoc or split the "
701
+ "edit into multiple smaller tool calls.]"
702
+ )
703
+ session.context_manager.add_message(
704
+ Message(role="user", content=recovery_prompt)
705
+ )
706
+ await session.send_event(
707
+ Event(
708
+ event_type="tool_log",
709
+ data={
710
+ "tool": "system",
711
+ "log": (
712
+ "Repeated malformed tool arguments detected — "
713
+ f"forcing a different strategy for {malformed_tool}"
714
+ ),
715
+ },
716
+ )
717
+ )
718
+
719
  messages = session.context_manager.get_messages()
720
  tools = session.tool_router.get_tool_specs_for_llm()
721
  try:
 
1137
  tool_args["script"] = edited_script
1138
  was_edited = True
1139
  logger.info(f"Using user-edited script for {tool_name} ({tc.id})")
1140
+ selected_namespace = approval_decision.get("namespace")
1141
+ if selected_namespace and tool_name == "hf_jobs":
1142
+ tool_args["namespace"] = selected_namespace
1143
  approved_tasks.append((tc, tool_name, tool_args, was_edited))
1144
  else:
1145
  rejected_tasks.append((tc, tool_name, approval_decision))
agent/core/doom_loop.py CHANGED
@@ -17,10 +17,11 @@ logger = logging.getLogger(__name__)
17
 
18
  @dataclass(frozen=True)
19
  class ToolCallSignature:
20
- """Hashable signature for a single tool call (name + args hash)."""
21
 
22
  name: str
23
  args_hash: str
 
24
 
25
 
26
  def _hash_args(args_str: str) -> str:
@@ -31,11 +32,16 @@ def _hash_args(args_str: str) -> str:
31
  def extract_recent_tool_signatures(
32
  messages: list[Message], lookback: int = 30
33
  ) -> list[ToolCallSignature]:
34
- """Extract tool call signatures from recent assistant messages."""
 
 
 
 
 
35
  signatures: list[ToolCallSignature] = []
36
  recent = messages[-lookback:] if len(messages) > lookback else messages
37
 
38
- for msg in recent:
39
  if getattr(msg, "role", None) != "assistant":
40
  continue
41
  tool_calls = getattr(msg, "tool_calls", None)
@@ -47,7 +53,21 @@ def extract_recent_tool_signatures(
47
  continue
48
  name = getattr(fn, "name", "") or ""
49
  args_str = getattr(fn, "arguments", "") or ""
50
- signatures.append(ToolCallSignature(name=name, args_hash=_hash_args(args_str)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  return signatures
53
 
 
17
 
18
  @dataclass(frozen=True)
19
  class ToolCallSignature:
20
+ """Hashable signature for a single tool call plus its observed result."""
21
 
22
  name: str
23
  args_hash: str
24
+ result_hash: str | None = None
25
 
26
 
27
  def _hash_args(args_str: str) -> str:
 
32
  def extract_recent_tool_signatures(
33
  messages: list[Message], lookback: int = 30
34
  ) -> list[ToolCallSignature]:
35
+ """Extract tool call signatures from recent assistant messages.
36
+
37
+ Includes the immediate tool result hash when present. This prevents
38
+ legitimate polling from being classified as a doom loop when the poll
39
+ arguments stay constant but the observed result keeps changing.
40
+ """
41
  signatures: list[ToolCallSignature] = []
42
  recent = messages[-lookback:] if len(messages) > lookback else messages
43
 
44
+ for idx, msg in enumerate(recent):
45
  if getattr(msg, "role", None) != "assistant":
46
  continue
47
  tool_calls = getattr(msg, "tool_calls", None)
 
53
  continue
54
  name = getattr(fn, "name", "") or ""
55
  args_str = getattr(fn, "arguments", "") or ""
56
+ result_hash = None
57
+ for follow in recent[idx + 1:]:
58
+ role = getattr(follow, "role", None)
59
+ if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(tc, "id", None):
60
+ result_hash = _hash_args(str(getattr(follow, "content", "") or ""))
61
+ break
62
+ if role in {"assistant", "user"}:
63
+ break
64
+ signatures.append(
65
+ ToolCallSignature(
66
+ name=name,
67
+ args_hash=_hash_args(args_str),
68
+ result_hash=result_hash,
69
+ )
70
+ )
71
 
72
  return signatures
73
 
agent/core/effort_probe.py CHANGED
@@ -32,9 +32,10 @@ logger = logging.getLogger(__name__)
32
 
33
 
34
  # Cascade: for each user-stated preference, the ordered list of levels to
35
- # try. First success wins. ``max`` / ``xhigh`` are Anthropic-only; providers
36
- # that don't accept them raise ``UnsupportedEffortError`` synchronously (no
37
- # wasted network round-trip) and we advance to the next level.
 
38
  _EFFORT_CASCADE: dict[str, list[str]] = {
39
  "max": ["max", "xhigh", "high", "medium", "low"],
40
  "xhigh": ["xhigh", "high", "medium", "low"],
@@ -45,7 +46,10 @@ _EFFORT_CASCADE: dict[str, list[str]] = {
45
  }
46
 
47
  _PROBE_TIMEOUT = 15.0
48
- _PROBE_MAX_TOKENS = 16
 
 
 
49
 
50
 
51
  class ProbeInconclusive(Exception):
 
32
 
33
 
34
  # Cascade: for each user-stated preference, the ordered list of levels to
35
+ # try. First success wins. ``max`` is Anthropic-only; ``xhigh`` is also
36
+ # supported on current OpenAI GPT-5 models. Providers that don't accept a
37
+ # requested level raise ``UnsupportedEffortError`` synchronously (no wasted
38
+ # network round-trip) and we advance to the next level.
39
  _EFFORT_CASCADE: dict[str, list[str]] = {
40
  "max": ["max", "xhigh", "high", "medium", "low"],
41
  "xhigh": ["xhigh", "high", "medium", "low"],
 
46
  }
47
 
48
  _PROBE_TIMEOUT = 15.0
49
+ # Keep the probe cheap, but high enough that frontier reasoning models can
50
+ # finish a trivial reply instead of tripping a false "output limit reached"
51
+ # error during capability detection.
52
+ _PROBE_MAX_TOKENS = 64
53
 
54
 
55
  class ProbeInconclusive(Exception):
agent/core/hf_access.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for Hugging Face account / org access decisions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ import httpx
11
+
12
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class JobsAccess:
17
+ """Jobs entitlement derived from whoami-v2."""
18
+
19
+ username: str | None
20
+ plan: str
21
+ personal_can_run_jobs: bool
22
+ paid_org_names: list[str]
23
+ eligible_namespaces: list[str]
24
+ default_namespace: str | None
25
+ access_known: bool = True
26
+
27
+ @property
28
+ def can_run_jobs(self) -> bool:
29
+ return bool(self.default_namespace)
30
+
31
+
32
+ class JobsAccessError(Exception):
33
+ """Structured jobs access error for upgrade / namespace gating."""
34
+
35
+ def __init__(
36
+ self,
37
+ message: str,
38
+ *,
39
+ access: JobsAccess | None = None,
40
+ upgrade_required: bool = False,
41
+ namespace_required: bool = False,
42
+ ) -> None:
43
+ super().__init__(message)
44
+ self.access = access
45
+ self.upgrade_required = upgrade_required
46
+ self.namespace_required = namespace_required
47
+
48
+
49
+ def _extract_username(whoami: dict[str, Any]) -> str | None:
50
+ for key in ("name", "user", "preferred_username"):
51
+ value = whoami.get(key)
52
+ if isinstance(value, str) and value:
53
+ return value
54
+ return None
55
+
56
+
57
+ def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
58
+ plan_str = ""
59
+ for key in ("plan", "type", "accountType"):
60
+ value = whoami.get(key)
61
+ if isinstance(value, str) and value:
62
+ plan_str = value.lower()
63
+ break
64
+
65
+ if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True):
66
+ return "pro"
67
+
68
+ if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
69
+ return "pro"
70
+ return "free"
71
+
72
+
73
+ def _paid_org_names(whoami: dict[str, Any]) -> list[str]:
74
+ names: list[str] = []
75
+ orgs = whoami.get("orgs") or []
76
+ if not isinstance(orgs, list):
77
+ return names
78
+
79
+ for org in orgs:
80
+ if not isinstance(org, dict):
81
+ continue
82
+ name = org.get("name")
83
+ if not isinstance(name, str) or not name:
84
+ continue
85
+ org_plan = str(org.get("plan") or org.get("type") or "").lower()
86
+ if any(tag in org_plan for tag in ("pro", "enterprise", "team")):
87
+ names.append(name)
88
+ return sorted(set(names))
89
+
90
+
91
+ def jobs_access_from_whoami(whoami: dict[str, Any]) -> JobsAccess:
92
+ username = _extract_username(whoami)
93
+ personal_plan = _normalize_personal_plan(whoami)
94
+ paid_orgs = _paid_org_names(whoami)
95
+ personal_can_run = personal_plan == "pro"
96
+
97
+ eligible_namespaces: list[str] = []
98
+ if personal_can_run and username:
99
+ eligible_namespaces.append(username)
100
+ eligible_namespaces.extend(paid_orgs)
101
+
102
+ plan = "pro" if personal_can_run else ("org" if paid_orgs else "free")
103
+ default_namespace = username if personal_can_run and username else None
104
+
105
+ return JobsAccess(
106
+ username=username,
107
+ plan=plan,
108
+ personal_can_run_jobs=personal_can_run,
109
+ paid_org_names=paid_orgs,
110
+ eligible_namespaces=eligible_namespaces,
111
+ default_namespace=default_namespace,
112
+ )
113
+
114
+
115
+ async def fetch_whoami_v2(token: str, timeout: float = 5.0) -> dict[str, Any] | None:
116
+ if not token:
117
+ return None
118
+ async with httpx.AsyncClient(timeout=timeout) as client:
119
+ try:
120
+ response = await client.get(
121
+ f"{OPENID_PROVIDER_URL}/api/whoami-v2",
122
+ headers={"Authorization": f"Bearer {token}"},
123
+ )
124
+ if response.status_code != 200:
125
+ return None
126
+ payload = response.json()
127
+ return payload if isinstance(payload, dict) else None
128
+ except (httpx.HTTPError, ValueError):
129
+ return None
130
+
131
+
132
+ async def get_jobs_access(token: str) -> JobsAccess | None:
133
+ whoami = await fetch_whoami_v2(token)
134
+ if whoami is None:
135
+ return None
136
+ return jobs_access_from_whoami(whoami)
137
+
138
+
139
+ async def resolve_jobs_namespace(
140
+ token: str,
141
+ requested_namespace: str | None = None,
142
+ ) -> tuple[str, JobsAccess | None]:
143
+ """Return the namespace to use for jobs.
144
+
145
+ If whoami-v2 is unavailable, fall back to the token owner's username.
146
+ """
147
+ access = await get_jobs_access(token)
148
+ if access:
149
+ if requested_namespace:
150
+ if requested_namespace in access.eligible_namespaces:
151
+ return requested_namespace, access
152
+ raise JobsAccessError(
153
+ f"You can only run jobs under your own Pro account or a paid org you belong to. "
154
+ f"Allowed namespaces: {', '.join(access.eligible_namespaces) or '(none)'}",
155
+ access=access,
156
+ )
157
+ if access.default_namespace:
158
+ return access.default_namespace, access
159
+ if access.paid_org_names:
160
+ raise JobsAccessError(
161
+ "Choose which paid organization should own this job run.",
162
+ access=access,
163
+ namespace_required=True,
164
+ )
165
+ raise JobsAccessError(
166
+ "Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations. "
167
+ "Upgrade to Pro, or run the job under a paid org you belong to.",
168
+ access=access,
169
+ upgrade_required=True,
170
+ )
171
+
172
+ # Fallback: whoami-v2 unavailable. Do not block the call pre-emptively.
173
+ from huggingface_hub import HfApi
174
+
175
+ username = None
176
+ if token:
177
+ whoami = await asyncio.to_thread(HfApi(token=token).whoami)
178
+ username = whoami.get("name")
179
+ if not username:
180
+ raise JobsAccessError("No HF token available to resolve a jobs namespace.")
181
+ return requested_namespace or username, None
agent/core/llm_params.py CHANGED
@@ -66,13 +66,13 @@ _patch_litellm_effort_validation()
66
 
67
  # Effort levels accepted on the wire.
68
  # Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort)
69
- # OpenAI direct: minimal | low | medium | high (reasoning_effort top-level)
70
  # HF router: low | medium | high (extra_body.reasoning_effort)
71
  #
72
  # We validate *shape* here and let the probe cascade walk down on rejection;
73
  # we deliberately do NOT maintain a per-model capability table.
74
  _ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"}
75
- _OPENAI_EFFORTS = {"minimal", "low", "medium", "high"}
76
  _HF_EFFORTS = {"low", "medium", "high"}
77
 
78
 
 
66
 
67
  # Effort levels accepted on the wire.
68
  # Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort)
69
+ # OpenAI direct: minimal | low | medium | high | xhigh (reasoning_effort top-level)
70
  # HF router: low | medium | high (extra_body.reasoning_effort)
71
  #
72
  # We validate *shape* here and let the probe cascade walk down on rejection;
73
  # we deliberately do NOT maintain a per-model capability table.
74
  _ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"}
75
+ _OPENAI_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"}
76
  _HF_EFFORTS = {"low", "medium", "high"}
77
 
78
 
agent/core/model_switcher.py CHANGED
@@ -24,8 +24,11 @@ from agent.core.effort_probe import ProbeInconclusive, probe_effort
24
  # ":cheapest" / ":preferred" / ":<provider>" to override the default
25
  # routing policy (auto = fastest with failover).
26
  SUGGESTED_MODELS = [
27
- {"id": "bedrock/us.anthropic.claude-opus-4-7", "label": "Claude Opus 4.7"},
28
- {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6"},
 
 
 
29
  {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
30
  {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
31
  {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
 
24
  # ":cheapest" / ":preferred" / ":<provider>" to override the default
25
  # routing policy (auto = fastest with failover).
26
  SUGGESTED_MODELS = [
27
+ {"id": "openai/gpt-5.5", "label": "GPT-5.5"},
28
+ {"id": "openai/gpt-5.4", "label": "GPT-5.4"},
29
+ {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"},
30
+ {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"},
31
+ {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6 via Bedrock"},
32
  {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
33
  {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
34
  {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
agent/core/telemetry.py CHANGED
@@ -141,6 +141,7 @@ async def record_hf_job_submit(
141
  "timeout": args.get("timeout", "30m"),
142
  "job_type": job_type,
143
  "image": image,
 
144
  "push_to_hub": _infer_push_to_hub(script_text),
145
  },
146
  ))
@@ -239,6 +240,43 @@ async def record_feedback(
239
  logger.debug("record_feedback failed (non-fatal): %s", e)
240
 
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  # ── heartbeat ──────────────────────────────────────────────────────────────
243
 
244
  # Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
 
141
  "timeout": args.get("timeout", "30m"),
142
  "job_type": job_type,
143
  "image": image,
144
+ "namespace": args.get("namespace"),
145
  "push_to_hub": _infer_push_to_hub(script_text),
146
  },
147
  ))
 
240
  logger.debug("record_feedback failed (non-fatal): %s", e)
241
 
242
 
243
+ async def record_jobs_access_blocked(
244
+ session: Any,
245
+ *,
246
+ tool_call_ids: list[str],
247
+ plan: str,
248
+ eligible_namespaces: list[str],
249
+ ) -> None:
250
+ from agent.core.session import Event
251
+ try:
252
+ await session.send_event(Event(
253
+ event_type="jobs_access_blocked",
254
+ data={
255
+ "tool_call_ids": tool_call_ids,
256
+ "plan": plan,
257
+ "eligible_namespaces": eligible_namespaces,
258
+ },
259
+ ))
260
+ except Exception as e:
261
+ logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
262
+
263
+
264
+ async def record_pro_cta_click(
265
+ session: Any,
266
+ *,
267
+ source: str,
268
+ target: str = "pro_pricing",
269
+ ) -> None:
270
+ from agent.core.session import Event
271
+ try:
272
+ await session.send_event(Event(
273
+ event_type="pro_cta_click",
274
+ data={"source": source, "target": target},
275
+ ))
276
+ except Exception as e:
277
+ logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
278
+
279
+
280
  # ── heartbeat ──────────────────────────────────────────────────────────────
281
 
282
  # Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
agent/main.py CHANGED
@@ -50,6 +50,16 @@ litellm.drop_params = True
50
  # on every error — users don't need it, and our friendly errors cover the case.
51
  litellm.suppress_debug_info = True
52
 
 
 
 
 
 
 
 
 
 
 
53
  def _safe_get_args(arguments: dict) -> dict:
54
  """Safely extract args dict from arguments, handling cases where LLM passes string."""
55
  args = arguments.get("args", {})
@@ -771,8 +781,9 @@ async def _handle_slash_command(
771
  console.print(f" [dim]{m}: {eff or 'off'}[/dim]")
772
  console.print(
773
  "[dim]Set with '/effort minimal|low|medium|high|xhigh|max|off'. "
774
- "'max' and 'xhigh' are Anthropic-only; the cascade falls back "
775
- "to whatever the model actually accepts.[/dim]"
 
776
  )
777
  return None
778
  level = arg.lower()
@@ -820,6 +831,8 @@ async def main():
820
  if not hf_token:
821
  hf_token = await _prompt_and_save_hf_token(prompt_session)
822
 
 
 
823
  # Resolve username for banner
824
  hf_user = None
825
  try:
@@ -828,7 +841,7 @@ async def main():
828
  except Exception:
829
  pass
830
 
831
- print_banner(hf_user=hf_user)
832
 
833
  # Pre-warm the HF router catalog in the background so /model switches
834
  # don't block on a network fetch.
@@ -844,10 +857,6 @@ async def main():
844
  turn_complete_event.set()
845
  ready_event = asyncio.Event()
846
 
847
- # Start agent loop in background
848
- config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
849
- config = load_config(config_path)
850
-
851
  # Create tool router with local mode
852
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
853
 
@@ -1036,6 +1045,7 @@ async def headless_main(
1036
  import logging
1037
 
1038
  logging.basicConfig(level=logging.WARNING)
 
1039
 
1040
  hf_token = _get_hf_token()
1041
  if not hf_token:
@@ -1044,8 +1054,7 @@ async def headless_main(
1044
 
1045
  print(f"HF token loaded", file=sys.stderr)
1046
 
1047
- config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
1048
- config = load_config(config_path)
1049
  config.yolo_mode = True # Auto-approve everything in headless mode
1050
 
1051
  if model:
@@ -1221,6 +1230,7 @@ def cli():
1221
  import warnings
1222
  # Suppress aiohttp "Unclosed client session" noise during event loop teardown
1223
  _logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
 
1224
  # Suppress litellm pydantic deprecation warnings
1225
  warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm")
1226
  # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream)
 
50
  # on every error — users don't need it, and our friendly errors cover the case.
51
  litellm.suppress_debug_info = True
52
 
53
+ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
54
+
55
+
56
+ def _configure_runtime_logging() -> None:
57
+ """Keep third-party warning spam from punching through the interactive UI."""
58
+ import logging
59
+
60
+ logging.getLogger("LiteLLM").setLevel(logging.ERROR)
61
+ logging.getLogger("litellm").setLevel(logging.ERROR)
62
+
63
  def _safe_get_args(arguments: dict) -> dict:
64
  """Safely extract args dict from arguments, handling cases where LLM passes string."""
65
  args = arguments.get("args", {})
 
781
  console.print(f" [dim]{m}: {eff or 'off'}[/dim]")
782
  console.print(
783
  "[dim]Set with '/effort minimal|low|medium|high|xhigh|max|off'. "
784
+ "'max' is Anthropic-only; 'xhigh' is also supported by current "
785
+ "OpenAI GPT-5 models. The cascade falls back to whatever the "
786
+ "model actually accepts.[/dim]"
787
  )
788
  return None
789
  level = arg.lower()
 
831
  if not hf_token:
832
  hf_token = await _prompt_and_save_hf_token(prompt_session)
833
 
834
+ config = load_config(CLI_CONFIG_PATH)
835
+
836
  # Resolve username for banner
837
  hf_user = None
838
  try:
 
841
  except Exception:
842
  pass
843
 
844
+ print_banner(model=config.model_name, hf_user=hf_user)
845
 
846
  # Pre-warm the HF router catalog in the background so /model switches
847
  # don't block on a network fetch.
 
857
  turn_complete_event.set()
858
  ready_event = asyncio.Event()
859
 
 
 
 
 
860
  # Create tool router with local mode
861
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
862
 
 
1045
  import logging
1046
 
1047
  logging.basicConfig(level=logging.WARNING)
1048
+ _configure_runtime_logging()
1049
 
1050
  hf_token = _get_hf_token()
1051
  if not hf_token:
 
1054
 
1055
  print(f"HF token loaded", file=sys.stderr)
1056
 
1057
+ config = load_config(CLI_CONFIG_PATH)
 
1058
  config.yolo_mode = True # Auto-approve everything in headless mode
1059
 
1060
  if model:
 
1230
  import warnings
1231
  # Suppress aiohttp "Unclosed client session" noise during event loop teardown
1232
  _logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
1233
+ _configure_runtime_logging()
1234
  # Suppress litellm pydantic deprecation warnings
1235
  warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm")
1236
  # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream)
agent/tools/jobs_tool.py CHANGED
@@ -17,6 +17,7 @@ import httpx
17
  from huggingface_hub import HfApi
18
  from huggingface_hub.utils import HfHubHTTPError
19
 
 
20
  from agent.core.session import Event
21
  from agent.tools.types import ToolResult
22
 
@@ -298,6 +299,7 @@ class HfJobsTool:
298
  self,
299
  hf_token: Optional[str] = None,
300
  namespace: Optional[str] = None,
 
301
  log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
302
  session: Any = None,
303
  tool_call_id: Optional[str] = None,
@@ -305,6 +307,7 @@ class HfJobsTool:
305
  self.hf_token = hf_token
306
  self.api = HfApi(token=hf_token)
307
  self.namespace = namespace
 
308
  self.log_callback = log_callback
309
  self.session = session
310
  self.tool_call_id = tool_call_id
@@ -565,7 +568,7 @@ class HfJobsTool:
565
  from agent.core import telemetry
566
  submit_ts = await telemetry.record_hf_job_submit(
567
  self.session, job,
568
- {**args, "hardware_flavor": flavor, "timeout": timeout_str},
569
  image=image, job_type=job_type,
570
  )
571
 
@@ -1057,6 +1060,14 @@ HF_JOBS_TOOL_SPEC = {
1057
  "type": "object",
1058
  "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
1059
  },
 
 
 
 
 
 
 
 
1060
  "job_id": {
1061
  "type": "string",
1062
  "description": "Job ID. Required for: logs, inspect, cancel.",
@@ -1099,11 +1110,18 @@ async def hf_jobs_handler(
1099
  arguments = {**arguments, "script": content}
1100
 
1101
  hf_token = session.hf_token if session else None
1102
- namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None)
 
 
 
 
 
 
1103
 
1104
  tool = HfJobsTool(
1105
  namespace=namespace,
1106
  hf_token=hf_token,
 
1107
  log_callback=log_callback if session else None,
1108
  session=session,
1109
  tool_call_id=tool_call_id,
 
17
  from huggingface_hub import HfApi
18
  from huggingface_hub.utils import HfHubHTTPError
19
 
20
+ from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
21
  from agent.core.session import Event
22
  from agent.tools.types import ToolResult
23
 
 
299
  self,
300
  hf_token: Optional[str] = None,
301
  namespace: Optional[str] = None,
302
+ jobs_access: Any = None,
303
  log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
304
  session: Any = None,
305
  tool_call_id: Optional[str] = None,
 
307
  self.hf_token = hf_token
308
  self.api = HfApi(token=hf_token)
309
  self.namespace = namespace
310
+ self.jobs_access = jobs_access
311
  self.log_callback = log_callback
312
  self.session = session
313
  self.tool_call_id = tool_call_id
 
568
  from agent.core import telemetry
569
  submit_ts = await telemetry.record_hf_job_submit(
570
  self.session, job,
571
+ {**args, "hardware_flavor": flavor, "timeout": timeout_str, "namespace": self.namespace},
572
  image=image, job_type=job_type,
573
  )
574
 
 
1060
  "type": "object",
1061
  "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
1062
  },
1063
+ "namespace": {
1064
+ "type": "string",
1065
+ "description": (
1066
+ "Optional namespace to run the job under. Must be your own Pro account "
1067
+ "or a paid org you belong to. If omitted, the tool prefers your personal "
1068
+ "account when eligible, otherwise the first eligible paid org."
1069
+ ),
1070
+ },
1071
  "job_id": {
1072
  "type": "string",
1073
  "description": "Job ID. Required for: logs, inspect, cancel.",
 
1110
  arguments = {**arguments, "script": content}
1111
 
1112
  hf_token = session.hf_token if session else None
1113
+ try:
1114
+ namespace, jobs_access = await resolve_jobs_namespace(
1115
+ hf_token or "",
1116
+ arguments.get("namespace"),
1117
+ )
1118
+ except JobsAccessError as e:
1119
+ return str(e), False
1120
 
1121
  tool = HfJobsTool(
1122
  namespace=namespace,
1123
  hf_token=hf_token,
1124
+ jobs_access=jobs_access,
1125
  log_callback=log_callback if session else None,
1126
  session=session,
1127
  tool_call_id=tool_call_id,
agent/tools/research_tool.py CHANGED
@@ -216,7 +216,9 @@ RESEARCH_TOOL_SPEC = {
216
 
217
  def _get_research_model(main_model: str) -> str:
218
  """Pick a cheaper model for research based on the main model."""
219
- if "anthropic" in main_model:
 
 
220
  return "bedrock/us.anthropic.claude-sonnet-4-6"
221
  # For non-Anthropic models (HF router etc.), use the same model
222
  return main_model
 
216
 
217
  def _get_research_model(main_model: str) -> str:
218
  """Pick a cheaper model for research based on the main model."""
219
+ if main_model.startswith("anthropic/"):
220
+ return "anthropic/claude-sonnet-4-6"
221
+ if main_model.startswith("bedrock/") and "anthropic" in main_model:
222
  return "bedrock/us.anthropic.claude-sonnet-4-6"
223
  # For non-Anthropic models (HF router etc.), use the same model
224
  return main_model
agent/tools/sandbox_tool.py CHANGED
@@ -213,16 +213,26 @@ async def sandbox_create_handler(
213
  args: dict[str, Any], session: Any = None
214
  ) -> tuple[str, bool]:
215
  """Handle sandbox_create tool calls."""
 
 
216
  # If sandbox already exists, return its info
217
  if session and getattr(session, "sandbox", None):
218
  sb = session.sandbox
 
 
 
 
 
 
 
 
219
  return (
220
  f"Sandbox already active: {sb.space_id}\n"
221
  f"URL: {sb.url}\n"
 
222
  f"Use bash/read/write/edit to interact with it."
223
  ), True
224
 
225
- hardware = args.get("hardware", "cpu-basic")
226
  create_kwargs = {}
227
  if "private" in args:
228
  create_kwargs["private"] = args["private"]
 
213
  args: dict[str, Any], session: Any = None
214
  ) -> tuple[str, bool]:
215
  """Handle sandbox_create tool calls."""
216
+ hardware = args.get("hardware", "cpu-basic")
217
+
218
  # If sandbox already exists, return its info
219
  if session and getattr(session, "sandbox", None):
220
  sb = session.sandbox
221
+ requested_hardware = args.get("hardware")
222
+ lockout_note = ""
223
+ if requested_hardware:
224
+ lockout_note = (
225
+ f"\nRequested hardware: {requested_hardware}\n"
226
+ "Hardware cannot be changed by calling sandbox_create again. "
227
+ "Delete the existing sandbox first if you need a different tier."
228
+ )
229
  return (
230
  f"Sandbox already active: {sb.space_id}\n"
231
  f"URL: {sb.url}\n"
232
+ f"{lockout_note}\n"
233
  f"Use bash/read/write/edit to interact with it."
234
  ), True
235
 
 
236
  create_kwargs = {}
237
  if "private" in args:
238
  create_kwargs["private"] = args["private"]
agent/utils/terminal_display.py CHANGED
@@ -99,7 +99,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
99
  _console.file.write("\033[2J\033[H")
100
  _console.file.flush()
101
 
102
- model_label = model or "bedrock/us.anthropic.claude-opus-4-6-v1"
103
  user_label = hf_user or "not logged in"
104
 
105
  # Warm gold palette matching the shimmer highlight (255, 200, 80)
@@ -180,10 +180,8 @@ class SubAgentDisplayManager:
180
  def __init__(self):
181
  self._agents: dict[str, dict] = {} # agent_id -> state dict
182
  self._lines_on_screen = 0
183
- self._ticker_task = None
184
 
185
  def start(self, agent_id: str, label: str = "research") -> None:
186
- import asyncio
187
  import time
188
  self._agents[agent_id] = {
189
  "label": label,
@@ -192,8 +190,6 @@ class SubAgentDisplayManager:
192
  "token_count": 0,
193
  "start_time": time.monotonic(),
194
  }
195
- if not self._ticker_task:
196
- self._ticker_task = asyncio.ensure_future(self._tick())
197
  self._redraw()
198
 
199
  def set_tokens(self, agent_id: str, tokens: int) -> None:
@@ -222,11 +218,7 @@ class SubAgentDisplayManager:
222
  _console.file.write(line + "\n")
223
  _console.file.flush()
224
  self._lines_on_screen = 0
225
- if not self._agents:
226
- if self._ticker_task:
227
- self._ticker_task.cancel()
228
- self._ticker_task = None
229
- else:
230
  self._redraw()
231
 
232
  @staticmethod
@@ -239,16 +231,6 @@ class SubAgentDisplayManager:
239
  line += f" \033[2m({stats})\033[0m"
240
  return line
241
 
242
- async def _tick(self) -> None:
243
- import asyncio
244
- try:
245
- while True:
246
- await asyncio.sleep(1.0)
247
- if self._agents:
248
- self._redraw()
249
- except asyncio.CancelledError:
250
- pass
251
-
252
  @staticmethod
253
  def _format_stats(agent: dict) -> str:
254
  import time
 
99
  _console.file.write("\033[2J\033[H")
100
  _console.file.flush()
101
 
102
+ model_label = model or "unknown"
103
  user_label = hf_user or "not logged in"
104
 
105
  # Warm gold palette matching the shimmer highlight (255, 200, 80)
 
180
  def __init__(self):
181
  self._agents: dict[str, dict] = {} # agent_id -> state dict
182
  self._lines_on_screen = 0
 
183
 
184
  def start(self, agent_id: str, label: str = "research") -> None:
 
185
  import time
186
  self._agents[agent_id] = {
187
  "label": label,
 
190
  "token_count": 0,
191
  "start_time": time.monotonic(),
192
  }
 
 
193
  self._redraw()
194
 
195
  def set_tokens(self, agent_id: str, tokens: int) -> None:
 
218
  _console.file.write(line + "\n")
219
  _console.file.flush()
220
  self._lines_on_screen = 0
221
+ if self._agents:
 
 
 
 
222
  self._redraw()
223
 
224
  @staticmethod
 
231
  line += f" \033[2m({stats})\033[0m"
232
  return line
233
 
 
 
 
 
 
 
 
 
 
 
234
  @staticmethod
235
  def _format_stats(agent: dict) -> str:
236
  import time
backend/dependencies.py CHANGED
@@ -12,6 +12,8 @@ from typing import Any
12
  import httpx
13
  from fastapi import HTTPException, Request, status
14
 
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
@@ -80,41 +82,6 @@ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
80
  }
81
 
82
 
83
- def _normalize_plan(whoami: dict[str, Any]) -> str:
84
- """Map an HF /api/whoami-v2 payload to one of: 'free' | 'pro' | 'org'.
85
-
86
- The exact field shape in whoami-v2 isn't documented for our purposes,
87
- so we try a handful of likely keys and fall back to 'free'. The first
88
- call logs the raw shape at DEBUG (see `_fetch_user_plan`) so we can
89
- pin the real key post-deploy.
90
- """
91
- plan_str = ""
92
- for key in ("plan", "type", "accountType"):
93
- val = whoami.get(key)
94
- if isinstance(val, str) and val:
95
- plan_str = val.lower()
96
- break
97
-
98
- if not plan_str:
99
- if whoami.get("isPro") is True or whoami.get("is_pro") is True:
100
- return "pro"
101
-
102
- if "pro" in plan_str or "enterprise" in plan_str or "team" in plan_str:
103
- return "pro"
104
-
105
- # Org tier: anyone in a paid / enterprise org. We don't pay for this
106
- # right now, but the "pro" cap applies identically.
107
- orgs = whoami.get("orgs") or []
108
- if isinstance(orgs, list):
109
- for org in orgs:
110
- if isinstance(org, dict):
111
- org_plan = str(org.get("plan") or org.get("type") or "").lower()
112
- if "pro" in org_plan or "enterprise" in org_plan or "team" in org_plan:
113
- return "org"
114
-
115
- return "free"
116
-
117
-
118
  async def _fetch_user_plan(token: str) -> str:
119
  """Look up the user's HF plan via /api/whoami-v2.
120
 
@@ -123,19 +90,9 @@ async def _fetch_user_plan(token: str) -> str:
123
  grant the Pro cap than over-grant it on bad data.
124
  """
125
  global _WHOAMI_SHAPE_LOGGED
126
- async with httpx.AsyncClient(timeout=5.0) as client:
127
- try:
128
- resp = await client.get(
129
- f"{OPENID_PROVIDER_URL}/api/whoami-v2",
130
- headers={"Authorization": f"Bearer {token}"},
131
- )
132
- if resp.status_code != 200:
133
- return "free"
134
- whoami = resp.json()
135
- except httpx.HTTPError:
136
- return "free"
137
- except ValueError:
138
- return "free"
139
 
140
  if not _WHOAMI_SHAPE_LOGGED:
141
  _WHOAMI_SHAPE_LOGGED = True
@@ -149,7 +106,7 @@ async def _fetch_user_plan(token: str) -> str:
149
 
150
  if not isinstance(whoami, dict):
151
  return "free"
152
- return _normalize_plan(whoami)
153
 
154
 
155
  async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
@@ -246,4 +203,3 @@ async def require_huggingface_org_member(request: Request) -> bool:
246
  return False
247
  return await check_org_membership(token, HF_EMPLOYEE_ORG)
248
 
249
-
 
12
  import httpx
13
  from fastapi import HTTPException, Request, status
14
 
15
+ from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
16
+
17
  logger = logging.getLogger(__name__)
18
 
19
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
 
82
  }
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  async def _fetch_user_plan(token: str) -> str:
86
  """Look up the user's HF plan via /api/whoami-v2.
87
 
 
90
  grant the Pro cap than over-grant it on bad data.
91
  """
92
  global _WHOAMI_SHAPE_LOGGED
93
+ whoami = await fetch_whoami_v2(token)
94
+ if whoami is None:
95
+ return "free"
 
 
 
 
 
 
 
 
 
 
96
 
97
  if not _WHOAMI_SHAPE_LOGGED:
98
  _WHOAMI_SHAPE_LOGGED = True
 
106
 
107
  if not isinstance(whoami, dict):
108
  return "free"
109
+ return jobs_access_from_whoami(whoami).plan
110
 
111
 
112
  async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
 
203
  return False
204
  return await check_org_membership(token, HF_EMPLOYEE_ORG)
205
 
 
backend/models.py CHANGED
@@ -38,6 +38,7 @@ class ToolApproval(BaseModel):
38
  approved: bool
39
  feedback: str | None = None
40
  edited_script: str | None = None
 
41
 
42
 
43
  class ApprovalRequest(BaseModel):
 
38
  approved: bool
39
  feedback: str | None = None
40
  edited_script: str | None = None
41
+ namespace: str | None = None
42
 
43
 
44
  class ApprovalRequest(BaseModel):
backend/routes/agent.py CHANGED
@@ -32,6 +32,7 @@ from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, se
32
 
33
  import user_quotas
34
 
 
35
  from agent.core.llm_params import _resolve_llm_params
36
 
37
  logger = logging.getLogger(__name__)
@@ -136,6 +137,105 @@ async def _enforce_claude_quota(
136
  agent_session.claude_counted = True
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
140
  """Verify the user has access to the given session. Raises 403 or 404."""
141
  info = session_manager.get_session_info(session_id)
@@ -442,6 +542,27 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
442
  }
443
 
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  @router.get("/sessions", response_model=list[SessionInfo])
446
  async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
447
  """List sessions belonging to the authenticated user."""
@@ -482,15 +603,20 @@ async def submit_approval(
482
  ) -> dict:
483
  """Submit tool approvals to a session. Only accessible by the session owner."""
484
  _check_session_access(request.session_id, user)
 
 
 
485
  approvals = [
486
  {
487
  "tool_call_id": a.tool_call_id,
488
  "approved": a.approved,
489
  "feedback": a.feedback,
490
  "edited_script": a.edited_script,
 
491
  }
492
  for a in request.approvals
493
  ]
 
494
  success = await session_manager.submit_approval(request.session_id, approvals)
495
  if not success:
496
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -540,9 +666,11 @@ async def chat_sse(
540
  "approved": a["approved"],
541
  "feedback": a.get("feedback"),
542
  "edited_script": a.get("edited_script"),
 
543
  }
544
  for a in approvals
545
  ]
 
546
  success = await session_manager.submit_approval(session_id, formatted)
547
  elif text is not None:
548
  success = await session_manager.submit_user_input(session_id, text)
@@ -554,6 +682,7 @@ async def chat_sse(
554
  broadcaster.unsubscribe(sub_id)
555
  raise HTTPException(status_code=404, detail="Session not found or inactive")
556
  except HTTPException:
 
557
  raise
558
  except Exception:
559
  broadcaster.unsubscribe(sub_id)
@@ -562,6 +691,31 @@ async def chat_sse(
562
  return _sse_response(broadcaster, event_queue, sub_id)
563
 
564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  # ---------------------------------------------------------------------------
566
  # Shared SSE helpers
567
  # ---------------------------------------------------------------------------
@@ -729,5 +883,3 @@ async def submit_feedback(
729
  agent_session.session.config.session_dataset_repo
730
  )
731
  return {"status": "ok"}
732
-
733
-
 
32
 
33
  import user_quotas
34
 
35
+ from agent.core.hf_access import get_jobs_access
36
  from agent.core.llm_params import _resolve_llm_params
37
 
38
  logger = logging.getLogger(__name__)
 
137
  agent_session.claude_counted = True
138
 
139
 
140
+ async def _enforce_jobs_access_for_approvals(
141
+ user: dict[str, Any],
142
+ agent_session: AgentSession,
143
+ approvals: list[dict[str, Any]],
144
+ ) -> None:
145
+ """Block approved hf_jobs tool calls when the user has no eligible jobs namespace."""
146
+ pending = agent_session.session.pending_approval or {}
147
+ tool_calls = pending.get("tool_calls") or []
148
+ if not tool_calls:
149
+ return
150
+
151
+ approved_ids = {
152
+ a.get("tool_call_id")
153
+ for a in approvals
154
+ if a.get("approved")
155
+ }
156
+ if not approved_ids:
157
+ return
158
+
159
+ hf_job_ids = [
160
+ tc.id for tc in tool_calls
161
+ if tc.id in approved_ids and tc.function.name == "hf_jobs"
162
+ ]
163
+ if not hf_job_ids:
164
+ return
165
+
166
+ token = agent_session.hf_token or agent_session.session.hf_token
167
+ if not token:
168
+ return
169
+
170
+ access = await get_jobs_access(token)
171
+ if access is None:
172
+ return
173
+
174
+ approval_map = {a.get("tool_call_id"): a for a in approvals}
175
+ if access.personal_can_run_jobs:
176
+ return
177
+
178
+ if access.paid_org_names:
179
+ invalid_namespace = [
180
+ tool_call_id
181
+ for tool_call_id in hf_job_ids
182
+ if (
183
+ approval_map.get(tool_call_id, {}).get("namespace")
184
+ and approval_map.get(tool_call_id, {}).get("namespace") not in access.paid_org_names
185
+ )
186
+ ]
187
+ if invalid_namespace:
188
+ raise HTTPException(
189
+ status_code=400,
190
+ detail={
191
+ "error": "hf_jobs_invalid_namespace",
192
+ "message": (
193
+ "The selected jobs namespace is not one of your eligible paid organizations. "
194
+ f"Allowed namespaces: {', '.join(access.paid_org_names)}"
195
+ ),
196
+ },
197
+ )
198
+ missing_namespace = [
199
+ tool_call_id
200
+ for tool_call_id in hf_job_ids
201
+ if not approval_map.get(tool_call_id, {}).get("namespace")
202
+ ]
203
+ if missing_namespace:
204
+ raise HTTPException(
205
+ status_code=409,
206
+ detail={
207
+ "error": "hf_jobs_namespace_required",
208
+ "message": "Choose which paid organization should own this job run.",
209
+ "plan": user.get("plan", "free"),
210
+ "tool_call_ids": missing_namespace,
211
+ "eligible_namespaces": access.paid_org_names,
212
+ },
213
+ )
214
+ return
215
+
216
+ from agent.core import telemetry
217
+ await telemetry.record_jobs_access_blocked(
218
+ agent_session.session,
219
+ tool_call_ids=hf_job_ids,
220
+ plan=user.get("plan", "free"),
221
+ eligible_namespaces=access.eligible_namespaces,
222
+ )
223
+
224
+ raise HTTPException(
225
+ status_code=402,
226
+ detail={
227
+ "error": "hf_jobs_upgrade_required",
228
+ "message": (
229
+ "Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations. "
230
+ "Upgrade to Pro, or decline the job tool call so the agent can choose another path."
231
+ ),
232
+ "plan": user.get("plan", "free"),
233
+ "tool_call_ids": hf_job_ids,
234
+ "eligible_namespaces": access.eligible_namespaces,
235
+ },
236
+ )
237
+
238
+
239
  def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
240
  """Verify the user has access to the given session. Raises 403 or 404."""
241
  info = session_manager.get_session_info(session_id)
 
542
  }
543
 
544
 
545
+ @router.get("/user/jobs-access")
546
+ async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
547
+ """Return whether the current token can run HF Jobs and under which namespaces."""
548
+ token = None
549
+ auth_header = request.headers.get("Authorization", "")
550
+ if auth_header.startswith("Bearer "):
551
+ token = auth_header[7:]
552
+ if not token:
553
+ token = request.cookies.get("hf_access_token")
554
+ if not token:
555
+ token = os.environ.get("HF_TOKEN")
556
+
557
+ access = await get_jobs_access(token or "")
558
+ return {
559
+ "plan": user.get("plan", "free"),
560
+ "can_run_jobs": bool(access and (access.personal_can_run_jobs or access.paid_org_names)),
561
+ "eligible_namespaces": access.eligible_namespaces if access else [],
562
+ "default_namespace": access.default_namespace if access else None,
563
+ }
564
+
565
+
566
  @router.get("/sessions", response_model=list[SessionInfo])
567
  async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
568
  """List sessions belonging to the authenticated user."""
 
603
  ) -> dict:
604
  """Submit tool approvals to a session. Only accessible by the session owner."""
605
  _check_session_access(request.session_id, user)
606
+ agent_session = session_manager.sessions.get(request.session_id)
607
+ if agent_session is None:
608
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
609
  approvals = [
610
  {
611
  "tool_call_id": a.tool_call_id,
612
  "approved": a.approved,
613
  "feedback": a.feedback,
614
  "edited_script": a.edited_script,
615
+ "namespace": a.namespace,
616
  }
617
  for a in request.approvals
618
  ]
619
+ await _enforce_jobs_access_for_approvals(user, agent_session, approvals)
620
  success = await session_manager.submit_approval(request.session_id, approvals)
621
  if not success:
622
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
666
  "approved": a["approved"],
667
  "feedback": a.get("feedback"),
668
  "edited_script": a.get("edited_script"),
669
+ "namespace": a.get("namespace"),
670
  }
671
  for a in approvals
672
  ]
673
+ await _enforce_jobs_access_for_approvals(user, agent_session, formatted)
674
  success = await session_manager.submit_approval(session_id, formatted)
675
  elif text is not None:
676
  success = await session_manager.submit_user_input(session_id, text)
 
682
  broadcaster.unsubscribe(sub_id)
683
  raise HTTPException(status_code=404, detail="Session not found or inactive")
684
  except HTTPException:
685
+ broadcaster.unsubscribe(sub_id)
686
  raise
687
  except Exception:
688
  broadcaster.unsubscribe(sub_id)
 
691
  return _sse_response(broadcaster, event_queue, sub_id)
692
 
693
 
694
+ @router.post("/pro-click/{session_id}")
695
+ async def record_pro_click(
696
+ session_id: str,
697
+ body: dict,
698
+ user: dict = Depends(get_current_user),
699
+ ) -> dict:
700
+ """Record a click on a Pro upgrade CTA shown from inside a session."""
701
+ _check_session_access(session_id, user)
702
+ agent_session = session_manager.sessions.get(session_id)
703
+ if not agent_session:
704
+ raise HTTPException(status_code=404, detail="Session not found")
705
+
706
+ from agent.core import telemetry
707
+ await telemetry.record_pro_cta_click(
708
+ agent_session.session,
709
+ source=str(body.get("source") or "unknown"),
710
+ target=str(body.get("target") or "pro_pricing"),
711
+ )
712
+ if agent_session.session.config.save_sessions:
713
+ agent_session.session.save_and_upload_detached(
714
+ agent_session.session.config.session_dataset_repo
715
+ )
716
+ return {"status": "ok"}
717
+
718
+
719
  # ---------------------------------------------------------------------------
720
  # Shared SSE helpers
721
  # ---------------------------------------------------------------------------
 
883
  agent_session.session.config.session_dataset_repo
884
  )
885
  return {"status": "ok"}
 
 
backend/session_manager.py CHANGED
@@ -15,7 +15,7 @@ from agent.core.tools import ToolRouter
15
 
16
  # Get project root (parent of backend directory)
17
  PROJECT_ROOT = Path(__file__).parent.parent
18
- DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json")
19
 
20
 
21
  # These dataclasses match agent/main.py structure
 
15
 
16
  # Get project root (parent of backend directory)
17
  PROJECT_ROOT = Path(__file__).parent.parent
18
+ DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json")
19
 
20
 
21
  # These dataclasses match agent/main.py structure
configs/cli_agent_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "anthropic/claude-opus-4-6",
3
+ "save_sessions": true,
4
+ "session_dataset_repo": "smolagents/ml-intern-sessions",
5
+ "yolo_mode": false,
6
+ "confirm_cpu_jobs": true,
7
+ "auto_file_upload": true,
8
+ "mcpServers": {
9
+ "hf-mcp-server": {
10
+ "transport": "http",
11
+ "url": "https://huggingface.co/mcp?login"
12
+ }
13
+ }
14
+ }
configs/{main_agent_config.json → frontend_agent_config.json} RENAMED
File without changes
frontend/src/components/Chat/ChatInput.tsx CHANGED
@@ -6,8 +6,9 @@ import StopIcon from '@mui/icons-material/Stop';
6
  import { apiFetch } from '@/utils/api';
7
  import { useUserQuota } from '@/hooks/useUserQuota';
8
  import ClaudeCapDialog from '@/components/ClaudeCapDialog';
 
9
  import { useAgentStore } from '@/store/agentStore';
10
- import { FIRST_FREE_MODEL_PATH } from '@/utils/model';
11
 
12
  // Model configuration
13
  interface ModelOption {
@@ -37,7 +38,7 @@ const MODEL_OPTIONS: ModelOption[] = [
37
  id: 'claude-opus',
38
  name: 'Claude Opus 4.6',
39
  description: 'Anthropic',
40
- modelPath: 'anthropic/claude-opus-4-6',
41
  avatarUrl: 'https://huggingface.co/api/avatars/Anthropic',
42
  recommended: true,
43
  },
@@ -65,15 +66,17 @@ interface ChatInputProps {
65
  sessionId?: string;
66
  onSend: (text: string) => void;
67
  onStop?: () => void;
 
 
68
  isProcessing?: boolean;
69
  disabled?: boolean;
70
  placeholder?: string;
71
  }
72
 
73
- const isClaudeModel = (m: ModelOption) => m.modelPath.startsWith('anthropic/');
74
  const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0];
75
 
76
- export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
77
  const [input, setInput] = useState('');
78
  const inputRef = useRef<HTMLTextAreaElement>(null);
79
  const [selectedModelId, setSelectedModelId] = useState<string>(MODEL_OPTIONS[0].id);
@@ -86,6 +89,8 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
86
  // the hook layer can flip it without threading props through.
87
  const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted);
88
  const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted);
 
 
89
  const lastSentRef = useRef<string>('');
90
 
91
  // Model is per-session: fetch this tab's current model every time the
@@ -197,6 +202,44 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
197
  } catch { /* ignore */ }
198
  }, [sessionId, onSend, setClaudeQuotaExhausted]);
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  // Hide the chip until the user has actually burned quota — an unused
201
  // Opus session shouldn't populate a counter.
202
  const claudeChip = (() => {
@@ -435,6 +478,17 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
435
  cap={quota?.claudeDailyCap ?? 1}
436
  onClose={handleCapDialogClose}
437
  onUseFreeModel={handleUseFreeModel}
 
 
 
 
 
 
 
 
 
 
 
438
  />
439
  </Box>
440
  </Box>
 
6
  import { apiFetch } from '@/utils/api';
7
  import { useUserQuota } from '@/hooks/useUserQuota';
8
  import ClaudeCapDialog from '@/components/ClaudeCapDialog';
9
+ import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
10
  import { useAgentStore } from '@/store/agentStore';
11
+ import { CLAUDE_MODEL_PATH, FIRST_FREE_MODEL_PATH, isClaudePath } from '@/utils/model';
12
 
13
  // Model configuration
14
  interface ModelOption {
 
38
  id: 'claude-opus',
39
  name: 'Claude Opus 4.6',
40
  description: 'Anthropic',
41
+ modelPath: CLAUDE_MODEL_PATH,
42
  avatarUrl: 'https://huggingface.co/api/avatars/Anthropic',
43
  recommended: true,
44
  },
 
66
  sessionId?: string;
67
  onSend: (text: string) => void;
68
  onStop?: () => void;
69
+ onDeclineBlockedJobs?: () => Promise<boolean>;
70
+ onContinueBlockedJobsWithNamespace?: (namespace: string) => Promise<boolean>;
71
  isProcessing?: boolean;
72
  disabled?: boolean;
73
  placeholder?: string;
74
  }
75
 
76
+ const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
77
  const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0];
78
 
79
+ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJobs, onContinueBlockedJobsWithNamespace, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
80
  const [input, setInput] = useState('');
81
  const inputRef = useRef<HTMLTextAreaElement>(null);
82
  const [selectedModelId, setSelectedModelId] = useState<string>(MODEL_OPTIONS[0].id);
 
89
  // the hook layer can flip it without threading props through.
90
  const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted);
91
  const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted);
92
+ const jobsUpgradeRequired = useAgentStore((s) => s.jobsUpgradeRequired);
93
+ const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired);
94
  const lastSentRef = useRef<string>('');
95
 
96
  // Model is per-session: fetch this tab's current model every time the
 
202
  } catch { /* ignore */ }
203
  }, [sessionId, onSend, setClaudeQuotaExhausted]);
204
 
205
+ const handleClaudeUpgradeClick = useCallback(async () => {
206
+ if (!sessionId) return;
207
+ try {
208
+ await apiFetch(`/api/pro-click/${sessionId}`, {
209
+ method: 'POST',
210
+ body: JSON.stringify({ source: 'claude_cap_dialog', target: 'pro_pricing' }),
211
+ });
212
+ } catch {
213
+ /* tracking is best-effort */
214
+ }
215
+ }, [sessionId]);
216
+
217
+ const handleJobsUpgradeClose = useCallback(() => {
218
+ setJobsUpgradeRequired(null);
219
+ }, [setJobsUpgradeRequired]);
220
+
221
+ const handleJobsUpgradeClick = useCallback(async () => {
222
+ if (!sessionId || !jobsUpgradeRequired) return;
223
+ try {
224
+ await apiFetch(`/api/pro-click/${sessionId}`, {
225
+ method: 'POST',
226
+ body: JSON.stringify({ source: 'hf_jobs_upgrade_dialog', target: 'pro_pricing' }),
227
+ });
228
+ } catch {
229
+ /* tracking is best-effort */
230
+ }
231
+ }, [sessionId, jobsUpgradeRequired]);
232
+
233
+ const handleDeclineBlockedJobs = useCallback(async () => {
234
+ if (!onDeclineBlockedJobs) return;
235
+ await onDeclineBlockedJobs();
236
+ }, [onDeclineBlockedJobs]);
237
+
238
+ const handleContinueBlockedJobsWithNamespace = useCallback(async (namespace: string) => {
239
+ if (!onContinueBlockedJobsWithNamespace) return;
240
+ await onContinueBlockedJobsWithNamespace(namespace);
241
+ }, [onContinueBlockedJobsWithNamespace]);
242
+
243
  // Hide the chip until the user has actually burned quota — an unused
244
  // Opus session shouldn't populate a counter.
245
  const claudeChip = (() => {
 
478
  cap={quota?.claudeDailyCap ?? 1}
479
  onClose={handleCapDialogClose}
480
  onUseFreeModel={handleUseFreeModel}
481
+ onUpgrade={handleClaudeUpgradeClick}
482
+ />
483
+ <JobsUpgradeDialog
484
+ open={!!jobsUpgradeRequired}
485
+ mode={jobsUpgradeRequired?.mode || 'upgrade'}
486
+ message={jobsUpgradeRequired?.message || ''}
487
+ eligibleNamespaces={jobsUpgradeRequired?.eligibleNamespaces || []}
488
+ onClose={handleJobsUpgradeClose}
489
+ onUpgrade={handleJobsUpgradeClick}
490
+ onDecline={handleDeclineBlockedJobs}
491
+ onContinueWithNamespace={handleContinueBlockedJobsWithNamespace}
492
  />
493
  </Box>
494
  </Box>
frontend/src/components/ClaudeCapDialog.tsx CHANGED
@@ -19,6 +19,7 @@ interface ClaudeCapDialogProps {
19
  cap: number;
20
  onClose: () => void;
21
  onUseFreeModel: () => void;
 
22
  }
23
 
24
  export default function ClaudeCapDialog({
@@ -27,6 +28,7 @@ export default function ClaudeCapDialog({
27
  cap,
28
  onClose,
29
  onUseFreeModel,
 
30
  }: ClaudeCapDialogProps) {
31
  // plan not surfaced in copy right now — Pro users see the same dialog and
32
  // can upgrade their org if they're also capped.
@@ -100,6 +102,7 @@ export default function ClaudeCapDialog({
100
  href={HF_PRICING_URL}
101
  target="_blank"
102
  rel="noopener noreferrer"
 
103
  variant="contained"
104
  size="small"
105
  sx={{
 
19
  cap: number;
20
  onClose: () => void;
21
  onUseFreeModel: () => void;
22
+ onUpgrade: () => void;
23
  }
24
 
25
  export default function ClaudeCapDialog({
 
28
  cap,
29
  onClose,
30
  onUseFreeModel,
31
+ onUpgrade,
32
  }: ClaudeCapDialogProps) {
33
  // plan not surfaced in copy right now — Pro users see the same dialog and
34
  // can upgrade their org if they're also capped.
 
102
  href={HF_PRICING_URL}
103
  target="_blank"
104
  rel="noopener noreferrer"
105
+ onClick={onUpgrade}
106
  variant="contained"
107
  size="small"
108
  sx={{
frontend/src/components/JobsUpgradeDialog.tsx ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from 'react';
2
+ import {
3
+ Box,
4
+ Button,
5
+ Dialog,
6
+ DialogActions,
7
+ DialogContent,
8
+ DialogContentText,
9
+ DialogTitle,
10
+ FormControl,
11
+ InputLabel,
12
+ MenuItem,
13
+ Select,
14
+ Typography,
15
+ } from '@mui/material';
16
+
17
+ const HF_PRICING_URL = 'https://huggingface.co/pricing';
18
+
19
+ interface JobsUpgradeDialogProps {
20
+ open: boolean;
21
+ mode: 'upgrade' | 'namespace';
22
+ message: string;
23
+ eligibleNamespaces: string[];
24
+ onUpgrade: () => void;
25
+ onDecline: () => void;
26
+ onClose: () => void;
27
+ onContinueWithNamespace: (namespace: string) => void;
28
+ }
29
+
30
+ export default function JobsUpgradeDialog({
31
+ open,
32
+ mode,
33
+ message,
34
+ eligibleNamespaces,
35
+ onUpgrade,
36
+ onDecline,
37
+ onClose,
38
+ onContinueWithNamespace,
39
+ }: JobsUpgradeDialogProps) {
40
+ const [selectedNamespace, setSelectedNamespace] = useState('');
41
+
42
+ useEffect(() => {
43
+ if (!open) return;
44
+ setSelectedNamespace(eligibleNamespaces[0] || '');
45
+ }, [open, eligibleNamespaces]);
46
+
47
+ return (
48
+ <Dialog
49
+ open={open}
50
+ onClose={onClose}
51
+ slotProps={{
52
+ backdrop: { sx: { backgroundColor: 'rgba(0,0,0,0.5)', backdropFilter: 'blur(4px)' } },
53
+ }}
54
+ PaperProps={{
55
+ sx: {
56
+ bgcolor: 'var(--panel)',
57
+ border: '1px solid var(--border)',
58
+ borderRadius: 'var(--radius-md)',
59
+ boxShadow: 'var(--shadow-1)',
60
+ maxWidth: 500,
61
+ mx: 2,
62
+ },
63
+ }}
64
+ >
65
+ <DialogTitle
66
+ sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
67
+ >
68
+ {mode === 'namespace' ? 'Choose the org for this job' : 'Jobs need Pro or a paid org'}
69
+ </DialogTitle>
70
+ <DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
71
+ <DialogContentText
72
+ sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
73
+ >
74
+ {message}
75
+ </DialogContentText>
76
+ {eligibleNamespaces.length > 0 && (
77
+ <Box
78
+ sx={{
79
+ mt: 2,
80
+ p: 1.5,
81
+ borderRadius: '8px',
82
+ bgcolor: 'var(--accent-yellow-weak)',
83
+ border: '1px solid var(--border)',
84
+ }}
85
+ >
86
+ <Typography
87
+ variant="caption"
88
+ sx={{
89
+ display: 'block',
90
+ fontWeight: 700,
91
+ color: 'var(--text)',
92
+ fontSize: '0.78rem',
93
+ mb: 1,
94
+ letterSpacing: '0.02em',
95
+ }}
96
+ >
97
+ Eligible namespaces
98
+ </Typography>
99
+ {mode === 'namespace' ? (
100
+ <FormControl fullWidth size="small">
101
+ <InputLabel id="jobs-namespace-label">Organization</InputLabel>
102
+ <Select
103
+ labelId="jobs-namespace-label"
104
+ value={selectedNamespace}
105
+ label="Organization"
106
+ onChange={(e) => setSelectedNamespace(String(e.target.value))}
107
+ >
108
+ {eligibleNamespaces.map((namespace) => (
109
+ <MenuItem key={namespace} value={namespace}>
110
+ {namespace}
111
+ </MenuItem>
112
+ ))}
113
+ </Select>
114
+ </FormControl>
115
+ ) : (
116
+ <Typography
117
+ variant="caption"
118
+ sx={{ display: 'block', color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
119
+ >
120
+ {eligibleNamespaces.join(', ')}
121
+ </Typography>
122
+ )}
123
+ </Box>
124
+ )}
125
+ <Typography
126
+ variant="caption"
127
+ sx={{ display: 'block', mt: 2, color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
128
+ >
129
+ If you decline, the agent will have to find another way forward without `hf_jobs`.
130
+ </Typography>
131
+ </DialogContent>
132
+ <DialogActions sx={{ px: 3, pb: 2.5, pt: 2, gap: 1 }}>
133
+ {mode === 'namespace' ? (
134
+ <Button
135
+ onClick={() => onContinueWithNamespace(selectedNamespace)}
136
+ disabled={!selectedNamespace}
137
+ variant="contained"
138
+ size="small"
139
+ sx={{
140
+ fontSize: '0.82rem',
141
+ px: 2.5,
142
+ bgcolor: 'var(--accent-yellow)',
143
+ color: '#000',
144
+ textTransform: 'none',
145
+ fontWeight: 700,
146
+ boxShadow: 'none',
147
+ '&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
148
+ }}
149
+ >
150
+ Run under selected org
151
+ </Button>
152
+ ) : (
153
+ <Button
154
+ component="a"
155
+ href={HF_PRICING_URL}
156
+ target="_blank"
157
+ rel="noopener noreferrer"
158
+ onClick={onUpgrade}
159
+ variant="contained"
160
+ size="small"
161
+ sx={{
162
+ fontSize: '0.82rem',
163
+ px: 2.5,
164
+ bgcolor: 'var(--accent-yellow)',
165
+ color: '#000',
166
+ textTransform: 'none',
167
+ fontWeight: 700,
168
+ boxShadow: 'none',
169
+ '&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
170
+ }}
171
+ >
172
+ Upgrade to Pro
173
+ </Button>
174
+ )}
175
+ <Button
176
+ onClick={onDecline}
177
+ size="small"
178
+ sx={{
179
+ color: 'var(--muted-text)',
180
+ fontSize: '0.82rem',
181
+ px: 2,
182
+ textTransform: 'none',
183
+ '&:hover': { bgcolor: 'var(--hover-bg)' },
184
+ }}
185
+ >
186
+ Decline tool call
187
+ </Button>
188
+ </DialogActions>
189
+ </Dialog>
190
+ );
191
+ }
frontend/src/components/SessionChat.tsx CHANGED
@@ -26,7 +26,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
26
  const { updateSessionTitle, sessions } = useSessionStore();
27
  const isExpired = sessions.find((s) => s.id === sessionId)?.expired === true;
28
 
29
- const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools } = useAgentChat({
30
  sessionId,
31
  isActive,
32
  onReady: () => logger.log(`Session ${sessionId} ready`),
@@ -114,6 +114,8 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
114
  sessionId={sessionId}
115
  onSend={handleSendMessage}
116
  onStop={handleStop}
 
 
117
  isProcessing={busy}
118
  disabled={!isConnected || activityStatus.type === 'waiting-approval'}
119
  placeholder={
 
26
  const { updateSessionTitle, sessions } = useSessionStore();
27
  const isExpired = sessions.find((s) => s.id === sessionId)?.expired === true;
28
 
29
+ const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools, declineBlockedJobs, continueBlockedJobsWithNamespace } = useAgentChat({
30
  sessionId,
31
  isActive,
32
  onReady: () => logger.log(`Session ${sessionId} ready`),
 
114
  sessionId={sessionId}
115
  onSend={handleSendMessage}
116
  onStop={handleStop}
117
+ onDeclineBlockedJobs={declineBlockedJobs}
118
+ onContinueBlockedJobsWithNamespace={continueBlockedJobsWithNamespace}
119
  isProcessing={busy}
120
  disabled={!isConnected || activityStatus.type === 'waiting-approval'}
121
  placeholder={
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -330,6 +330,49 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
330
  messages: UIMessage[];
331
  }>({ setMessages: null, messages: [] });
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  // -- useChat from Vercel AI SDK -----------------------------------------
334
  const chat = useChat({
335
  id: sessionId,
@@ -354,6 +397,56 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
354
  }
355
  return;
356
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  logger.error('useChat error:', error);
358
  if (isActiveRef.current) {
359
  useAgentStore.getState().setError(error.message);
@@ -672,12 +765,15 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
672
 
673
  // -- Approve tools ------------------------------------------------------
674
  const approveTools = useCallback(
675
- async (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => {
676
  // Store edited scripts so the transport can read them when sendMessages is called
677
  for (const a of approvals) {
678
  if (a.edited_script) {
679
  useAgentStore.getState().setEditedScript(a.tool_call_id, a.edited_script);
680
  }
 
 
 
681
  }
682
 
683
  // Update SDK tool state — this triggers sendMessages() via the transport
@@ -707,6 +803,37 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
707
  [sessionId, chat, updateSession, setNeedsAttention],
708
  );
709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
  // -- Stop (interrupt backend agent loop, keep SSE open for events) --------
711
  const stop = useCallback(() => {
712
  // Don't call chat.stop() — keep the SSE stream open so the backend's
@@ -763,5 +890,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
763
  undoLastTurn,
764
  editAndRegenerate,
765
  approveTools,
 
 
766
  };
767
  }
 
330
  messages: UIMessage[];
331
  }>({ setMessages: null, messages: [] });
332
 
333
+ const hydrateFromBackend = useCallback(async () => {
334
+ try {
335
+ const [msgsRes, infoRes] = await Promise.all([
336
+ apiFetch(`/api/session/${sessionId}/messages`),
337
+ apiFetch(`/api/session/${sessionId}`),
338
+ ]);
339
+ if (!msgsRes.ok) return null;
340
+ const data = await msgsRes.json();
341
+ if (!Array.isArray(data) || data.length === 0) return null;
342
+
343
+ saveBackendMessages(sessionId, data);
344
+
345
+ let pendingIds: Set<string> | undefined;
346
+ let info: Record<string, unknown> | null = null;
347
+ if (infoRes.ok) {
348
+ info = await infoRes.json();
349
+ const pendingApproval = info?.pending_approval;
350
+ if (pendingApproval && Array.isArray(pendingApproval)) {
351
+ pendingIds = new Set(
352
+ pendingApproval.map((t: { tool_call_id: string }) => t.tool_call_id),
353
+ );
354
+ if (pendingIds.size > 0) {
355
+ setNeedsAttention(sessionId, true);
356
+ }
357
+ }
358
+ }
359
+
360
+ const uiMsgs = llmMessagesToUIMessages(data, pendingIds, chatActionsRef.current.messages);
361
+ if (uiMsgs.length > 0) {
362
+ chatActionsRef.current.setMessages?.(uiMsgs);
363
+ saveMessages(sessionId, uiMsgs);
364
+ }
365
+
366
+ if (pendingIds && pendingIds.size > 0) {
367
+ updateSession(sessionId, { activityStatus: { type: 'waiting-approval' }, isProcessing: false });
368
+ }
369
+
370
+ return { data, pendingIds, info };
371
+ } catch {
372
+ return null;
373
+ }
374
+ }, [sessionId, setNeedsAttention]);
375
+
376
  // -- useChat from Vercel AI SDK -----------------------------------------
377
  const chat = useChat({
378
  id: sessionId,
 
397
  }
398
  return;
399
  }
400
+ if (error.message === 'HF_JOBS_UPGRADE_REQUIRED') {
401
+ const typed = error as Error & {
402
+ detail?: Record<string, unknown>;
403
+ approvals?: Array<{
404
+ tool_call_id: string;
405
+ approved: boolean;
406
+ feedback?: string | null;
407
+ edited_script?: string | null;
408
+ }>;
409
+ };
410
+ void hydrateFromBackend();
411
+ if (isActiveRef.current) {
412
+ useAgentStore.getState().setJobsUpgradeRequired({
413
+ approvals: typed.approvals || [],
414
+ toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
415
+ message: String(
416
+ typed.detail?.message
417
+ || 'Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations.',
418
+ ),
419
+ eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
420
+ plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
421
+ mode: 'upgrade',
422
+ });
423
+ }
424
+ return;
425
+ }
426
+ if (error.message === 'HF_JOBS_NAMESPACE_REQUIRED') {
427
+ const typed = error as Error & {
428
+ detail?: Record<string, unknown>;
429
+ approvals?: Array<{
430
+ tool_call_id: string;
431
+ approved: boolean;
432
+ feedback?: string | null;
433
+ edited_script?: string | null;
434
+ namespace?: string | null;
435
+ }>;
436
+ };
437
+ void hydrateFromBackend();
438
+ if (isActiveRef.current) {
439
+ useAgentStore.getState().setJobsUpgradeRequired({
440
+ approvals: typed.approvals || [],
441
+ toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
442
+ message: String(typed.detail?.message || 'Choose which organization should own this job run.'),
443
+ eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
444
+ plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
445
+ mode: 'namespace',
446
+ });
447
+ }
448
+ return;
449
+ }
450
  logger.error('useChat error:', error);
451
  if (isActiveRef.current) {
452
  useAgentStore.getState().setError(error.message);
 
765
 
766
  // -- Approve tools ------------------------------------------------------
767
  const approveTools = useCallback(
768
+ async (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null; namespace?: string | null }>) => {
769
  // Store edited scripts so the transport can read them when sendMessages is called
770
  for (const a of approvals) {
771
  if (a.edited_script) {
772
  useAgentStore.getState().setEditedScript(a.tool_call_id, a.edited_script);
773
  }
774
+ if (a.namespace) {
775
+ useAgentStore.getState().setApprovalNamespace(a.tool_call_id, a.namespace);
776
+ }
777
  }
778
 
779
  // Update SDK tool state — this triggers sendMessages() via the transport
 
803
  [sessionId, chat, updateSession, setNeedsAttention],
804
  );
805
 
806
+ const declineBlockedJobs = useCallback(async () => {
807
+ const blocked = useAgentStore.getState().jobsUpgradeRequired;
808
+ if (!blocked) return false;
809
+
810
+ const approvals = blocked.approvals.map((approval) => ({
811
+ ...approval,
812
+ approved: blocked.toolCallIds.includes(approval.tool_call_id) ? false : approval.approved,
813
+ feedback: blocked.toolCallIds.includes(approval.tool_call_id)
814
+ ? 'Rejected because this account cannot launch Hugging Face Jobs.'
815
+ : approval.feedback,
816
+ }));
817
+
818
+ useAgentStore.getState().setJobsUpgradeRequired(null);
819
+ return approveTools(approvals);
820
+ }, [approveTools]);
821
+
822
+ const continueBlockedJobsWithNamespace = useCallback(async (namespace: string) => {
823
+ const blocked = useAgentStore.getState().jobsUpgradeRequired;
824
+ if (!blocked) return false;
825
+
826
+ const approvals = blocked.approvals.map((approval) => ({
827
+ ...approval,
828
+ namespace: blocked.toolCallIds.includes(approval.tool_call_id)
829
+ ? namespace
830
+ : approval.namespace,
831
+ }));
832
+
833
+ useAgentStore.getState().setJobsUpgradeRequired(null);
834
+ return approveTools(approvals);
835
+ }, [approveTools]);
836
+
837
  // -- Stop (interrupt backend agent loop, keep SSE open for events) --------
838
  const stop = useCallback(() => {
839
  // Don't call chat.stop() — keep the SSE stream open so the backend's
 
890
  undoLastTurn,
891
  editAndRegenerate,
892
  approveTools,
893
+ declineBlockedJobs,
894
+ continueBlockedJobsWithNamespace,
895
  };
896
  }
frontend/src/lib/sse-chat-transport.ts CHANGED
@@ -320,11 +320,13 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
320
  const approved = p.approval?.approved ?? true;
321
  // Get edited script from agentStore if available
322
  const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
 
323
  return {
324
  tool_call_id: p.toolCallId,
325
  approved,
326
  feedback: approved ? null : (p.approval?.reason || 'Rejected by user'),
327
  edited_script: editedScript ?? null,
 
328
  };
329
  }).filter(Boolean);
330
  body = { approvals };
@@ -362,6 +364,30 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
362
  // instead of a generic error banner.
363
  throw new Error('CLAUDE_QUOTA_EXHAUSTED');
364
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  if (!response.ok) {
366
  const errorText = await response.text().catch(() => 'Request failed');
367
  throw new Error(`Chat request failed: ${response.status} ${errorText}`);
 
320
  const approved = p.approval?.approved ?? true;
321
  // Get edited script from agentStore if available
322
  const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
323
+ const namespace = useAgentStore.getState().getApprovalNamespace(p.toolCallId);
324
  return {
325
  tool_call_id: p.toolCallId,
326
  approved,
327
  feedback: approved ? null : (p.approval?.reason || 'Rejected by user'),
328
  edited_script: editedScript ?? null,
329
+ namespace: namespace ?? null,
330
  };
331
  }).filter(Boolean);
332
  body = { approvals };
 
364
  // instead of a generic error banner.
365
  throw new Error('CLAUDE_QUOTA_EXHAUSTED');
366
  }
367
+ if (response.status === 402) {
368
+ const payload = await response.json().catch(() => null);
369
+ if (payload?.detail?.error === 'hf_jobs_upgrade_required') {
370
+ const err = new Error('HF_JOBS_UPGRADE_REQUIRED') as Error & {
371
+ detail?: Record<string, unknown>;
372
+ approvals?: Array<Record<string, unknown>>;
373
+ };
374
+ err.detail = payload.detail as Record<string, unknown>;
375
+ err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
376
+ throw err;
377
+ }
378
+ }
379
+ if (response.status === 409) {
380
+ const payload = await response.json().catch(() => null);
381
+ if (payload?.detail?.error === 'hf_jobs_namespace_required') {
382
+ const err = new Error('HF_JOBS_NAMESPACE_REQUIRED') as Error & {
383
+ detail?: Record<string, unknown>;
384
+ approvals?: Array<Record<string, unknown>>;
385
+ };
386
+ err.detail = payload.detail as Record<string, unknown>;
387
+ err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
388
+ throw err;
389
+ }
390
+ }
391
  if (!response.ok) {
392
  const errorText = await response.text().catch(() => 'Request failed');
393
  throw new Error(`Chat request failed: ${response.status} ${errorText}`);
frontend/src/store/agentStore.ts CHANGED
@@ -45,6 +45,21 @@ export interface LLMHealthError {
45
  model: string;
46
  }
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  export type ActivityStatus =
49
  | { type: 'idle' }
50
  | { type: 'thinking' }
@@ -110,6 +125,7 @@ interface AgentStore {
110
  llmHealthError: LLMHealthError | null;
111
  /** Set when a Claude-send hits the daily quota — ChatInput opens the cap dialog in response. */
112
  claudeQuotaExhausted: boolean;
 
113
 
114
  // Right panel (single-artifact pattern)
115
  panelData: PanelData | null;
@@ -122,6 +138,9 @@ interface AgentStore {
122
  // Edited scripts (tool_call_id -> edited content)
123
  editedScripts: Record<string, string>;
124
 
 
 
 
125
  // Job URLs (tool_call_id -> job URL) for HF jobs
126
  jobUrls: Record<string, string>;
127
 
@@ -156,6 +175,7 @@ interface AgentStore {
156
  setError: (error: string | null) => void;
157
  setLlmHealthError: (error: LLMHealthError | null) => void;
158
  setClaudeQuotaExhausted: (exhausted: boolean) => void;
 
159
 
160
  setPanel: (data: PanelData, view?: PanelView, editable?: boolean) => void;
161
  setPanelView: (view: PanelView) => void;
@@ -170,6 +190,10 @@ interface AgentStore {
170
  getEditedScript: (toolCallId: string) => string | undefined;
171
  clearEditedScripts: () => void;
172
 
 
 
 
 
173
  setJobUrl: (toolCallId: string, jobUrl: string) => void;
174
  getJobUrl: (toolCallId: string) => string | undefined;
175
 
@@ -251,6 +275,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
251
  error: null,
252
  llmHealthError: null,
253
  claudeQuotaExhausted: false,
 
254
 
255
  panelData: null,
256
  panelView: 'script',
@@ -259,6 +284,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
259
  plan: [],
260
 
261
  editedScripts: {},
 
262
  jobUrls: {},
263
  jobStatuses: {},
264
  toolErrors: loadToolErrors(),
@@ -363,6 +389,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
363
  setError: (error) => set({ error }),
364
  setLlmHealthError: (error) => set({ llmHealthError: error }),
365
  setClaudeQuotaExhausted: (exhausted) => set({ claudeQuotaExhausted: exhausted }),
 
366
 
367
  // ── Panel (single-artifact) ───────────────────────────────────────
368
  // Each setter also patches the active session's snapshot so that
@@ -428,6 +455,16 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
428
 
429
  clearEditedScripts: () => set({ editedScripts: {} }),
430
 
 
 
 
 
 
 
 
 
 
 
431
  // ── Job URLs ────────────────────────────────────────────────────────
432
 
433
  setJobUrl: (toolCallId, jobUrl) => {
 
45
  model: string;
46
  }
47
 
48
+ export interface JobsUpgradeState {
49
+ approvals: Array<{
50
+ tool_call_id: string;
51
+ approved: boolean;
52
+ feedback?: string | null;
53
+ edited_script?: string | null;
54
+ namespace?: string | null;
55
+ }>;
56
+ toolCallIds: string[];
57
+ message: string;
58
+ eligibleNamespaces: string[];
59
+ plan: 'free' | 'pro' | 'org';
60
+ mode: 'upgrade' | 'namespace';
61
+ }
62
+
63
  export type ActivityStatus =
64
  | { type: 'idle' }
65
  | { type: 'thinking' }
 
125
  llmHealthError: LLMHealthError | null;
126
  /** Set when a Claude-send hits the daily quota — ChatInput opens the cap dialog in response. */
127
  claudeQuotaExhausted: boolean;
128
+ jobsUpgradeRequired: JobsUpgradeState | null;
129
 
130
  // Right panel (single-artifact pattern)
131
  panelData: PanelData | null;
 
138
  // Edited scripts (tool_call_id -> edited content)
139
  editedScripts: Record<string, string>;
140
 
141
+ // Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
142
+ approvalNamespaces: Record<string, string>;
143
+
144
  // Job URLs (tool_call_id -> job URL) for HF jobs
145
  jobUrls: Record<string, string>;
146
 
 
175
  setError: (error: string | null) => void;
176
  setLlmHealthError: (error: LLMHealthError | null) => void;
177
  setClaudeQuotaExhausted: (exhausted: boolean) => void;
178
+ setJobsUpgradeRequired: (state: JobsUpgradeState | null) => void;
179
 
180
  setPanel: (data: PanelData, view?: PanelView, editable?: boolean) => void;
181
  setPanelView: (view: PanelView) => void;
 
190
  getEditedScript: (toolCallId: string) => string | undefined;
191
  clearEditedScripts: () => void;
192
 
193
+ setApprovalNamespace: (toolCallId: string, namespace: string) => void;
194
+ getApprovalNamespace: (toolCallId: string) => string | undefined;
195
+ clearApprovalNamespaces: () => void;
196
+
197
  setJobUrl: (toolCallId: string, jobUrl: string) => void;
198
  getJobUrl: (toolCallId: string) => string | undefined;
199
 
 
275
  error: null,
276
  llmHealthError: null,
277
  claudeQuotaExhausted: false,
278
+ jobsUpgradeRequired: null,
279
 
280
  panelData: null,
281
  panelView: 'script',
 
284
  plan: [],
285
 
286
  editedScripts: {},
287
+ approvalNamespaces: {},
288
  jobUrls: {},
289
  jobStatuses: {},
290
  toolErrors: loadToolErrors(),
 
389
  setError: (error) => set({ error }),
390
  setLlmHealthError: (error) => set({ llmHealthError: error }),
391
  setClaudeQuotaExhausted: (exhausted) => set({ claudeQuotaExhausted: exhausted }),
392
+ setJobsUpgradeRequired: (state) => set({ jobsUpgradeRequired: state }),
393
 
394
  // ── Panel (single-artifact) ───────────────────────────────────────
395
  // Each setter also patches the active session's snapshot so that
 
455
 
456
  clearEditedScripts: () => set({ editedScripts: {} }),
457
 
458
+ setApprovalNamespace: (toolCallId, namespace) => {
459
+ set((state) => ({
460
+ approvalNamespaces: { ...state.approvalNamespaces, [toolCallId]: namespace },
461
+ }));
462
+ },
463
+
464
+ getApprovalNamespace: (toolCallId) => get().approvalNamespaces[toolCallId],
465
+
466
+ clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
467
+
468
  // ── Job URLs ────────────────────────────────────────────────────────
469
 
470
  setJobUrl: (toolCallId, jobUrl) => {
frontend/src/types/agent.ts CHANGED
@@ -27,6 +27,7 @@ export interface ToolApproval {
27
  tool_call_id: string;
28
  approved: boolean;
29
  feedback?: string | null;
 
30
  }
31
 
32
  export interface User {
 
27
  tool_call_id: string;
28
  approved: boolean;
29
  feedback?: string | null;
30
+ namespace?: string | null;
31
  }
32
 
33
  export interface User {
frontend/src/utils/model.ts CHANGED
@@ -3,13 +3,12 @@
3
  * ClaudeCapDialog "Use a free model" escape hatch.
4
  *
5
  * Keep in sync with MODEL_OPTIONS in components/Chat/ChatInput.tsx and
6
- * AVAILABLE_MODELS in backend/routes/agent.py. Bare HF ids (no
7
- * `huggingface/` prefix) — matches upstream's auto-router.
8
  */
9
 
10
- export const CLAUDE_MODEL_PATH = 'anthropic/claude-opus-4-6';
11
  export const FIRST_FREE_MODEL_PATH = 'moonshotai/Kimi-K2.6';
12
 
13
  export function isClaudePath(modelPath: string | undefined): boolean {
14
- return !!modelPath && modelPath.startsWith('anthropic/');
15
  }
 
3
  * ClaudeCapDialog "Use a free model" escape hatch.
4
  *
5
  * Keep in sync with MODEL_OPTIONS in components/Chat/ChatInput.tsx and
6
+ * AVAILABLE_MODELS in backend/routes/agent.py.
 
7
  */
8
 
9
+ export const CLAUDE_MODEL_PATH = 'bedrock/us.anthropic.claude-opus-4-6-v1';
10
  export const FIRST_FREE_MODEL_PATH = 'moonshotai/Kimi-K2.6';
11
 
12
  export function isClaudePath(modelPath: string | undefined): boolean {
13
+ return !!modelPath && modelPath.includes('anthropic');
14
  }
pyproject.toml CHANGED
@@ -1,5 +1,5 @@
1
  [project]
2
- name = "hf-agent"
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
@@ -42,11 +42,12 @@ eval = [
42
  # Development and testing dependencies
43
  dev = [
44
  "pytest>=9.0.2",
 
45
  ]
46
 
47
  # All dependencies (eval + dev)
48
  all = [
49
- "hf-agent[eval,dev]",
50
  ]
51
 
52
  [project.scripts]
@@ -61,3 +62,6 @@ include = ["agent*"]
61
 
62
  [tool.uv]
63
  package = true
 
 
 
 
1
  [project]
2
+ name = "ml-intern"
3
  version = "0.1.0"
4
  description = "Add your description here"
5
  readme = "README.md"
 
42
  # Development and testing dependencies
43
  dev = [
44
  "pytest>=9.0.2",
45
+ "pytest-asyncio>=0.26.0",
46
  ]
47
 
48
  # All dependencies (eval + dev)
49
  all = [
50
+ "ml-intern[eval,dev]",
51
  ]
52
 
53
  [project.scripts]
 
62
 
63
  [tool.uv]
64
  package = true
65
+
66
+ [tool.pytest.ini_options]
67
+ asyncio_mode = "auto"
scripts/build_kpis.py CHANGED
@@ -44,7 +44,8 @@ re-running the same hour overwrites.
44
  regenerate_rate — sessions with any `undo_complete` event / sessions
45
  time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
46
  thumbs_up / thumbs_down
47
- hf_jobs_submitted / _succeeded
 
48
  gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
49
 
50
  ================================================================================
@@ -210,7 +211,8 @@ def _session_metrics(session: dict) -> dict:
210
  "tool_calls_total": 0, "tool_calls_success": 0,
211
  "failures": 0, "regenerate_sessions": 0,
212
  "thumbs_up": 0, "thumbs_down": 0,
213
- "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0,
 
214
  "first_tool_s": -1,
215
  }
216
  events = session.get("events") or []
@@ -229,8 +231,11 @@ def _session_metrics(session: dict) -> dict:
229
  gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
230
  jobs_submitted = 0
231
  jobs_succeeded = 0
 
232
  thumbs_up = 0
233
  thumbs_down = 0
 
 
234
 
235
  start_dt = _parse_ts(session_start)
236
 
@@ -283,6 +288,14 @@ def _session_metrics(session: dict) -> dict:
283
  if status in ("completed", "succeeded", "success"):
284
  jobs_succeeded += 1
285
 
 
 
 
 
 
 
 
 
286
  out["tool_calls_total"] = tool_total
287
  out["tool_calls_success"] = tool_success
288
  out["failures"] = 1 if had_error else 0
@@ -291,8 +304,11 @@ def _session_metrics(session: dict) -> dict:
291
  out["thumbs_down"] = thumbs_down
292
  out["hf_jobs_submitted"] = jobs_submitted
293
  out["hf_jobs_succeeded"] = jobs_succeeded
 
 
294
  out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
295
  out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
 
296
  out["_user"] = session.get("user_id") or session.get("session_id")
297
  return dict(out)
298
 
@@ -301,9 +317,12 @@ def _aggregate(per_session: list[dict]) -> dict:
301
  """Collapse a bucket's worth of session rollups into the final KPI row."""
302
  ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
303
  gpu_hours: dict[str, float] = defaultdict(float)
 
304
  for s in per_session:
305
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
306
  gpu_hours[f] += h
 
 
307
 
308
  total_sessions = sum(s["sessions"] for s in per_session)
309
  total_turns = sum(s["turns"] for s in per_session)
@@ -340,7 +359,10 @@ def _aggregate(per_session: list[dict]) -> dict:
340
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
341
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
342
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
 
 
343
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
 
344
  }
345
 
346
 
 
44
  regenerate_rate — sessions with any `undo_complete` event / sessions
45
  time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
46
  thumbs_up / thumbs_down
47
+ hf_jobs_submitted / _succeeded / _blocked
48
+ pro_cta_clicks
49
  gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
50
 
51
  ================================================================================
 
211
  "tool_calls_total": 0, "tool_calls_success": 0,
212
  "failures": 0, "regenerate_sessions": 0,
213
  "thumbs_up": 0, "thumbs_down": 0,
214
+ "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
215
+ "pro_cta_clicks": 0,
216
  "first_tool_s": -1,
217
  }
218
  events = session.get("events") or []
 
231
  gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
232
  jobs_submitted = 0
233
  jobs_succeeded = 0
234
+ jobs_blocked = 0
235
  thumbs_up = 0
236
  thumbs_down = 0
237
+ pro_cta_clicks = 0
238
+ pro_cta_by_source: dict[str, int] = defaultdict(int)
239
 
240
  start_dt = _parse_ts(session_start)
241
 
 
288
  if status in ("completed", "succeeded", "success"):
289
  jobs_succeeded += 1
290
 
291
+ elif et == "jobs_access_blocked":
292
+ jobs_blocked += 1
293
+
294
+ elif et == "pro_cta_click":
295
+ pro_cta_clicks += 1
296
+ source = str(data.get("source") or "unknown")
297
+ pro_cta_by_source[source] += 1
298
+
299
  out["tool_calls_total"] = tool_total
300
  out["tool_calls_success"] = tool_success
301
  out["failures"] = 1 if had_error else 0
 
304
  out["thumbs_down"] = thumbs_down
305
  out["hf_jobs_submitted"] = jobs_submitted
306
  out["hf_jobs_succeeded"] = jobs_succeeded
307
+ out["hf_jobs_blocked"] = jobs_blocked
308
+ out["pro_cta_clicks"] = pro_cta_clicks
309
  out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
310
  out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
311
+ out["_pro_cta_by_source"] = dict(pro_cta_by_source)
312
  out["_user"] = session.get("user_id") or session.get("session_id")
313
  return dict(out)
314
 
 
317
  """Collapse a bucket's worth of session rollups into the final KPI row."""
318
  ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
319
  gpu_hours: dict[str, float] = defaultdict(float)
320
+ pro_cta_by_source: dict[str, int] = defaultdict(int)
321
  for s in per_session:
322
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
323
  gpu_hours[f] += h
324
+ for source, count in (s.get("_pro_cta_by_source") or {}).items():
325
+ pro_cta_by_source[source] += int(count)
326
 
327
  total_sessions = sum(s["sessions"] for s in per_session)
328
  total_turns = sum(s["turns"] for s in per_session)
 
359
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
360
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
361
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
362
+ "hf_jobs_blocked": int(sum(s["hf_jobs_blocked"] for s in per_session)),
363
+ "pro_cta_clicks": int(sum(s["pro_cta_clicks"] for s in per_session)),
364
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
365
+ "pro_cta_by_source_json": json.dumps(dict(pro_cta_by_source), sort_keys=True),
366
  }
367
 
368
 
tests/unit/test_build_kpis.py CHANGED
@@ -88,6 +88,22 @@ def test_hf_job_gpu_hours():
88
  assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def test_feedback_counts():
92
  mod = _load()
93
  events = [
@@ -120,6 +136,22 @@ def test_aggregate_day_cache_hit_and_users():
120
  assert abs(row["cost_usd"] - 1.5) < 1e-9
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def test_failure_and_regenerate_rates():
124
  mod = _load()
125
  s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
 
88
  assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
89
 
90
 
91
+ def test_hf_job_blocked_and_pro_clicks_are_counted():
92
+ mod = _load()
93
+ events = [
94
+ _ev("jobs_access_blocked", {"tool_call_ids": ["tc1"], "plan": "free"}),
95
+ _ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
96
+ _ev("pro_cta_click", {"source": "claude_cap_dialog"}),
97
+ ]
98
+ m = mod._session_metrics(_session(events))
99
+ assert m["hf_jobs_blocked"] == 1
100
+ assert m["pro_cta_clicks"] == 2
101
+ assert m["_pro_cta_by_source"] == {
102
+ "hf_jobs_upgrade_dialog": 1,
103
+ "claude_cap_dialog": 1,
104
+ }
105
+
106
+
107
  def test_feedback_counts():
108
  mod = _load()
109
  events = [
 
136
  assert abs(row["cost_usd"] - 1.5) < 1e-9
137
 
138
 
139
+ def test_aggregate_day_sums_pro_click_sources():
140
+ mod = _load()
141
+ s1 = mod._session_metrics(_session([
142
+ _ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
143
+ _ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
144
+ ], user_id="u1"))
145
+ s2 = mod._session_metrics(_session([
146
+ _ev("pro_cta_click", {"source": "claude_cap_dialog"}),
147
+ ], user_id="u2"))
148
+ row = mod._aggregate_day([s1, s2])
149
+ assert row["pro_cta_clicks"] == 3
150
+ assert row["pro_cta_by_source_json"] == (
151
+ '{"claude_cap_dialog": 1, "hf_jobs_upgrade_dialog": 2}'
152
+ )
153
+
154
+
155
  def test_failure_and_regenerate_rates():
156
  mod = _load()
157
  s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
tests/unit/test_cli_rendering.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression tests for interactive CLI rendering and research model routing."""
2
+
3
+ from io import StringIO
4
+ from types import SimpleNamespace
5
+
6
+ from agent.tools.research_tool import _get_research_model
7
+ from agent.utils import terminal_display
8
+
9
+
10
+ def test_direct_anthropic_research_model_stays_off_bedrock():
11
+ assert _get_research_model("anthropic/claude-opus-4-6") == "anthropic/claude-sonnet-4-6"
12
+
13
+
14
+ def test_bedrock_anthropic_research_model_stays_on_bedrock():
15
+ assert (
16
+ _get_research_model("bedrock/us.anthropic.claude-opus-4-6-v1")
17
+ == "bedrock/us.anthropic.claude-sonnet-4-6"
18
+ )
19
+
20
+
21
+ def test_non_anthropic_research_model_is_unchanged():
22
+ assert _get_research_model("openai/gpt-5.4") == "openai/gpt-5.4"
23
+
24
+
25
+ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
26
+ calls: list[object] = []
27
+
28
+ def _unexpected_future(*args, **kwargs):
29
+ calls.append((args, kwargs))
30
+ raise AssertionError("background redraw task should not be created")
31
+
32
+ monkeypatch.setattr("asyncio.ensure_future", _unexpected_future)
33
+ monkeypatch.setattr(
34
+ terminal_display,
35
+ "_console",
36
+ SimpleNamespace(file=StringIO(), width=100),
37
+ )
38
+
39
+ mgr = terminal_display.SubAgentDisplayManager()
40
+ mgr.start("agent-1", "research")
41
+ mgr.add_call("agent-1", "▸ hf_papers {\"operation\": \"search\"}")
42
+ mgr.clear("agent-1")
43
+
44
+ assert calls == []
tests/unit/test_dangling_tool_calls.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression tests for `_patch_dangling_tool_calls`.
2
+
3
+ Reproduces the failure mode behind observatory sessions 8dd2ce30 and
4
+ 59c9e678 (2026-04-25): a tool call cancelled mid-execution leaves an
5
+ orphan ``tool_use`` in history; the user types a follow-up; Bedrock
6
+ rejects the next request with HTTP 400 ``messages.N: tool_use ids were
7
+ found without tool_result blocks immediately after``.
8
+ """
9
+
10
+ from litellm import ChatCompletionMessageToolCall, Message
11
+
12
+ from agent.context_manager.manager import ContextManager
13
+
14
+
15
+ def _tool_call(call_id: str, name: str = "research") -> ChatCompletionMessageToolCall:
16
+ return ChatCompletionMessageToolCall(
17
+ id=call_id,
18
+ type="function",
19
+ function={"name": name, "arguments": "{}"},
20
+ )
21
+
22
+
23
+ def _make_cm() -> ContextManager:
24
+ cm = ContextManager.__new__(ContextManager)
25
+ cm.system_prompt = "system"
26
+ cm.model_max_tokens = 100_000
27
+ cm.compact_size = 1_000
28
+ cm.running_context_usage = 0
29
+ cm.untouched_messages = 5
30
+ cm.items = [Message(role="system", content="system")]
31
+ return cm
32
+
33
+
34
+ def test_orphan_tool_use_followed_by_user_message_is_patched():
35
+ cm = _make_cm()
36
+ cm.items.extend([
37
+ Message(role="user", content="Research X"),
38
+ Message(
39
+ role="assistant",
40
+ content=None,
41
+ tool_calls=[_tool_call("call_abc", "research")],
42
+ ),
43
+ Message(role="user", content="??"),
44
+ ])
45
+ msgs = cm.get_messages()
46
+ tool_msgs = [m for m in msgs if getattr(m, "role", None) == "tool"]
47
+ assert len(tool_msgs) == 1
48
+ assert tool_msgs[0].tool_call_id == "call_abc"
49
+ assert "interrupted" in (tool_msgs[0].content or "").lower() or "not executed" in (tool_msgs[0].content or "").lower()
50
+
51
+
52
+ def test_no_orphan_means_no_stub():
53
+ cm = _make_cm()
54
+ cm.items.extend([
55
+ Message(role="user", content="Research X"),
56
+ Message(
57
+ role="assistant",
58
+ content=None,
59
+ tool_calls=[_tool_call("call_abc", "research")],
60
+ ),
61
+ Message(role="tool", content="ok", tool_call_id="call_abc", name="research"),
62
+ ])
63
+ cm.get_messages()
64
+ tool_msgs = [m for m in cm.items if getattr(m, "role", None) == "tool"]
65
+ assert len(tool_msgs) == 1
66
+ assert tool_msgs[0].content == "ok"
67
+
68
+
69
+ def test_multiple_dangling_tool_calls_in_one_assistant_message_are_all_patched():
70
+ cm = _make_cm()
71
+ cm.items.extend([
72
+ Message(role="user", content="do two things"),
73
+ Message(
74
+ role="assistant",
75
+ content=None,
76
+ tool_calls=[
77
+ _tool_call("call_1", "research"),
78
+ _tool_call("call_2", "bash"),
79
+ ],
80
+ ),
81
+ Message(role="user", content="follow up"),
82
+ ])
83
+ cm.get_messages()
84
+ tool_ids = {
85
+ getattr(m, "tool_call_id", None)
86
+ for m in cm.items
87
+ if getattr(m, "role", None) == "tool"
88
+ }
89
+ assert tool_ids == {"call_1", "call_2"}
90
+
91
+
92
+ def test_orphan_in_earlier_turn_still_gets_patched():
93
+ """Two-turn history where the FIRST turn was interrupted.
94
+
95
+ Old patcher stopped at the first user msg encountered while scanning
96
+ backwards, so this case never got fixed and Bedrock rejected.
97
+ """
98
+ cm = _make_cm()
99
+ cm.items.extend([
100
+ Message(role="user", content="turn 1"),
101
+ Message(
102
+ role="assistant",
103
+ content=None,
104
+ tool_calls=[_tool_call("call_old", "research")],
105
+ ),
106
+ Message(role="user", content="turn 2 — please retry"),
107
+ Message(
108
+ role="assistant",
109
+ content=None,
110
+ tool_calls=[_tool_call("call_new", "bash")],
111
+ ),
112
+ Message(role="tool", content="ok", tool_call_id="call_new", name="bash"),
113
+ ])
114
+ cm.get_messages()
115
+ tool_ids = {
116
+ getattr(m, "tool_call_id", None)
117
+ for m in cm.items
118
+ if getattr(m, "role", None) == "tool"
119
+ }
120
+ assert "call_old" in tool_ids
121
+ assert "call_new" in tool_ids
tests/unit/test_doom_loop_polling.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression test for doom-loop false-positive on legitimate polling.
2
+
3
+ Reproduces the failure mode in observatory sessions 40fcb414 ($32.59),
4
+ 8e90352e ($62.63), and 403178bf ($5.71) on 2026-04-25: the agent polled a
5
+ long-running job with `bash sleep 300 && wc -l output` four times in a
6
+ row. The arguments were byte-identical, but the results moved (27210 →
7
+ 36454 → 45770 → 55138 — actual progress). The detector hashed args only
8
+ and false-fired DOOM LOOP, which made the agent abandon perfectly valid
9
+ polling.
10
+
11
+ After the fix the signature includes the tool result hash, so identical
12
+ args + different results no longer trips the detector.
13
+ """
14
+
15
+ from litellm import ChatCompletionMessageToolCall, Message
16
+
17
+ from agent.core.doom_loop import check_for_doom_loop
18
+
19
+
20
+ def _assistant(call_id: str, name: str, args: str) -> Message:
21
+ return Message(
22
+ role="assistant",
23
+ content=None,
24
+ tool_calls=[
25
+ ChatCompletionMessageToolCall(
26
+ id=call_id,
27
+ type="function",
28
+ function={"name": name, "arguments": args},
29
+ )
30
+ ],
31
+ )
32
+
33
+
34
+ def _tool(call_id: str, name: str, content: str) -> Message:
35
+ return Message(role="tool", content=content, tool_call_id=call_id, name=name)
36
+
37
+
38
+ _POLL_ARGS = '{"command": "sleep 300 && ls /app/images/ | wc -l"}'
39
+
40
+
41
+ def test_polling_with_progressing_results_does_not_fire():
42
+ msgs = [
43
+ Message(role="user", content="run the job"),
44
+ _assistant("c1", "bash", _POLL_ARGS),
45
+ _tool("c1", "bash", "27210"),
46
+ _assistant("c2", "bash", _POLL_ARGS),
47
+ _tool("c2", "bash", "36454"),
48
+ _assistant("c3", "bash", _POLL_ARGS),
49
+ _tool("c3", "bash", "45770"),
50
+ _assistant("c4", "bash", _POLL_ARGS),
51
+ _tool("c4", "bash", "55138"),
52
+ ]
53
+ assert check_for_doom_loop(msgs) is None
54
+
55
+
56
+ def test_truly_stuck_polling_with_identical_results_still_fires():
57
+ """If the same poll returns the same number, the job is genuinely
58
+ stuck and the detector SHOULD fire."""
59
+ msgs = [
60
+ _assistant("c1", "bash", _POLL_ARGS),
61
+ _tool("c1", "bash", "55138"),
62
+ _assistant("c2", "bash", _POLL_ARGS),
63
+ _tool("c2", "bash", "55138"),
64
+ _assistant("c3", "bash", _POLL_ARGS),
65
+ _tool("c3", "bash", "55138"),
66
+ ]
67
+ prompt = check_for_doom_loop(msgs)
68
+ assert prompt is not None
69
+ assert "DOOM LOOP" in prompt
70
+ assert "bash" in prompt
71
+
72
+
73
+ def test_identical_calls_with_no_results_yet_still_fires():
74
+ """If three identical calls have no tool results (e.g. all cancelled
75
+ or errored before a result was recorded), treat as a real loop."""
76
+ msgs = [
77
+ _assistant("c1", "write", '{"path": "/tmp/x", "content": "..."}'),
78
+ _assistant("c2", "write", '{"path": "/tmp/x", "content": "..."}'),
79
+ _assistant("c3", "write", '{"path": "/tmp/x", "content": "..."}'),
80
+ ]
81
+ prompt = check_for_doom_loop(msgs)
82
+ assert prompt is not None
83
+ assert "DOOM LOOP" in prompt
84
+ assert "write" in prompt
85
+
86
+
87
+ def test_different_args_does_not_fire():
88
+ msgs = [
89
+ _assistant("c1", "bash", '{"command": "ls /a"}'),
90
+ _tool("c1", "bash", "ok"),
91
+ _assistant("c2", "bash", '{"command": "ls /b"}'),
92
+ _tool("c2", "bash", "ok"),
93
+ _assistant("c3", "bash", '{"command": "ls /c"}'),
94
+ _tool("c3", "bash", "ok"),
95
+ ]
96
+ assert check_for_doom_loop(msgs) is None
tests/unit/test_hf_access.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.core.hf_access import jobs_access_from_whoami
2
+
3
+
4
+ def test_personal_pro_prefers_username_namespace():
5
+ access = jobs_access_from_whoami({
6
+ "name": "alice",
7
+ "plan": "pro",
8
+ "orgs": [],
9
+ })
10
+ assert access.plan == "pro"
11
+ assert access.eligible_namespaces == ["alice"]
12
+ assert access.default_namespace == "alice"
13
+
14
+
15
+ def test_free_user_with_paid_org_uses_org_namespace():
16
+ access = jobs_access_from_whoami({
17
+ "name": "alice",
18
+ "plan": "free",
19
+ "orgs": [
20
+ {"name": "team-a", "plan": "team"},
21
+ {"name": "oss-friends", "plan": "free"},
22
+ ],
23
+ })
24
+ assert access.plan == "org"
25
+ assert access.personal_can_run_jobs is False
26
+ assert access.eligible_namespaces == ["team-a"]
27
+ assert access.default_namespace is None
28
+
29
+
30
+ def test_free_user_without_paid_org_cannot_run_jobs():
31
+ access = jobs_access_from_whoami({
32
+ "name": "alice",
33
+ "plan": "free",
34
+ "orgs": [{"name": "community", "plan": "free"}],
35
+ })
36
+ assert access.plan == "free"
37
+ assert access.can_run_jobs is False
38
+ assert access.eligible_namespaces == []
39
+ assert access.default_namespace is None
tests/unit/test_llm_error_classification.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for LLM error classification helpers in agent.core.agent_loop.
2
+
3
+ Covers two regressions on 2026-04-25:
4
+
5
+ 1. Non-Anthropic context overflow (Kimi 365k > 262k) was not classified as
6
+ ``_is_context_overflow_error``, so the recovery path didn't fire and
7
+ session 62ccfdcb died with 68 wasted compaction events.
8
+
9
+ 2. Bedrock TPM rate limit (`Too many tokens, please wait before trying
10
+ again.`) needs the longer rate-limit retry schedule. The old schedule
11
+ ([5, 15, 30] = 50s) burned through 6 sessions costing >$2,400 combined
12
+ on the same day.
13
+ """
14
+
15
+ from agent.core.agent_loop import (
16
+ _MAX_LLM_RETRIES,
17
+ _LLM_RATE_LIMIT_RETRY_DELAYS,
18
+ _LLM_RETRY_DELAYS,
19
+ _is_context_overflow_error,
20
+ _is_rate_limit_error,
21
+ _is_transient_error,
22
+ _retry_delay_for,
23
+ )
24
+
25
+
26
+ # ── context overflow ────────────────────────────────────────────────────
27
+
28
+
29
+ def test_kimi_prompt_too_long_is_context_overflow():
30
+ # Verbatim error text from session 62ccfdcb (2026-04-25, Kimi K2.6).
31
+ err = Exception(
32
+ "litellm.BadRequestError: OpenAIException - The prompt is too long: "
33
+ "365407, model maximum context length: 262143"
34
+ )
35
+ assert _is_context_overflow_error(err)
36
+
37
+
38
+ def test_openai_context_length_exceeded_is_context_overflow():
39
+ err = Exception("Error: This model's maximum context length is 8192 tokens.")
40
+ assert _is_context_overflow_error(err)
41
+
42
+
43
+ def test_random_error_is_not_context_overflow():
44
+ err = Exception("connection reset by peer")
45
+ assert not _is_context_overflow_error(err)
46
+
47
+
48
+ # ── rate limit ──────────────────────────────────────────────────────────
49
+
50
+
51
+ def test_bedrock_too_many_tokens_is_rate_limit():
52
+ # Verbatim from sessions b37a3823, c4d7a831, b63c4933 (2026-04-25).
53
+ err = Exception(
54
+ 'litellm.RateLimitError: BedrockException - {"message":"Too many '
55
+ 'tokens, please wait before trying again."}'
56
+ )
57
+ assert _is_rate_limit_error(err)
58
+ # Rate-limit errors are also classified as transient.
59
+ assert _is_transient_error(err)
60
+
61
+
62
+ def test_429_is_rate_limit():
63
+ err = Exception("HTTP 429 Too Many Requests")
64
+ assert _is_rate_limit_error(err)
65
+
66
+
67
+ def test_timeout_is_transient_but_not_rate_limit():
68
+ err = Exception("Request timed out after 600s")
69
+ assert _is_transient_error(err)
70
+ assert not _is_rate_limit_error(err)
71
+
72
+
73
+ # ── retry schedule selection ────────────────────────────────────────────
74
+
75
+
76
+ def test_rate_limit_uses_longer_schedule():
77
+ err = Exception("Too many tokens, please wait before trying again.")
78
+ delays = [_retry_delay_for(err, i) for i in range(len(_LLM_RATE_LIMIT_RETRY_DELAYS))]
79
+ assert delays == _LLM_RATE_LIMIT_RETRY_DELAYS
80
+ # Just past the schedule → None (stop retrying).
81
+ assert _retry_delay_for(err, len(_LLM_RATE_LIMIT_RETRY_DELAYS)) is None
82
+
83
+
84
+ def test_other_transient_uses_short_schedule():
85
+ err = Exception("503 service unavailable")
86
+ delays = [_retry_delay_for(err, i) for i in range(len(_LLM_RETRY_DELAYS))]
87
+ assert delays == _LLM_RETRY_DELAYS
88
+ assert _retry_delay_for(err, len(_LLM_RETRY_DELAYS)) is None
89
+
90
+
91
+ def test_non_transient_returns_none():
92
+ err = Exception("invalid request: bad parameter")
93
+ assert _retry_delay_for(err, 0) is None
94
+
95
+
96
+ def test_rate_limit_total_budget_covers_bedrock_bucket_recovery():
97
+ """The whole point of the rate-limit schedule: total wait time should
98
+ exceed the ~60s Bedrock TPM bucket recovery window."""
99
+ assert len(_LLM_RATE_LIMIT_RETRY_DELAYS) == _MAX_LLM_RETRIES - 1
100
+ assert sum(_LLM_RATE_LIMIT_RETRY_DELAYS) > 60
tests/unit/test_llm_params.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.core.llm_params import UnsupportedEffortError, _resolve_llm_params
2
+
3
+
4
+ def test_openai_xhigh_effort_is_forwarded():
5
+ params = _resolve_llm_params(
6
+ "openai/gpt-5.5",
7
+ reasoning_effort="xhigh",
8
+ strict=True,
9
+ )
10
+
11
+ assert params["model"] == "openai/gpt-5.5"
12
+ assert params["reasoning_effort"] == "xhigh"
13
+
14
+
15
+ def test_openai_max_effort_is_still_rejected():
16
+ try:
17
+ _resolve_llm_params(
18
+ "openai/gpt-5.4",
19
+ reasoning_effort="max",
20
+ strict=True,
21
+ )
22
+ except UnsupportedEffortError as exc:
23
+ assert "OpenAI doesn't accept effort='max'" in str(exc)
24
+ else:
25
+ raise AssertionError("Expected UnsupportedEffortError for max effort")
tests/unit/test_malformed_args_recovery.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression test for the malformed-JSON loop in observatory session
2
+ 7750e82f (2026-04-25): GLM-5.1 produced six consecutive ``write`` calls
3
+ whose ``arguments`` strings JSON-parse-failed (truncated mid-stream by
4
+ the provider). The soft retry hint didn't move the model. The detector
5
+ in ``_detect_repeated_malformed`` looks for the streak so the agent loop
6
+ can inject a hard system-prompt forcing a different strategy.
7
+ """
8
+
9
+ from litellm import Message
10
+
11
+ from agent.core.agent_loop import _detect_repeated_malformed
12
+
13
+
14
+ def _malformed_tool_msg(name: str, call_id: str) -> Message:
15
+ return Message(
16
+ role="tool",
17
+ content=(
18
+ f"ERROR: Tool call to '{name}' had malformed JSON arguments and "
19
+ f"was NOT executed. Retry with smaller content — for 'write', "
20
+ f"split into multiple smaller writes using 'edit'."
21
+ ),
22
+ tool_call_id=call_id,
23
+ name=name,
24
+ )
25
+
26
+
27
+ def test_two_consecutive_malformed_same_tool_triggers():
28
+ items = [
29
+ Message(role="user", content="write a big plan"),
30
+ Message(role="assistant", content=None),
31
+ _malformed_tool_msg("write", "1"),
32
+ Message(role="assistant", content=None),
33
+ _malformed_tool_msg("write", "2"),
34
+ ]
35
+ assert _detect_repeated_malformed(items, threshold=2) == "write"
36
+
37
+
38
+ def test_one_malformed_does_not_trigger():
39
+ items = [
40
+ Message(role="user", content="write a plan"),
41
+ Message(role="assistant", content=None),
42
+ _malformed_tool_msg("write", "1"),
43
+ ]
44
+ assert _detect_repeated_malformed(items, threshold=2) is None
45
+
46
+
47
+ def test_two_malformed_different_tools_does_not_trigger():
48
+ items = [
49
+ Message(role="assistant", content=None),
50
+ _malformed_tool_msg("write", "1"),
51
+ Message(role="assistant", content=None),
52
+ _malformed_tool_msg("bash", "2"),
53
+ ]
54
+ assert _detect_repeated_malformed(items, threshold=2) is None
55
+
56
+
57
+ def test_streak_broken_by_successful_tool_call_does_not_trigger():
58
+ items = [
59
+ Message(role="assistant", content=None),
60
+ _malformed_tool_msg("write", "1"),
61
+ Message(role="assistant", content=None),
62
+ Message(role="tool", content="ok", tool_call_id="2", name="write"),
63
+ Message(role="assistant", content=None),
64
+ _malformed_tool_msg("write", "3"),
65
+ ]
66
+ assert _detect_repeated_malformed(items, threshold=2) is None
tests/unit/test_sandbox_already_active_message.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression test for sandbox_create not surfacing the hardware lockout.
2
+
3
+ In observatory session d6f8454c (2026-04-25) the agent called
4
+ sandbox_create 18 times across 11 distinct hardware tiers (a10g-large,
5
+ a100-large, t4-small, cpu-upgrade, cpu-basic, zero-a10g, l4x1, t4-medium,
6
+ a10g-small, l40sx1, …). Every call returned 'Sandbox already active' for
7
+ the same sandbox, but the message did not say that hardware can't be
8
+ changed by re-calling, so the agent thought "still pending, retry with a
9
+ different flavor" and burned 17 useless turns.
10
+
11
+ The fix makes the response explicit when the requested hardware differs
12
+ from what's already active.
13
+ """
14
+
15
+ import asyncio
16
+ from types import SimpleNamespace
17
+
18
+ from agent.tools.sandbox_tool import sandbox_create_handler
19
+
20
+
21
+ def _session_with_sandbox():
22
+ sb = SimpleNamespace(
23
+ space_id="user/sandbox-abc123",
24
+ url="https://huggingface.co/spaces/user/sandbox-abc123",
25
+ )
26
+ return SimpleNamespace(sandbox=sb)
27
+
28
+
29
+ def test_already_active_with_different_hw_warns_about_lockout():
30
+ session = _session_with_sandbox()
31
+ out, ok = asyncio.run(
32
+ sandbox_create_handler({"hardware": "a100-large"}, session=session)
33
+ )
34
+ assert ok is True
35
+ # The message should mention the lockout AND the requested flavor.
36
+ assert "cannot be changed" in out.lower()
37
+ assert "a100-large" in out
38
+ assert "delete" in out.lower()
39
+
40
+
41
+ def test_already_active_no_hw_request_just_returns_handle():
42
+ session = _session_with_sandbox()
43
+ out, ok = asyncio.run(sandbox_create_handler({}, session=session))
44
+ assert ok is True
45
+ assert "user/sandbox-abc123" in out
46
+ # No spurious lockout note when the agent didn't request a flavor.
47
+ assert "cannot be changed" not in out.lower()
uv.lock CHANGED
@@ -228,6 +228,18 @@ wheels = [
228
  { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
229
  ]
230
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  [[package]]
232
  name = "attrs"
233
  version = "25.4.0"
@@ -992,78 +1004,6 @@ wheels = [
992
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
993
  ]
994
 
995
- [[package]]
996
- name = "hf-agent"
997
- version = "0.1.0"
998
- source = { editable = "." }
999
- dependencies = [
1000
- { name = "boto3" },
1001
- { name = "datasets" },
1002
- { name = "fastapi" },
1003
- { name = "fastmcp" },
1004
- { name = "httpx" },
1005
- { name = "huggingface-hub" },
1006
- { name = "litellm" },
1007
- { name = "nbconvert" },
1008
- { name = "nbformat" },
1009
- { name = "prompt-toolkit" },
1010
- { name = "pydantic" },
1011
- { name = "python-dotenv" },
1012
- { name = "requests" },
1013
- { name = "rich" },
1014
- { name = "thefuzz" },
1015
- { name = "uvicorn", extra = ["standard"] },
1016
- { name = "websockets" },
1017
- { name = "whoosh" },
1018
- ]
1019
-
1020
- [package.optional-dependencies]
1021
- all = [
1022
- { name = "datasets" },
1023
- { name = "inspect-ai" },
1024
- { name = "pandas" },
1025
- { name = "pytest" },
1026
- { name = "tenacity" },
1027
- ]
1028
- dev = [
1029
- { name = "pytest" },
1030
- ]
1031
- eval = [
1032
- { name = "datasets" },
1033
- { name = "inspect-ai" },
1034
- { name = "pandas" },
1035
- { name = "tenacity" },
1036
- ]
1037
-
1038
- [package.metadata]
1039
- requires-dist = [
1040
- { name = "boto3", specifier = ">=1.35.0" },
1041
- { name = "datasets", specifier = ">=4.4.1" },
1042
- { name = "datasets", marker = "extra == 'eval'", specifier = ">=4.3.0" },
1043
- { name = "fastapi", specifier = ">=0.115.0" },
1044
- { name = "fastmcp", specifier = ">=3.2.0" },
1045
- { name = "hf-agent", extras = ["eval", "dev"], marker = "extra == 'all'" },
1046
- { name = "httpx", specifier = ">=0.27.0" },
1047
- { name = "huggingface-hub", specifier = ">=1.0.1" },
1048
- { name = "inspect-ai", marker = "extra == 'eval'", specifier = ">=0.3.149" },
1049
- { name = "litellm", specifier = ">=1.83.0" },
1050
- { name = "nbconvert", specifier = ">=7.16.6" },
1051
- { name = "nbformat", specifier = ">=5.10.4" },
1052
- { name = "pandas", marker = "extra == 'eval'", specifier = ">=2.3.3" },
1053
- { name = "prompt-toolkit", specifier = ">=3.0.0" },
1054
- { name = "pydantic", specifier = ">=2.12.3" },
1055
- { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1056
- { name = "python-dotenv", specifier = ">=1.2.1" },
1057
- { name = "requests", specifier = ">=2.33.0" },
1058
- { name = "rich", specifier = ">=13.0.0" },
1059
- { name = "tenacity", marker = "extra == 'eval'", specifier = ">=8.0.0" },
1060
- { name = "thefuzz", specifier = ">=0.22.1" },
1061
- { name = "uvicorn", extras = ["standard"], specifier = ">=0.32.0" },
1062
- { name = "websockets", specifier = ">=13.0" },
1063
- { name = "whoosh", specifier = ">=2.7.4" },
1064
- ]
1065
- provides-extras = ["eval", "dev", "all"]
1066
-
1067
  [[package]]
1068
  name = "hf-xet"
1069
  version = "1.2.0"
@@ -1827,6 +1767,83 @@ wheels = [
1827
  { url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
1828
  ]
1829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1830
  [[package]]
1831
  name = "mmh3"
1832
  version = "5.2.0"
@@ -2775,6 +2792,19 @@ wheels = [
2775
  { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
2776
  ]
2777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2778
  [[package]]
2779
  name = "python-dateutil"
2780
  version = "2.9.0.post0"
@@ -3619,6 +3649,18 @@ wheels = [
3619
  { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" },
3620
  ]
3621
 
 
 
 
 
 
 
 
 
 
 
 
 
3622
  [[package]]
3623
  name = "uc-micro-py"
3624
  version = "1.0.3"
 
228
  { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
229
  ]
230
 
231
+ [[package]]
232
+ name = "apscheduler"
233
+ version = "3.11.2"
234
+ source = { registry = "https://pypi.org/simple" }
235
+ dependencies = [
236
+ { name = "tzlocal" },
237
+ ]
238
+ sdist = { url = "https://files.pythonhosted.org/packages/07/12/3e4389e5920b4c1763390c6d371162f3784f86f85cd6d6c1bfe68eef14e2/apscheduler-3.11.2.tar.gz", hash = "sha256:2a9966b052ec805f020c8c4c3ae6e6a06e24b1bf19f2e11d91d8cca0473eef41", size = 108683, upload-time = "2025-12-22T00:39:34.884Z" }
239
+ wheels = [
240
+ { url = "https://files.pythonhosted.org/packages/9f/64/2e54428beba8d9992aa478bb8f6de9e4ecaa5f8f513bcfd567ed7fb0262d/apscheduler-3.11.2-py3-none-any.whl", hash = "sha256:ce005177f741409db4e4dd40a7431b76feb856b9dd69d57e0da49d6715bfd26d", size = 64439, upload-time = "2025-12-22T00:39:33.303Z" },
241
+ ]
242
+
243
  [[package]]
244
  name = "attrs"
245
  version = "25.4.0"
 
1004
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
1005
  ]
1006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  [[package]]
1008
  name = "hf-xet"
1009
  version = "1.2.0"
 
1767
  { url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
1768
  ]
1769
 
1770
+ [[package]]
1771
+ name = "ml-intern"
1772
+ version = "0.1.0"
1773
+ source = { editable = "." }
1774
+ dependencies = [
1775
+ { name = "apscheduler" },
1776
+ { name = "boto3" },
1777
+ { name = "datasets" },
1778
+ { name = "fastapi" },
1779
+ { name = "fastmcp" },
1780
+ { name = "httpx" },
1781
+ { name = "huggingface-hub" },
1782
+ { name = "litellm" },
1783
+ { name = "nbconvert" },
1784
+ { name = "nbformat" },
1785
+ { name = "prompt-toolkit" },
1786
+ { name = "pydantic" },
1787
+ { name = "python-dotenv" },
1788
+ { name = "requests" },
1789
+ { name = "rich" },
1790
+ { name = "thefuzz" },
1791
+ { name = "uvicorn", extra = ["standard"] },
1792
+ { name = "websockets" },
1793
+ { name = "whoosh" },
1794
+ ]
1795
+
1796
+ [package.optional-dependencies]
1797
+ all = [
1798
+ { name = "datasets" },
1799
+ { name = "inspect-ai" },
1800
+ { name = "pandas" },
1801
+ { name = "pytest" },
1802
+ { name = "pytest-asyncio" },
1803
+ { name = "tenacity" },
1804
+ ]
1805
+ dev = [
1806
+ { name = "pytest" },
1807
+ { name = "pytest-asyncio" },
1808
+ ]
1809
+ eval = [
1810
+ { name = "datasets" },
1811
+ { name = "inspect-ai" },
1812
+ { name = "pandas" },
1813
+ { name = "tenacity" },
1814
+ ]
1815
+
1816
+ [package.metadata]
1817
+ requires-dist = [
1818
+ { name = "apscheduler", specifier = ">=3.10,<4" },
1819
+ { name = "boto3", specifier = ">=1.35.0" },
1820
+ { name = "datasets", specifier = ">=4.4.1" },
1821
+ { name = "datasets", marker = "extra == 'eval'", specifier = ">=4.3.0" },
1822
+ { name = "fastapi", specifier = ">=0.115.0" },
1823
+ { name = "fastmcp", specifier = ">=3.2.0" },
1824
+ { name = "httpx", specifier = ">=0.27.0" },
1825
+ { name = "huggingface-hub", specifier = ">=1.0.1" },
1826
+ { name = "inspect-ai", marker = "extra == 'eval'", specifier = ">=0.3.149" },
1827
+ { name = "litellm", specifier = ">=1.83.0" },
1828
+ { name = "ml-intern", extras = ["eval", "dev"], marker = "extra == 'all'" },
1829
+ { name = "nbconvert", specifier = ">=7.16.6" },
1830
+ { name = "nbformat", specifier = ">=5.10.4" },
1831
+ { name = "pandas", marker = "extra == 'eval'", specifier = ">=2.3.3" },
1832
+ { name = "prompt-toolkit", specifier = ">=3.0.0" },
1833
+ { name = "pydantic", specifier = ">=2.12.3" },
1834
+ { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1835
+ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.26.0" },
1836
+ { name = "python-dotenv", specifier = ">=1.2.1" },
1837
+ { name = "requests", specifier = ">=2.33.0" },
1838
+ { name = "rich", specifier = ">=13.0.0" },
1839
+ { name = "tenacity", marker = "extra == 'eval'", specifier = ">=8.0.0" },
1840
+ { name = "thefuzz", specifier = ">=0.22.1" },
1841
+ { name = "uvicorn", extras = ["standard"], specifier = ">=0.32.0" },
1842
+ { name = "websockets", specifier = ">=13.0" },
1843
+ { name = "whoosh", specifier = ">=2.7.4" },
1844
+ ]
1845
+ provides-extras = ["eval", "dev", "all"]
1846
+
1847
  [[package]]
1848
  name = "mmh3"
1849
  version = "5.2.0"
 
2792
  { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
2793
  ]
2794
 
2795
+ [[package]]
2796
+ name = "pytest-asyncio"
2797
+ version = "1.3.0"
2798
+ source = { registry = "https://pypi.org/simple" }
2799
+ dependencies = [
2800
+ { name = "pytest" },
2801
+ { name = "typing-extensions", marker = "python_full_version < '3.13'" },
2802
+ ]
2803
+ sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" }
2804
+ wheels = [
2805
+ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
2806
+ ]
2807
+
2808
  [[package]]
2809
  name = "python-dateutil"
2810
  version = "2.9.0.post0"
 
3649
  { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" },
3650
  ]
3651
 
3652
+ [[package]]
3653
+ name = "tzlocal"
3654
+ version = "5.3.1"
3655
+ source = { registry = "https://pypi.org/simple" }
3656
+ dependencies = [
3657
+ { name = "tzdata", marker = "sys_platform == 'win32'" },
3658
+ ]
3659
+ sdist = { url = "https://files.pythonhosted.org/packages/8b/2e/c14812d3d4d9cd1773c6be938f89e5735a1f11a9f184ac3639b93cef35d5/tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd", size = 30761, upload-time = "2025-03-05T21:17:41.549Z" }
3660
+ wheels = [
3661
+ { url = "https://files.pythonhosted.org/packages/c2/14/e2a54fabd4f08cd7af1c07030603c3356b74da07f7cc056e600436edfa17/tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d", size = 18026, upload-time = "2025-03-05T21:17:39.857Z" },
3662
+ ]
3663
+
3664
  [[package]]
3665
  name = "uc-micro-py"
3666
  version = "1.0.3"