Aksel Joonas Reedi commited on
Commit
2a2e170
·
unverified ·
1 Parent(s): b292d83

feat(observability)

Browse files
agent/config.py CHANGED
@@ -24,8 +24,13 @@ class Config(BaseModel):
24
  model_name: str
25
  mcpServers: dict[str, MCPServerConfig] = {}
26
  save_sessions: bool = True
27
- session_dataset_repo: str = "akseljoonas/hf-agent-sessions"
28
- auto_save_interval: int = 3 # Save every N user turns (0 = disabled)
 
 
 
 
 
29
  yolo_mode: bool = False # Auto-approve all tool calls without confirmation
30
  max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited)
31
 
 
24
  model_name: str
25
  mcpServers: dict[str, MCPServerConfig] = {}
26
  save_sessions: bool = True
27
+ session_dataset_repo: str = "smolagents/ml-intern-sessions"
28
+ auto_save_interval: int = 1 # Save every N user turns (0 = disabled)
29
+ # Mid-turn heartbeat: save + upload every N seconds while events are being
30
+ # emitted. Guards against losing trace data on long-running turns that
31
+ # crash before turn_complete (e.g. a multi-hour hf_jobs wait that OOMs).
32
+ # 0 = disabled. Consumed by agent.core.telemetry.HeartbeatSaver.
33
+ heartbeat_interval_s: int = 60
34
  yolo_mode: bool = False # Auto-approve all tool calls without confirmation
35
  max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited)
36
 
agent/core/agent_loop.py CHANGED
@@ -6,12 +6,14 @@ import asyncio
6
  import json
7
  import logging
8
  import os
9
- from dataclasses import dataclass
 
10
 
11
  from litellm import ChatCompletionMessageToolCall, Message, acompletion
12
  from litellm.exceptions import ContextWindowExceededError
13
 
14
  from agent.config import Config
 
15
  from agent.core.doom_loop import check_for_doom_loop
16
  from agent.core.llm_params import _resolve_llm_params
17
  from agent.core.prompt_caching import with_prompt_caching
@@ -291,6 +293,7 @@ class LLMResult:
291
  tool_calls_acc: dict[int, dict]
292
  token_count: int
293
  finish_reason: str | None
 
294
 
295
 
296
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
@@ -298,6 +301,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
298
  response = None
299
  _healed_effort = False # one-shot safety net per call
300
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
 
301
  for _llm_attempt in range(_MAX_LLM_RETRIES):
302
  try:
303
  response = await acompletion(
@@ -339,6 +343,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
339
  tool_calls_acc: dict[int, dict] = {}
340
  token_count = 0
341
  finish_reason = None
 
342
 
343
  async for chunk in response:
344
  if session.is_cancelled:
@@ -349,6 +354,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
349
  if not choice:
350
  if hasattr(chunk, "usage") and chunk.usage:
351
  token_count = chunk.usage.total_tokens
 
352
  continue
353
 
354
  delta = choice.delta
@@ -379,12 +385,22 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
379
 
380
  if hasattr(chunk, "usage") and chunk.usage:
381
  token_count = chunk.usage.total_tokens
 
 
 
 
 
 
 
 
 
382
 
383
  return LLMResult(
384
  content=full_content or None,
385
  tool_calls_acc=tool_calls_acc,
386
  token_count=token_count,
387
  finish_reason=finish_reason,
 
388
  )
389
 
390
 
@@ -393,6 +409,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
393
  response = None
394
  _healed_effort = False
395
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
 
396
  for _llm_attempt in range(_MAX_LLM_RETRIES):
397
  try:
398
  response = await acompletion(
@@ -454,11 +471,20 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
454
  Event(event_type="assistant_message", data={"content": content})
455
  )
456
 
 
 
 
 
 
 
 
 
457
  return LLMResult(
458
  content=content,
459
  tool_calls_acc=tool_calls_acc,
460
  token_count=token_count,
461
  finish_reason=finish_reason,
 
462
  )
463
 
464
 
 
6
  import json
7
  import logging
8
  import os
9
+ import time
10
+ from dataclasses import dataclass, field
11
 
12
  from litellm import ChatCompletionMessageToolCall, Message, acompletion
13
  from litellm.exceptions import ContextWindowExceededError
14
 
15
  from agent.config import Config
16
+ from agent.core import telemetry
17
  from agent.core.doom_loop import check_for_doom_loop
18
  from agent.core.llm_params import _resolve_llm_params
19
  from agent.core.prompt_caching import with_prompt_caching
 
293
  tool_calls_acc: dict[int, dict]
294
  token_count: int
295
  finish_reason: str | None
296
+ usage: dict = field(default_factory=dict)
297
 
298
 
299
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
 
301
  response = None
302
  _healed_effort = False # one-shot safety net per call
303
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
304
+ t_start = time.monotonic()
305
  for _llm_attempt in range(_MAX_LLM_RETRIES):
306
  try:
307
  response = await acompletion(
 
343
  tool_calls_acc: dict[int, dict] = {}
344
  token_count = 0
345
  finish_reason = None
346
+ final_usage_chunk = None
347
 
348
  async for chunk in response:
349
  if session.is_cancelled:
 
354
  if not choice:
355
  if hasattr(chunk, "usage") and chunk.usage:
356
  token_count = chunk.usage.total_tokens
357
+ final_usage_chunk = chunk
358
  continue
359
 
360
  delta = choice.delta
 
385
 
386
  if hasattr(chunk, "usage") and chunk.usage:
387
  token_count = chunk.usage.total_tokens
388
+ final_usage_chunk = chunk
389
+
390
+ usage = await telemetry.record_llm_call(
391
+ session,
392
+ model=llm_params.get("model", session.config.model_name),
393
+ response=final_usage_chunk,
394
+ latency_ms=int((time.monotonic() - t_start) * 1000),
395
+ finish_reason=finish_reason,
396
+ )
397
 
398
  return LLMResult(
399
  content=full_content or None,
400
  tool_calls_acc=tool_calls_acc,
401
  token_count=token_count,
402
  finish_reason=finish_reason,
403
+ usage=usage,
404
  )
405
 
406
 
 
409
  response = None
410
  _healed_effort = False
411
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
412
+ t_start = time.monotonic()
413
  for _llm_attempt in range(_MAX_LLM_RETRIES):
414
  try:
415
  response = await acompletion(
 
471
  Event(event_type="assistant_message", data={"content": content})
472
  )
473
 
474
+ usage = await telemetry.record_llm_call(
475
+ session,
476
+ model=llm_params.get("model", session.config.model_name),
477
+ response=response,
478
+ latency_ms=int((time.monotonic() - t_start) * 1000),
479
+ finish_reason=finish_reason,
480
+ )
481
+
482
  return LLMResult(
483
  content=content,
484
  tool_calls_acc=tool_calls_acc,
485
  token_count=token_count,
486
  finish_reason=finish_reason,
487
+ usage=usage,
488
  )
489
 
490
 
agent/core/redact.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Secret scrubbing for session trajectories before upload.
2
+
3
+ Users frequently paste HF / API / GitHub tokens into the chat, or scripts echo
4
+ them via env dumps. This module applies regex-based redaction to any string
5
+ value found recursively in a trajectory payload. The goal is best-effort —
6
+ strict formats are matched; we won't catch free-form leaks like "my password
7
+ is hunter2".
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from typing import Any
14
+
15
+ # Each entry: (compiled regex, replacement placeholder).
16
+ # Patterns are conservative: they only match tokens with the canonical prefix
17
+ # and a minimum body length so we don't paint over normal text.
18
+ _PATTERNS: list[tuple[re.Pattern, str]] = [
19
+ # Hugging Face tokens: hf_[A-Za-z0-9]{30,}
20
+ (re.compile(r"hf_[A-Za-z0-9]{30,}"), "[REDACTED_HF_TOKEN]"),
21
+ # Anthropic: sk-ant-[A-Za-z0-9_\-]{20,}
22
+ (re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"), "[REDACTED_ANTHROPIC_KEY]"),
23
+ # OpenAI: sk-[A-Za-z0-9]{40,} (legacy + proj keys)
24
+ (re.compile(r"sk-(?!ant-)[A-Za-z0-9_\-]{40,}"), "[REDACTED_OPENAI_KEY]"),
25
+ # GitHub classic PATs: ghp_, gho_, ghu_, ghs_, ghr_ followed by 36+ chars
26
+ (re.compile(r"gh[pousr]_[A-Za-z0-9]{36,}"), "[REDACTED_GITHUB_TOKEN]"),
27
+ # GitHub fine-grained PATs: github_pat_<alphanumeric_underscore>
28
+ (re.compile(r"github_pat_[A-Za-z0-9_]{36,}"), "[REDACTED_GITHUB_TOKEN]"),
29
+ # AWS access key IDs: AKIA / ASIA + 16 uppercase alnum
30
+ (re.compile(r"\b(?:AKIA|ASIA)[A-Z0-9]{16}\b"), "[REDACTED_AWS_KEY_ID]"),
31
+ # Generic 'Bearer <token>' header values
32
+ (re.compile(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}"), "Bearer [REDACTED]"),
33
+ ]
34
+
35
+ # Env-var-like exports: we scrub the value but keep the name so callers can
36
+ # still see which secret was referenced. Covers `KEY=value` and `KEY: value`
37
+ # when the key looks secret-y.
38
+ _SECRETY_NAMES = re.compile(
39
+ r"(?i)\b(HF_TOKEN|HUGGINGFACEHUB_API_TOKEN|ANTHROPIC_API_KEY|OPENAI_API_KEY|"
40
+ r"GITHUB_TOKEN|AWS_SECRET_ACCESS_KEY|AWS_ACCESS_KEY_ID|PASSWORD|SECRET|API_KEY)"
41
+ r"\s*[:=]\s*([^\s\"']+)"
42
+ )
43
+
44
+
45
+ def scrub_string(s: str) -> str:
46
+ """Apply all redaction patterns to a single string. Safe on non-strings."""
47
+ if not isinstance(s, str) or not s:
48
+ return s
49
+ out = s
50
+ for pat, repl in _PATTERNS:
51
+ out = pat.sub(repl, out)
52
+ out = _SECRETY_NAMES.sub(lambda m: f"{m.group(1)}=[REDACTED]", out)
53
+ return out
54
+
55
+
56
+ def scrub(obj: Any) -> Any:
57
+ """Recursively scrub every string value in a nested dict/list structure.
58
+
59
+ Returns a new object — inputs are not mutated."""
60
+ if isinstance(obj, str):
61
+ return scrub_string(obj)
62
+ if isinstance(obj, dict):
63
+ return {k: scrub(v) for k, v in obj.items()}
64
+ if isinstance(obj, list):
65
+ return [scrub(v) for v in obj]
66
+ if isinstance(obj, tuple):
67
+ return tuple(scrub(v) for v in obj)
68
+ return obj
agent/core/session.py CHANGED
@@ -108,6 +108,11 @@ class Session:
108
  self.session_start_time = datetime.now().isoformat()
109
  self.turn_count: int = 0
110
  self.last_auto_save_turn: int = 0
 
 
 
 
 
111
 
112
  # Per-model probed reasoning-effort cache. Populated by the probe
113
  # on /model switch, read by ``effective_effort_for`` below. Keys are
@@ -132,6 +137,10 @@ class Session:
132
  }
133
  )
134
 
 
 
 
 
135
  def cancel(self) -> None:
136
  """Signal cancellation to the running agent loop."""
137
  self._cancelled.set()
@@ -184,6 +193,12 @@ class Session:
184
 
185
  def get_trajectory(self) -> dict:
186
  """Serialize complete session trajectory for logging"""
 
 
 
 
 
 
187
  return {
188
  "session_id": self.session_id,
189
  "session_start_time": self.session_start_time,
@@ -191,6 +206,7 @@ class Session:
191
  "model_name": self.config.model_name,
192
  "messages": [msg.model_dump() for msg in self.context_manager.items],
193
  "events": self.logged_events,
 
194
  }
195
 
196
  def save_trajectory_local(
@@ -216,16 +232,42 @@ class Session:
216
 
217
  trajectory = self.get_trajectory()
218
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Add upload metadata
220
  trajectory["upload_status"] = upload_status
221
  trajectory["upload_url"] = dataset_url
222
  trajectory["last_save_time"] = datetime.now().isoformat()
223
 
224
- filename = f"session_{self.session_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
225
- filepath = log_dir / filename
226
-
227
- with open(filepath, "w") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  json.dump(trajectory, f, indent=2)
 
229
 
230
  return str(filepath)
231
  except Exception as e:
 
108
  self.session_start_time = datetime.now().isoformat()
109
  self.turn_count: int = 0
110
  self.last_auto_save_turn: int = 0
111
+ # Stable local save path so heartbeat saves overwrite one file instead
112
+ # of spamming session_logs/. ``_last_heartbeat_ts`` is owned by
113
+ # ``agent.core.telemetry.HeartbeatSaver`` and lazily initialised there.
114
+ self._local_save_path: Optional[str] = None
115
+ self._last_heartbeat_ts: Optional[float] = None
116
 
117
  # Per-model probed reasoning-effort cache. Populated by the probe
118
  # on /model switch, read by ``effective_effort_for`` below. Keys are
 
137
  }
138
  )
139
 
140
+ # Mid-turn heartbeat flush (owned by telemetry module).
141
+ from agent.core.telemetry import HeartbeatSaver
142
+ HeartbeatSaver.maybe_fire(self)
143
+
144
  def cancel(self) -> None:
145
  """Signal cancellation to the running agent loop."""
146
  self._cancelled.set()
 
193
 
194
  def get_trajectory(self) -> dict:
195
  """Serialize complete session trajectory for logging"""
196
+ tools: list = []
197
+ if self.tool_router is not None:
198
+ try:
199
+ tools = self.tool_router.get_tool_specs_for_llm() or []
200
+ except Exception:
201
+ tools = []
202
  return {
203
  "session_id": self.session_id,
204
  "session_start_time": self.session_start_time,
 
206
  "model_name": self.config.model_name,
207
  "messages": [msg.model_dump() for msg in self.context_manager.items],
208
  "events": self.logged_events,
209
+ "tools": tools,
210
  }
211
 
212
  def save_trajectory_local(
 
232
 
233
  trajectory = self.get_trajectory()
234
 
235
+ # Scrub secrets at save time so session_logs/ never holds raw
236
+ # tokens on disk — a log aggregator, crash dump, or filesystem
237
+ # snapshot between heartbeats would otherwise leak them.
238
+ try:
239
+ from agent.core.redact import scrub
240
+ for key in ("messages", "events", "tools"):
241
+ if key in trajectory:
242
+ trajectory[key] = scrub(trajectory[key])
243
+ except Exception as _e:
244
+ logger.debug("Redact-on-save failed (non-fatal): %s", _e)
245
+
246
  # Add upload metadata
247
  trajectory["upload_status"] = upload_status
248
  trajectory["upload_url"] = dataset_url
249
  trajectory["last_save_time"] = datetime.now().isoformat()
250
 
251
+ # Reuse one stable path per session so heartbeat saves overwrite
252
+ # the same file instead of creating a new timestamped file every
253
+ # minute. The timestamp in the filename is kept for first-save
254
+ # ordering; subsequent saves just rewrite that file.
255
+ if self._local_save_path and Path(self._local_save_path).parent == log_dir:
256
+ filepath = Path(self._local_save_path)
257
+ else:
258
+ filename = (
259
+ f"session_{self.session_id}_"
260
+ f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
261
+ )
262
+ filepath = log_dir / filename
263
+ self._local_save_path = str(filepath)
264
+
265
+ # Atomic-ish write: stage to .tmp then rename so a crash mid-write
266
+ # doesn't leave a truncated JSON that breaks the retry scanner.
267
+ tmp_path = filepath.with_suffix(filepath.suffix + ".tmp")
268
+ with open(tmp_path, "w") as f:
269
  json.dump(trajectory, f, indent=2)
270
+ tmp_path.replace(filepath)
271
 
272
  return str(filepath)
273
  except Exception as e:
agent/core/session_uploader.py CHANGED
@@ -15,8 +15,15 @@ from dotenv import load_dotenv
15
 
16
  load_dotenv()
17
 
18
- # Token for session uploads loaded from env var (never hardcode tokens in source)
19
- _SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "")
 
 
 
 
 
 
 
20
 
21
 
22
  def upload_session_as_file(
@@ -58,15 +65,37 @@ def upload_session_as_file(
58
  json.dump(data, f, indent=2)
59
  return False
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Prepare JSONL content (single line)
62
- # Store messages and events as JSON strings to avoid schema conflicts
 
63
  session_row = {
64
  "session_id": data["session_id"],
65
  "session_start_time": data["session_start_time"],
66
  "session_end_time": data["session_end_time"],
67
  "model_name": data["model_name"],
68
- "messages": json.dumps(data["messages"]),
69
- "events": json.dumps(data["events"]),
 
70
  }
71
 
72
  # Create temporary JSONL file
 
15
 
16
  load_dotenv()
17
 
18
+ # Token for session uploads. Fallback chain (least-privilege first) matches
19
+ # backend/kpis_scheduler.py so one write-scoped token on the Space covers every
20
+ # telemetry dataset. Never hardcode tokens in source.
21
+ _SESSION_TOKEN = (
22
+ os.environ.get("HF_SESSION_UPLOAD_TOKEN")
23
+ or os.environ.get("HF_TOKEN")
24
+ or os.environ.get("HF_ADMIN_TOKEN")
25
+ or ""
26
+ )
27
 
28
 
29
  def upload_session_as_file(
 
65
  json.dump(data, f, indent=2)
66
  return False
67
 
68
+ # Scrub secrets (HF tokens, API keys, etc.) from messages + events
69
+ # before they leave the local disk. Best-effort regex-based redaction —
70
+ # see agent/core/redact.py for the patterns covered.
71
+ try:
72
+ from agent.core.redact import scrub # type: ignore
73
+ except Exception:
74
+ # Fallback for environments where the agent package isn't importable
75
+ # (shouldn't happen in our subprocess, but be defensive).
76
+ import importlib.util
77
+ _spec = importlib.util.spec_from_file_location(
78
+ "_redact",
79
+ Path(__file__).parent / "redact.py",
80
+ )
81
+ _mod = importlib.util.module_from_spec(_spec)
82
+ _spec.loader.exec_module(_mod) # type: ignore
83
+ scrub = _mod.scrub
84
+ scrubbed_messages = scrub(data["messages"])
85
+ scrubbed_events = scrub(data["events"])
86
+ scrubbed_tools = scrub(data.get("tools") or [])
87
+
88
  # Prepare JSONL content (single line)
89
+ # Store messages/events/tools as JSON strings to avoid schema conflicts
90
+ # across sessions with different tool rosters.
91
  session_row = {
92
  "session_id": data["session_id"],
93
  "session_start_time": data["session_start_time"],
94
  "session_end_time": data["session_end_time"],
95
  "model_name": data["model_name"],
96
+ "messages": json.dumps(scrubbed_messages),
97
+ "events": json.dumps(scrubbed_events),
98
+ "tools": json.dumps(scrubbed_tools),
99
  }
100
 
101
  # Create temporary JSONL file
agent/core/telemetry.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """All agent observability in one module.
2
+
3
+ Every telemetry signal the agent emits — LLM-call usage / cost, hf_jobs
4
+ lifecycle, sandbox lifecycle, user feedback, mid-turn heartbeat saves — is
5
+ defined here so business-logic files stay free of instrumentation noise.
6
+
7
+ Callsites are one-liners::
8
+
9
+ await telemetry.record_llm_call(session, model=..., response=r, ...)
10
+ await telemetry.record_hf_job_submit(session, job, args, image=..., job_type="Python")
11
+ HeartbeatSaver.maybe_fire(session)
12
+
13
+ All ``record_*`` functions emit a single ``Event`` via ``session.send_event``
14
+ and never raise — telemetry is best-effort and must not break the agent.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import logging
21
+ import time
22
+ from typing import Any
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ # ── usage extraction ────────────────────────────────────────────────────────
28
+
29
+ def extract_usage(response_or_chunk: Any) -> dict:
30
+ """Flat usage dict from a litellm response or final-chunk usage object.
31
+
32
+ Normalizes across providers: Anthropic exposes cache tokens as
33
+ ``cache_read_input_tokens`` / ``cache_creation_input_tokens``; OpenAI uses
34
+ ``prompt_tokens_details.cached_tokens``. Exposed under the stable keys
35
+ ``cache_read_tokens`` / ``cache_creation_tokens``.
36
+ """
37
+ u = getattr(response_or_chunk, "usage", None)
38
+ if u is None and isinstance(response_or_chunk, dict):
39
+ u = response_or_chunk.get("usage")
40
+ if u is None:
41
+ return {}
42
+
43
+ def _g(name, default=0):
44
+ if isinstance(u, dict):
45
+ return u.get(name, default) or default
46
+ return getattr(u, name, default) or default
47
+
48
+ prompt = _g("prompt_tokens")
49
+ completion = _g("completion_tokens")
50
+ total = _g("total_tokens") or (prompt + completion)
51
+
52
+ cache_read = _g("cache_read_input_tokens")
53
+ cache_creation = _g("cache_creation_input_tokens")
54
+
55
+ if not cache_read:
56
+ details = _g("prompt_tokens_details", None)
57
+ if details is not None:
58
+ if isinstance(details, dict):
59
+ cache_read = details.get("cached_tokens", 0) or 0
60
+ else:
61
+ cache_read = getattr(details, "cached_tokens", 0) or 0
62
+
63
+ return {
64
+ "prompt_tokens": int(prompt),
65
+ "completion_tokens": int(completion),
66
+ "total_tokens": int(total),
67
+ "cache_read_tokens": int(cache_read),
68
+ "cache_creation_tokens": int(cache_creation),
69
+ }
70
+
71
+
72
+ # ── llm_call ────────────────────────────────────────────────────────────────
73
+
74
+ async def record_llm_call(
75
+ session: Any,
76
+ *,
77
+ model: str,
78
+ response: Any = None,
79
+ latency_ms: int,
80
+ finish_reason: str | None,
81
+ ) -> dict:
82
+ """Emit an ``llm_call`` event and return the extracted usage dict so
83
+ callers can stash it on their result object if they want."""
84
+ usage = extract_usage(response) if response is not None else {}
85
+ cost_usd = 0.0
86
+ if response is not None:
87
+ try:
88
+ from litellm import completion_cost
89
+ cost_usd = float(completion_cost(completion_response=response) or 0.0)
90
+ except Exception:
91
+ cost_usd = 0.0
92
+ from agent.core.session import Event # local import to avoid cycle
93
+ try:
94
+ await session.send_event(Event(
95
+ event_type="llm_call",
96
+ data={
97
+ "model": model,
98
+ "latency_ms": latency_ms,
99
+ "finish_reason": finish_reason,
100
+ "cost_usd": cost_usd,
101
+ **usage,
102
+ },
103
+ ))
104
+ except Exception as e:
105
+ logger.debug("record_llm_call failed (non-fatal): %s", e)
106
+ return usage
107
+
108
+
109
+ # ── hf_jobs ────────────────────────────────────────────────────────────────
110
+
111
+ def _infer_push_to_hub(script_or_cmd: Any) -> bool:
112
+ if not isinstance(script_or_cmd, str):
113
+ return False
114
+ return (
115
+ "push_to_hub=True" in script_or_cmd
116
+ or "push_to_hub=true" in script_or_cmd
117
+ or "hub_model_id" in script_or_cmd
118
+ )
119
+
120
+
121
+ async def record_hf_job_submit(
122
+ session: Any,
123
+ job: Any,
124
+ args: dict,
125
+ *,
126
+ image: str,
127
+ job_type: str,
128
+ ) -> float:
129
+ """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
130
+ caller can pass it back into :func:`record_hf_job_complete`."""
131
+ from agent.core.session import Event
132
+ t_start = time.monotonic()
133
+ try:
134
+ script_text = args.get("script") or args.get("command") or ""
135
+ await session.send_event(Event(
136
+ event_type="hf_job_submit",
137
+ data={
138
+ "job_id": getattr(job, "id", None),
139
+ "job_url": getattr(job, "url", None),
140
+ "flavor": args.get("hardware_flavor", "cpu-basic"),
141
+ "timeout": args.get("timeout", "30m"),
142
+ "job_type": job_type,
143
+ "image": image,
144
+ "push_to_hub": _infer_push_to_hub(script_text),
145
+ },
146
+ ))
147
+ except Exception as e:
148
+ logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
149
+ return t_start
150
+
151
+
152
+ async def record_hf_job_complete(
153
+ session: Any,
154
+ job: Any,
155
+ *,
156
+ flavor: str,
157
+ final_status: str,
158
+ submit_ts: float,
159
+ ) -> None:
160
+ from agent.core.session import Event
161
+ try:
162
+ wall_time_s = int(time.monotonic() - submit_ts)
163
+ await session.send_event(Event(
164
+ event_type="hf_job_complete",
165
+ data={
166
+ "job_id": getattr(job, "id", None),
167
+ "flavor": flavor,
168
+ "final_status": final_status,
169
+ "wall_time_s": wall_time_s,
170
+ },
171
+ ))
172
+ except Exception as e:
173
+ logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
174
+
175
+
176
+ # ── sandbox ─────────────────────────────────────────────────────────────────
177
+
178
+ async def record_sandbox_create(
179
+ session: Any,
180
+ sandbox: Any,
181
+ *,
182
+ hardware: str,
183
+ create_latency_s: int,
184
+ ) -> None:
185
+ from agent.core.session import Event
186
+ try:
187
+ # Pin created-at on the session so record_sandbox_destroy can diff.
188
+ session._sandbox_created_at = time.monotonic() - create_latency_s
189
+ await session.send_event(Event(
190
+ event_type="sandbox_create",
191
+ data={
192
+ "sandbox_id": getattr(sandbox, "space_id", None),
193
+ "hardware": hardware,
194
+ "create_latency_s": int(create_latency_s),
195
+ },
196
+ ))
197
+ except Exception as e:
198
+ logger.debug("record_sandbox_create failed (non-fatal): %s", e)
199
+
200
+
201
+ async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
202
+ from agent.core.session import Event
203
+ try:
204
+ created = getattr(session, "_sandbox_created_at", None)
205
+ lifetime_s = int(time.monotonic() - created) if created else None
206
+ await session.send_event(Event(
207
+ event_type="sandbox_destroy",
208
+ data={
209
+ "sandbox_id": getattr(sandbox, "space_id", None),
210
+ "lifetime_s": lifetime_s,
211
+ },
212
+ ))
213
+ except Exception as e:
214
+ logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
215
+
216
+
217
+ # ── feedback ───────────────────────────────────────────────────────────────
218
+
219
+ async def record_feedback(
220
+ session: Any,
221
+ *,
222
+ rating: str,
223
+ turn_index: int | None = None,
224
+ message_id: str | None = None,
225
+ comment: str | None = None,
226
+ ) -> None:
227
+ from agent.core.session import Event
228
+ try:
229
+ await session.send_event(Event(
230
+ event_type="feedback",
231
+ data={
232
+ "rating": rating,
233
+ "turn_index": turn_index,
234
+ "message_id": message_id,
235
+ "comment": (comment or "")[:500],
236
+ },
237
+ ))
238
+ except Exception as e:
239
+ logger.debug("record_feedback failed (non-fatal): %s", e)
240
+
241
+
242
+ # ── heartbeat ──────────────────────────────────────────────────────────────
243
+
244
+ # Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
245
+ # keeps *weak* references to tasks, so the returned Task would otherwise be
246
+ # eligible for GC before running — the task gets discarded and the upload
247
+ # silently never happens. Hold strong refs until the task completes.
248
+ _heartbeat_tasks: set[asyncio.Task] = set()
249
+
250
+
251
+ class HeartbeatSaver:
252
+ """Time-gated mid-turn flush.
253
+
254
+ Called from ``Session.send_event`` after every event. Fires
255
+ ``save_and_upload_detached`` in a worker thread at most once per
256
+ ``heartbeat_interval_s`` (default 60s). Guards against losing trace data
257
+ on long-running turns that crash before ``turn_complete``.
258
+ """
259
+
260
+ @staticmethod
261
+ def maybe_fire(session: Any) -> None:
262
+ if not getattr(session.config, "save_sessions", False):
263
+ return
264
+ interval = getattr(session.config, "heartbeat_interval_s", 0) or 0
265
+ if interval <= 0:
266
+ return
267
+ now = time.monotonic()
268
+ last = getattr(session, "_last_heartbeat_ts", None)
269
+ if last is None:
270
+ # Initialise on first event; no save yet.
271
+ session._last_heartbeat_ts = now
272
+ return
273
+ if now - last < interval:
274
+ return
275
+ session._last_heartbeat_ts = now
276
+ repo_id = session.config.session_dataset_repo
277
+ try:
278
+ task = asyncio.get_running_loop().create_task(
279
+ asyncio.to_thread(session.save_and_upload_detached, repo_id)
280
+ )
281
+ # Hold a strong reference until the task finishes so asyncio can't
282
+ # GC it. ``set.discard`` is a no-op on missing keys → safe callback.
283
+ _heartbeat_tasks.add(task)
284
+ task.add_done_callback(_heartbeat_tasks.discard)
285
+ except RuntimeError:
286
+ try:
287
+ session.save_and_upload_detached(repo_id)
288
+ except Exception as e:
289
+ logger.debug("Heartbeat save failed (non-fatal): %s", e)
agent/sft/__init__.py ADDED
File without changes
agent/sft/tagger.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Derive tags for a session trajectory.
2
+
3
+ ``tag_session(trajectory)`` → ``list[str]``. Pure function. No filtering, no
4
+ mutation — tags are purely metadata so downstream pipelines can slice the raw
5
+ SFT dataset (``where 'hf_job:succeeded' in tags``) without re-reading trajectories.
6
+
7
+ Tag namespaces (all tags are ``"<namespace>:<value>"`` strings):
8
+
9
+ * ``tool:<name>`` — every tool called at least once (``tool:hf_jobs``, …)
10
+ * ``outcome:<end>`` — ``completed`` / ``errored`` / ``interrupted`` /
11
+ ``ongoing`` / ``doom_loop`` / ``context_exceeded``
12
+ * ``hf_job:<facet>`` — ``submitted``, ``succeeded``, ``failed``,
13
+ ``multi`` (>1), ``oom``, ``push_to_hub``
14
+ * ``gpu:<kind>`` — ``none``, ``t4``, ``a10g``, ``a100``, ``l40s``,
15
+ ``h100``, plus ``gpu:multi`` for x2/x4/x8 flavors
16
+ * ``sandbox:<facet>`` — ``created``, ``gpu``, ``cpu``, ``long_lived`` (>30 min)
17
+ * ``feedback:<kind>`` — ``up``, ``down``, ``mixed``, ``none``
18
+ * ``model:<family>`` — ``opus`` / ``sonnet`` / ``haiku`` / ``kimi`` /
19
+ ``gpt`` / ``deepseek`` / ``qwen`` / ``other``
20
+ * ``turns:<bucket>`` — ``short`` (<5) / ``medium`` (5–20) / ``long`` (>20)
21
+ * ``cost:<bucket>`` — ``low`` (<$0.10) / ``med`` (<$1) / ``high``
22
+ * ``task:<kind>`` — ``training`` / ``inference`` / ``data_prep`` /
23
+ ``research_only`` (heuristic on tools + scripts)
24
+
25
+ Tags are deduplicated before returning.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from typing import Any, Iterable
31
+
32
+ # Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
33
+ _GPU_FAMILY = {
34
+ "cpu-basic": "none", "cpu-upgrade": "none",
35
+ "t4-small": "t4", "t4-medium": "t4",
36
+ "l4x1": "l40s", "l4x4": "l40s",
37
+ "l40sx1": "l40s", "l40sx4": "l40s", "l40sx8": "l40s",
38
+ "a10g-small": "a10g", "a10g-large": "a10g",
39
+ "a10g-largex2": "a10g", "a10g-largex4": "a10g",
40
+ "a100-large": "a100", "a100x2": "a100",
41
+ "a100x4": "a100", "a100x8": "a100",
42
+ "h100": "h100", "h100x8": "h100",
43
+ }
44
+
45
+ # Substrings that count a flavor as multi-GPU.
46
+ _MULTI_GPU_MARKERS = ("x2", "x4", "x8")
47
+
48
+ # Tool names that don't touch training/inference or sandbox/jobs. If a session
49
+ # only used these, we tag it research_only.
50
+ _RESEARCH_ONLY_TOOLS = {
51
+ "research", "github_find_examples", "github_read_file", "github_list_repos",
52
+ "hf_papers", "explore_hf_docs", "fetch_hf_docs", "hub_repo_details",
53
+ "plan", "hf_inspect_dataset", "web_search",
54
+ }
55
+
56
+ # Tool names that signal data manipulation workflows.
57
+ _DATA_PREP_TOOLS = {"hf_inspect_dataset", "dataset_tools", "hub_repo_details"}
58
+
59
+
60
+ def _model_family(model_name: str | None) -> str:
61
+ if not model_name:
62
+ return "other"
63
+ n = model_name.lower()
64
+ if "opus" in n:
65
+ return "opus"
66
+ if "sonnet" in n:
67
+ return "sonnet"
68
+ if "haiku" in n:
69
+ return "haiku"
70
+ if "kimi" in n:
71
+ return "kimi"
72
+ if "gpt" in n:
73
+ return "gpt"
74
+ if "deepseek" in n:
75
+ return "deepseek"
76
+ if "qwen" in n:
77
+ return "qwen"
78
+ if "llama" in n:
79
+ return "llama"
80
+ return "other"
81
+
82
+
83
+ def _turns_bucket(n: int) -> str:
84
+ if n < 5:
85
+ return "short"
86
+ if n <= 20:
87
+ return "medium"
88
+ return "long"
89
+
90
+
91
+ def _cost_bucket(cost_usd: float) -> str:
92
+ if cost_usd < 0.10:
93
+ return "low"
94
+ if cost_usd < 1.0:
95
+ return "med"
96
+ return "high"
97
+
98
+
99
+ def _flavor_to_gpu_tags(flavor: str) -> list[str]:
100
+ family = _GPU_FAMILY.get(flavor, "none")
101
+ tags = [f"gpu:{family}"]
102
+ if any(m in flavor for m in _MULTI_GPU_MARKERS):
103
+ tags.append("gpu:multi")
104
+ return tags
105
+
106
+
107
+ def _has_oom_signal(tool_outputs: Iterable[str]) -> bool:
108
+ for out in tool_outputs:
109
+ if not isinstance(out, str):
110
+ continue
111
+ low = out.lower()
112
+ if "outofmemoryerror" in low or "cuda out of memory" in low or "oom" in low:
113
+ return True
114
+ return False
115
+
116
+
117
+ def _infer_task_tag(
118
+ tool_names: set[str],
119
+ hf_job_submit_scripts: list[str],
120
+ ) -> str | None:
121
+ """Return a ``task:*`` tag or None if we can't tell.
122
+
123
+ Heuristic order: training > inference > data_prep > research_only.
124
+ """
125
+ # training: any hf_jobs script with a Trainer/SFT/training keyword, OR uses
126
+ # hf_jobs at all and a script mentions training APIs.
127
+ for script in hf_job_submit_scripts:
128
+ low = script.lower()
129
+ if any(k in low for k in (
130
+ "sftconfig", "sfttrainer", "trainer(", "trainingarguments",
131
+ "grpo", "dpo", ".train(", "transformers import",
132
+ "trainer import", "fine-tune", "finetune",
133
+ )):
134
+ return "training"
135
+
136
+ # inference: sessions that use inference tools but never hf_jobs/sandbox
137
+ uses_compute = bool(tool_names & {"hf_jobs", "sandbox_create", "sandbox_exec"})
138
+ if not uses_compute and tool_names & {"inference", "generate", "run_inference"}:
139
+ return "inference"
140
+
141
+ # data_prep: primarily dataset tools and no training/inference
142
+ if tool_names & _DATA_PREP_TOOLS and not uses_compute:
143
+ return "data_prep"
144
+
145
+ # research_only: every tool used is in the research allow-list
146
+ if tool_names and tool_names <= _RESEARCH_ONLY_TOOLS:
147
+ return "research_only"
148
+
149
+ return None
150
+
151
+
152
+ def tag_session(trajectory: dict) -> list[str]:
153
+ """Derive tags from a session trajectory. Pure function."""
154
+ tags: set[str] = set()
155
+
156
+ events: list[dict] = trajectory.get("events") or []
157
+ messages: list[dict] = trajectory.get("messages") or []
158
+ model_name: str | None = trajectory.get("model_name")
159
+
160
+ # model
161
+ tags.add(f"model:{_model_family(model_name)}")
162
+
163
+ # turns
164
+ user_turns = sum(1 for m in messages if m.get("role") == "user")
165
+ tags.add(f"turns:{_turns_bucket(user_turns)}")
166
+
167
+ # cost + tool-name enumeration + outcome detection
168
+ cost_usd = 0.0
169
+ tool_names: set[str] = set()
170
+ tool_outputs: list[str] = []
171
+ hf_job_submit_count = 0
172
+ hf_job_submit_scripts: list[str] = []
173
+ hf_job_success_count = 0
174
+ hf_job_fail_count = 0
175
+ hf_job_push_to_hub = False
176
+ gpu_tags_seen: set[str] = set()
177
+
178
+ # Outcome is the *last* terminal signal. Seed with "ongoing" — overridden
179
+ # if we see a terminal event.
180
+ outcome = "ongoing"
181
+ had_error = False
182
+ had_doom_loop = False
183
+ had_compact = False
184
+
185
+ feedback_up = 0
186
+ feedback_down = 0
187
+
188
+ sandbox_created = False
189
+ sandbox_hardware: str | None = None
190
+ sandbox_lifetime_s: int | None = None
191
+
192
+ for ev in events:
193
+ et = ev.get("event_type")
194
+ data = ev.get("data") or {}
195
+
196
+ if et == "llm_call":
197
+ cost_usd += float(data.get("cost_usd") or 0.0)
198
+
199
+ elif et == "tool_call":
200
+ name = data.get("tool")
201
+ if name:
202
+ tool_names.add(name)
203
+
204
+ elif et == "tool_output":
205
+ out = data.get("output")
206
+ if isinstance(out, str):
207
+ tool_outputs.append(out)
208
+
209
+ elif et == "hf_job_submit":
210
+ hf_job_submit_count += 1
211
+ if data.get("push_to_hub"):
212
+ hf_job_push_to_hub = True
213
+ flavor = data.get("flavor") or "cpu-basic"
214
+ for t in _flavor_to_gpu_tags(flavor):
215
+ gpu_tags_seen.add(t)
216
+
217
+ elif et == "hf_job_complete":
218
+ final = (data.get("final_status") or "").lower()
219
+ if final in ("completed", "succeeded", "success"):
220
+ hf_job_success_count += 1
221
+ elif final in ("failed", "error", "timeout", "cancelled"):
222
+ hf_job_fail_count += 1
223
+
224
+ elif et == "sandbox_create":
225
+ sandbox_created = True
226
+ sandbox_hardware = data.get("hardware")
227
+
228
+ elif et == "sandbox_destroy":
229
+ lt = data.get("lifetime_s")
230
+ if isinstance(lt, (int, float)):
231
+ sandbox_lifetime_s = int(lt)
232
+
233
+ elif et == "feedback":
234
+ rating = data.get("rating")
235
+ if rating == "up":
236
+ feedback_up += 1
237
+ elif rating == "down":
238
+ feedback_down += 1
239
+
240
+ elif et == "error":
241
+ had_error = True
242
+ elif et == "turn_complete":
243
+ if not had_error:
244
+ outcome = "completed"
245
+ elif et == "interrupted":
246
+ outcome = "interrupted"
247
+ elif et == "compacted":
248
+ had_compact = True
249
+ elif et == "tool_log":
250
+ log_text = (data.get("log") or "").lower()
251
+ if "doom loop" in log_text:
252
+ had_doom_loop = True
253
+
254
+ if had_error and outcome not in ("completed", "interrupted"):
255
+ outcome = "errored"
256
+
257
+ tags.add(f"outcome:{outcome}")
258
+ if had_doom_loop:
259
+ tags.add("outcome:doom_loop")
260
+ if had_compact:
261
+ tags.add("outcome:context_exceeded")
262
+
263
+ # tools
264
+ for name in tool_names:
265
+ tags.add(f"tool:{name}")
266
+
267
+ # hf_jobs facets
268
+ if hf_job_submit_count >= 1:
269
+ tags.add("hf_job:submitted")
270
+ if hf_job_submit_count > 1:
271
+ tags.add("hf_job:multi")
272
+ if hf_job_success_count > 0:
273
+ tags.add("hf_job:succeeded")
274
+ if hf_job_fail_count > 0:
275
+ tags.add("hf_job:failed")
276
+ if hf_job_push_to_hub:
277
+ tags.add("hf_job:push_to_hub")
278
+ if _has_oom_signal(tool_outputs):
279
+ tags.add("hf_job:oom")
280
+
281
+ # gpu tags (from all submitted jobs)
282
+ tags.update(gpu_tags_seen)
283
+ if "gpu:none" in tags and len(gpu_tags_seen) > 1:
284
+ # If any GPU flavor was used, drop the "none" tag for clarity.
285
+ tags.discard("gpu:none")
286
+
287
+ # sandbox facets
288
+ if sandbox_created:
289
+ tags.add("sandbox:created")
290
+ if sandbox_hardware:
291
+ fam = _GPU_FAMILY.get(sandbox_hardware, "none")
292
+ tags.add("sandbox:cpu" if fam == "none" else "sandbox:gpu")
293
+ if sandbox_lifetime_s is not None and sandbox_lifetime_s > 1800:
294
+ tags.add("sandbox:long_lived")
295
+
296
+ # feedback
297
+ if feedback_up and feedback_down:
298
+ tags.add("feedback:mixed")
299
+ elif feedback_up:
300
+ tags.add("feedback:up")
301
+ elif feedback_down:
302
+ tags.add("feedback:down")
303
+ else:
304
+ tags.add("feedback:none")
305
+
306
+ # cost bucket
307
+ tags.add(f"cost:{_cost_bucket(cost_usd)}")
308
+
309
+ # task heuristic (needs scripts — pull from the hf_job_submit events'
310
+ # matching tool_call arguments in the event list).
311
+ for ev in events:
312
+ if ev.get("event_type") == "tool_call":
313
+ data = ev.get("data") or {}
314
+ if data.get("tool") == "hf_jobs":
315
+ args = data.get("arguments") or {}
316
+ script = args.get("script") or args.get("command") or ""
317
+ if isinstance(script, str):
318
+ hf_job_submit_scripts.append(script)
319
+
320
+ task_tag = _infer_task_tag(tool_names, hf_job_submit_scripts)
321
+ if task_tag:
322
+ tags.add(f"task:{task_tag}")
323
+
324
+ return sorted(tags)
agent/tools/jobs_tool.py CHANGED
@@ -528,14 +528,16 @@ class HfJobsTool:
528
  job_type = "Docker"
529
 
530
  # Run the job
 
 
531
  job = await _async_call(
532
  self.api.run_job,
533
  image=image,
534
  command=command,
535
  env=_add_default_env(args.get("env")),
536
  secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
537
- flavor=args.get("hardware_flavor", "cpu-basic"),
538
- timeout=args.get("timeout", "30m"),
539
  namespace=self.namespace,
540
  )
541
 
@@ -557,6 +559,16 @@ class HfJobsTool:
557
  )
558
  )
559
 
 
 
 
 
 
 
 
 
 
 
560
  # Wait for completion and stream logs
561
  logger.info(f"{job_type} job started: {job.url}")
562
  logger.info("Streaming logs...")
@@ -566,6 +578,13 @@ class HfJobsTool:
566
  namespace=self.namespace,
567
  )
568
 
 
 
 
 
 
 
 
569
  # Untrack job ID (completed or failed, no longer needs cancellation)
570
  if self.session:
571
  self.session._running_job_ids.discard(job.id)
 
528
  job_type = "Docker"
529
 
530
  # Run the job
531
+ flavor = args.get("hardware_flavor", "cpu-basic")
532
+ timeout_str = args.get("timeout", "30m")
533
  job = await _async_call(
534
  self.api.run_job,
535
  image=image,
536
  command=command,
537
  env=_add_default_env(args.get("env")),
538
  secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
539
+ flavor=flavor,
540
+ timeout=timeout_str,
541
  namespace=self.namespace,
542
  )
543
 
 
559
  )
560
  )
561
 
562
+ # Telemetry: job submission + completion (infra consumption signal).
563
+ submit_ts = None
564
+ if self.session:
565
+ from agent.core import telemetry
566
+ submit_ts = await telemetry.record_hf_job_submit(
567
+ self.session, job,
568
+ {**args, "hardware_flavor": flavor, "timeout": timeout_str},
569
+ image=image, job_type=job_type,
570
+ )
571
+
572
  # Wait for completion and stream logs
573
  logger.info(f"{job_type} job started: {job.url}")
574
  logger.info("Streaming logs...")
 
578
  namespace=self.namespace,
579
  )
580
 
581
+ if self.session and submit_ts is not None:
582
+ from agent.core import telemetry
583
+ await telemetry.record_hf_job_complete(
584
+ self.session, job,
585
+ flavor=flavor, final_status=final_status, submit_ts=submit_ts,
586
+ )
587
+
588
  # Untrack job ID (completed or failed, no longer needs cancellation)
589
  if self.session:
590
  self.session._running_job_ids.discard(job.id)
agent/tools/sandbox_tool.py CHANGED
@@ -131,6 +131,8 @@ async def _ensure_sandbox(
131
  }
132
  if hardware != "cpu-basic":
133
  kwargs["sleep_time"] = 2700
 
 
134
  try:
135
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
136
  except Sandbox.Cancelled:
@@ -139,6 +141,13 @@ async def _ensure_sandbox(
139
  watcher_task.cancel()
140
  session.sandbox = sb
141
 
 
 
 
 
 
 
 
142
  # Set a descriptive title (template title is inherited on duplicate)
143
  from huggingface_hub import metadata_update
144
 
 
131
  }
132
  if hardware != "cpu-basic":
133
  kwargs["sleep_time"] = 2700
134
+ import time as _t
135
+ _t_start = _t.monotonic()
136
  try:
137
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
138
  except Sandbox.Cancelled:
 
141
  watcher_task.cancel()
142
  session.sandbox = sb
143
 
144
+ # Telemetry: sandbox creation (infra consumption signal)
145
+ from agent.core import telemetry
146
+ await telemetry.record_sandbox_create(
147
+ session, sb, hardware=hardware,
148
+ create_latency_s=int(_t.monotonic() - _t_start),
149
+ )
150
+
151
  # Set a descriptive title (template title is inherited on duplicate)
152
  from huggingface_hub import metadata_update
153
 
backend/kpis_scheduler.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """In-process hourly KPI rollup, owned by the backend Space lifespan.
2
+
3
+ Replaces an external GitHub Actions cron so the rollup lives next to the data
4
+ and reuses the Space's existing HF token — no production secrets on the
5
+ public source repo. See ``scripts/build_kpis.py`` for the data-flow diagram
6
+ and metric definitions.
7
+
8
+ Behaviour::
9
+
10
+ lifespan startup → start APScheduler with cron("5 * * * *", UTC)
11
+ → fire a best-effort 6-hour backfill (fire-and-forget)
12
+ each :05 → run ``build_kpis.run_for_hour`` for the just-completed hour
13
+ lifespan shutdown → scheduler.shutdown(wait=False)
14
+
15
+ Environment::
16
+
17
+ HF_KPI_WRITE_TOKEN | HF_SESSION_UPLOAD_TOKEN | HF_TOKEN | HF_ADMIN_TOKEN
18
+ First one found is used. Least-privilege first.
19
+ KPI_SOURCE_REPO default smolagents/ml-intern-sessions
20
+ KPI_TARGET_REPO default smolagents/ml-intern-kpis
21
+ ML_INTERN_KPIS_DISABLED if truthy, the scheduler is not started
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import asyncio
27
+ import importlib.util
28
+ import logging
29
+ import os
30
+ from datetime import datetime, timedelta, timezone
31
+ from pathlib import Path
32
+ from typing import Optional
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
37
+
38
+ # Hold strong refs to backfill tasks so asyncio doesn't GC them mid-run.
39
+ _background_tasks: set[asyncio.Task] = set()
40
+
41
+ _scheduler = None # AsyncIOScheduler instance (lazy import)
42
+
43
+
44
+ def _resolve_token() -> Optional[str]:
45
+ """Pick the first available HF token. Least-privilege first."""
46
+ for var in (
47
+ "HF_KPI_WRITE_TOKEN",
48
+ "HF_SESSION_UPLOAD_TOKEN",
49
+ "HF_TOKEN",
50
+ "HF_ADMIN_TOKEN",
51
+ ):
52
+ val = os.environ.get(var)
53
+ if val:
54
+ return val
55
+ return None
56
+
57
+
58
+ def _load_build_kpis():
59
+ """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
60
+ spec = importlib.util.spec_from_file_location(
61
+ "build_kpis", _PROJECT_ROOT / "scripts" / "build_kpis.py",
62
+ )
63
+ mod = importlib.util.module_from_spec(spec)
64
+ assert spec.loader is not None
65
+ spec.loader.exec_module(mod)
66
+ return mod
67
+
68
+
69
+ async def _run_hour(hour_dt: datetime) -> None:
70
+ """Run one hourly rollup off the event loop. Best-effort, never raises."""
71
+ token = _resolve_token()
72
+ if not token:
73
+ logger.warning("kpis_scheduler: no HF token available, skipping %s", hour_dt)
74
+ return
75
+ try:
76
+ mod = _load_build_kpis()
77
+ from huggingface_hub import HfApi
78
+ api = HfApi()
79
+ source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
80
+ target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
81
+ await asyncio.to_thread(mod.run_for_hour, api, source, target, hour_dt, token)
82
+ except Exception as e:
83
+ logger.warning("kpis_scheduler: rollup for %s failed: %s", hour_dt, e)
84
+
85
+
86
+ async def run_last_completed_hour() -> None:
87
+ """The scheduled-at-:05 job. Rolls up the previous whole hour."""
88
+ now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
89
+ await _run_hour(now - timedelta(hours=1))
90
+
91
+
92
+ async def backfill(hours: int = 6) -> None:
93
+ """Catch-up pass for hours the Space was down. Idempotent (overwrites)."""
94
+ now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
95
+ for i in range(1, hours + 1):
96
+ await _run_hour(now - timedelta(hours=i))
97
+
98
+
99
+ def start(backfill_hours: int = 6) -> None:
100
+ """Called from FastAPI lifespan startup."""
101
+ global _scheduler
102
+ if os.environ.get("ML_INTERN_KPIS_DISABLED"):
103
+ logger.info("kpis_scheduler: disabled via ML_INTERN_KPIS_DISABLED")
104
+ return
105
+ if _scheduler is not None:
106
+ return
107
+
108
+ try:
109
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
110
+ from apscheduler.triggers.cron import CronTrigger
111
+ except ImportError:
112
+ logger.warning("kpis_scheduler: apscheduler not installed, skipping")
113
+ return
114
+
115
+ _scheduler = AsyncIOScheduler(timezone="UTC")
116
+ _scheduler.add_job(
117
+ run_last_completed_hour,
118
+ CronTrigger(minute=5),
119
+ id="kpis_hourly",
120
+ misfire_grace_time=600, # tolerate a 10-min misfire window
121
+ coalesce=True, # collapse multiple missed fires into one
122
+ max_instances=1,
123
+ replace_existing=True,
124
+ )
125
+ _scheduler.start()
126
+ logger.info("kpis_scheduler: started (cron '5 * * * *' UTC)")
127
+
128
+ # Non-blocking backfill. Hold a strong ref until done so asyncio doesn't
129
+ # GC the task before it finishes.
130
+ try:
131
+ task = asyncio.get_running_loop().create_task(backfill(backfill_hours))
132
+ _background_tasks.add(task)
133
+ task.add_done_callback(_background_tasks.discard)
134
+ except RuntimeError:
135
+ # Not in an event loop (tests); skip backfill.
136
+ pass
137
+
138
+
139
+ async def shutdown() -> None:
140
+ """Called from FastAPI lifespan shutdown."""
141
+ global _scheduler
142
+ if _scheduler is None:
143
+ return
144
+ _scheduler.shutdown(wait=False)
145
+ _scheduler = None
146
+ logger.info("kpis_scheduler: stopped")
backend/main.py CHANGED
@@ -27,8 +27,37 @@ logger = logging.getLogger(__name__)
27
  async def lifespan(app: FastAPI):
28
  """Application lifespan handler."""
29
  logger.info("Starting HF Agent backend...")
 
 
 
 
 
 
 
 
30
  yield
 
31
  logger.info("Shutting down HF Agent backend...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  app = FastAPI(
 
27
  async def lifespan(app: FastAPI):
28
  """Application lifespan handler."""
29
  logger.info("Starting HF Agent backend...")
30
+ # Start in-process hourly KPI rollup. Replaces an external cron so the
31
+ # rollup lives next to the data and reuses the Space's HF token.
32
+ try:
33
+ import kpis_scheduler
34
+ kpis_scheduler.start()
35
+ except Exception as e:
36
+ logger.warning("KPI scheduler failed to start: %s", e)
37
+
38
  yield
39
+
40
  logger.info("Shutting down HF Agent backend...")
41
+ try:
42
+ import kpis_scheduler
43
+ await kpis_scheduler.shutdown()
44
+ except Exception as e:
45
+ logger.warning("KPI scheduler shutdown failed: %s", e)
46
+
47
+ # Final-flush: save every still-active session so we don't lose traces on
48
+ # server restart. Uploads are detached subprocesses — this is fast.
49
+ try:
50
+ from session_manager import session_manager
51
+ for sid, agent_session in list(session_manager.sessions.items()):
52
+ sess = agent_session.session
53
+ if sess.config.save_sessions:
54
+ try:
55
+ sess.save_and_upload_detached(sess.config.session_dataset_repo)
56
+ logger.info("Flushed session %s on shutdown", sid)
57
+ except Exception as e:
58
+ logger.warning("Failed to flush session %s: %s", sid, e)
59
+ except Exception as e:
60
+ logger.warning("Lifespan final-flush skipped: %s", e)
61
 
62
 
63
  app = FastAPI(
backend/routes/agent.py CHANGED
@@ -693,3 +693,41 @@ async def shutdown_session(
693
  return {"status": "shutdown_requested", "session_id": session_id}
694
 
695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  return {"status": "shutdown_requested", "session_id": session_id}
694
 
695
 
696
+ @router.post("/feedback/{session_id}")
697
+ async def submit_feedback(
698
+ session_id: str,
699
+ body: dict,
700
+ user: dict = Depends(get_current_user),
701
+ ) -> dict:
702
+ """Attach a user feedback signal to a session's event log.
703
+
704
+ Body: {rating: "up"|"down"|"outcome_success"|"outcome_fail",
705
+ turn_index?: int, comment?: str, message_id?: str}
706
+ Appended as a `feedback` event and saved with the session trajectory.
707
+ """
708
+ _check_session_access(session_id, user)
709
+ agent_session = session_manager.sessions.get(session_id)
710
+ if not agent_session:
711
+ raise HTTPException(status_code=404, detail="Session not found")
712
+
713
+ rating = body.get("rating")
714
+ if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
715
+ raise HTTPException(status_code=400, detail="invalid rating")
716
+
717
+ from agent.core import telemetry
718
+ await telemetry.record_feedback(
719
+ agent_session.session,
720
+ rating=rating,
721
+ turn_index=body.get("turn_index"),
722
+ message_id=body.get("message_id"),
723
+ comment=body.get("comment"),
724
+ )
725
+ # Fire-and-forget save so feedback reaches the dataset even if the user
726
+ # closes the tab right after clicking.
727
+ if agent_session.session.config.save_sessions:
728
+ agent_session.session.save_and_upload_detached(
729
+ agent_session.session.config.session_dataset_repo
730
+ )
731
+ return {"status": "ok"}
732
+
733
+
backend/session_manager.py CHANGED
@@ -290,11 +290,14 @@ class SessionManager:
290
  """Delete the sandbox Space if one was created for this session."""
291
  sandbox = getattr(session, "sandbox", None)
292
  if sandbox and getattr(sandbox, "_owns_space", False):
 
293
  try:
294
- logger.info(f"Deleting sandbox {sandbox.space_id}...")
295
  await asyncio.to_thread(sandbox.delete)
 
 
296
  except Exception as e:
297
- logger.warning(f"Failed to delete sandbox {sandbox.space_id}: {e}")
298
 
299
  async def _run_session(
300
  self,
@@ -356,6 +359,15 @@ class SessionManager:
356
 
357
  await self._cleanup_sandbox(session)
358
 
 
 
 
 
 
 
 
 
 
359
  async with self._lock:
360
  if session_id in self.sessions:
361
  self.sessions[session_id].is_active = False
 
290
  """Delete the sandbox Space if one was created for this session."""
291
  sandbox = getattr(session, "sandbox", None)
292
  if sandbox and getattr(sandbox, "_owns_space", False):
293
+ space_id = getattr(sandbox, "space_id", None)
294
  try:
295
+ logger.info(f"Deleting sandbox {space_id}...")
296
  await asyncio.to_thread(sandbox.delete)
297
+ from agent.core import telemetry
298
+ await telemetry.record_sandbox_destroy(session, sandbox)
299
  except Exception as e:
300
+ logger.warning(f"Failed to delete sandbox {space_id}: {e}")
301
 
302
  async def _run_session(
303
  self,
 
359
 
360
  await self._cleanup_sandbox(session)
361
 
362
+ # Final-flush: always save on session death so we capture ended
363
+ # sessions even if the client disconnects without /shutdown.
364
+ # Idempotent via session_id key; detached subprocess.
365
+ if session.config.save_sessions:
366
+ try:
367
+ session.save_and_upload_detached(session.config.session_dataset_repo)
368
+ except Exception as e:
369
+ logger.warning(f"Final-flush failed for {session_id}: {e}")
370
+
371
  async with self._lock:
372
  if session_id in self.sessions:
373
  self.sessions[session_id].is_active = False
configs/main_agent_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "model_name": "bedrock/us.anthropic.claude-opus-4-6-v1",
3
  "save_sessions": true,
4
- "session_dataset_repo": "akseljoonas/hf-agent-sessions",
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
 
1
  {
2
  "model_name": "bedrock/us.anthropic.claude-opus-4-6-v1",
3
  "save_sessions": true,
4
+ "session_dataset_repo": "smolagents/ml-intern-sessions",
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
frontend/src/components/Chat/AssistantMessage.tsx CHANGED
@@ -1,13 +1,19 @@
1
- import { useMemo } from 'react';
2
- import { Box, Stack, Typography } from '@mui/material';
 
 
 
 
3
  import MarkdownContent from './MarkdownContent';
4
  import ToolCallGroup from './ToolCallGroup';
 
5
  import type { UIMessage } from 'ai';
6
  import type { MessageMeta } from '@/types/agent';
7
 
8
  interface AssistantMessageProps {
9
  message: UIMessage;
10
  isStreaming?: boolean;
 
11
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
  }
13
 
@@ -43,8 +49,27 @@ function groupParts(parts: UIMessage['parts']) {
43
  return groups;
44
  }
45
 
46
- export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) {
47
  const groups = useMemo(() => groupParts(message.parts), [message.parts]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  // Find the last text group index for streaming cursor
50
  let lastTextIdx = -1;
@@ -114,6 +139,24 @@ export default function AssistantMessage({ message, isStreaming = false, approve
114
  return null;
115
  })}
116
  </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  </Box>
118
  );
119
  }
 
1
+ import { useMemo, useState } from 'react';
2
+ import { Box, IconButton, Stack, Tooltip, Typography } from '@mui/material';
3
+ import ThumbUpOutlined from '@mui/icons-material/ThumbUpOutlined';
4
+ import ThumbUp from '@mui/icons-material/ThumbUp';
5
+ import ThumbDownOutlined from '@mui/icons-material/ThumbDownOutlined';
6
+ import ThumbDown from '@mui/icons-material/ThumbDown';
7
  import MarkdownContent from './MarkdownContent';
8
  import ToolCallGroup from './ToolCallGroup';
9
+ import { apiFetch } from '@/utils/api';
10
  import type { UIMessage } from 'ai';
11
  import type { MessageMeta } from '@/types/agent';
12
 
13
  interface AssistantMessageProps {
14
  message: UIMessage;
15
  isStreaming?: boolean;
16
+ sessionId?: string | null;
17
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
18
  }
19
 
 
49
  return groups;
50
  }
51
 
52
+ export default function AssistantMessage({ message, isStreaming = false, sessionId, approveTools }: AssistantMessageProps) {
53
  const groups = useMemo(() => groupParts(message.parts), [message.parts]);
54
+ const [feedback, setFeedback] = useState<'up' | 'down' | null>(null);
55
+ const [feedbackBusy, setFeedbackBusy] = useState(false);
56
+
57
+ const sendFeedback = async (rating: 'up' | 'down') => {
58
+ if (!sessionId || feedbackBusy) return;
59
+ setFeedbackBusy(true);
60
+ // Optimistic toggle — feedback is observability, not a hard requirement.
61
+ setFeedback(rating);
62
+ try {
63
+ await apiFetch(`/api/feedback/${sessionId}`, {
64
+ method: 'POST',
65
+ body: JSON.stringify({ rating, message_id: message.id }),
66
+ });
67
+ } catch {
68
+ // Silently swallow — don't block chat UX on a telemetry write.
69
+ } finally {
70
+ setFeedbackBusy(false);
71
+ }
72
+ };
73
 
74
  // Find the last text group index for streaming cursor
75
  let lastTextIdx = -1;
 
139
  return null;
140
  })}
141
  </Box>
142
+ {!isStreaming && sessionId && (
143
+ <Stack
144
+ direction="row"
145
+ spacing={0.5}
146
+ sx={{ mt: 0.5, ml: 0.5, opacity: feedback ? 1 : 0.5, '&:hover': { opacity: 1 } }}
147
+ >
148
+ <Tooltip title="Helpful">
149
+ <IconButton size="small" disabled={feedbackBusy} onClick={() => sendFeedback('up')}>
150
+ {feedback === 'up' ? <ThumbUp fontSize="inherit" /> : <ThumbUpOutlined fontSize="inherit" />}
151
+ </IconButton>
152
+ </Tooltip>
153
+ <Tooltip title="Not helpful">
154
+ <IconButton size="small" disabled={feedbackBusy} onClick={() => sendFeedback('down')}>
155
+ {feedback === 'down' ? <ThumbDown fontSize="inherit" /> : <ThumbDownOutlined fontSize="inherit" />}
156
+ </IconButton>
157
+ </Tooltip>
158
+ </Stack>
159
+ )}
160
  </Box>
161
  );
162
  }
frontend/src/components/Chat/MessageBubble.tsx CHANGED
@@ -9,6 +9,7 @@ interface MessageBubbleProps {
9
  onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
10
  isProcessing?: boolean;
11
  isStreaming?: boolean;
 
12
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
13
  }
14
 
@@ -19,6 +20,7 @@ export default function MessageBubble({
19
  onEditAndRegenerate,
20
  isProcessing = false,
21
  isStreaming = false,
 
22
  approveTools,
23
  }: MessageBubbleProps) {
24
  if (message.role === 'user') {
@@ -38,6 +40,7 @@ export default function MessageBubble({
38
  <AssistantMessage
39
  message={message}
40
  isStreaming={isStreaming}
 
41
  approveTools={approveTools}
42
  />
43
  );
 
9
  onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
10
  isProcessing?: boolean;
11
  isStreaming?: boolean;
12
+ sessionId?: string | null;
13
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
14
  }
15
 
 
20
  onEditAndRegenerate,
21
  isProcessing = false,
22
  isStreaming = false,
23
+ sessionId,
24
  approveTools,
25
  }: MessageBubbleProps) {
26
  if (message.role === 'user') {
 
40
  <AssistantMessage
41
  message={message}
42
  isStreaming={isStreaming}
43
+ sessionId={sessionId}
44
  approveTools={approveTools}
45
  />
46
  );
frontend/src/components/Chat/MessageList.tsx CHANGED
@@ -8,6 +8,7 @@ import type { UIMessage } from 'ai';
8
  interface MessageListProps {
9
  messages: UIMessage[];
10
  isProcessing: boolean;
 
11
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
  onUndoLastTurn: () => void | Promise<void>;
13
  onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
@@ -57,7 +58,7 @@ function WelcomeGreeting() {
57
  );
58
  }
59
 
60
- export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) {
61
  const scrollContainerRef = useRef<HTMLDivElement>(null);
62
  const stickToBottom = useRef(true);
63
 
@@ -139,6 +140,7 @@ export default function MessageList({ messages, isProcessing, approveTools, onUn
139
  onEditAndRegenerate={onEditAndRegenerate}
140
  isProcessing={isProcessing}
141
  isStreaming={isProcessing && msg.id === lastAssistantId}
 
142
  approveTools={approveTools}
143
  />
144
  ))
 
8
  interface MessageListProps {
9
  messages: UIMessage[];
10
  isProcessing: boolean;
11
+ sessionId?: string | null;
12
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
13
  onUndoLastTurn: () => void | Promise<void>;
14
  onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
 
58
  );
59
  }
60
 
61
+ export default function MessageList({ messages, isProcessing, sessionId, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) {
62
  const scrollContainerRef = useRef<HTMLDivElement>(null);
63
  const stickToBottom = useRef(true);
64
 
 
140
  onEditAndRegenerate={onEditAndRegenerate}
141
  isProcessing={isProcessing}
142
  isStreaming={isProcessing && msg.id === lastAssistantId}
143
+ sessionId={sessionId}
144
  approveTools={approveTools}
145
  />
146
  ))
frontend/src/components/SessionChat.tsx CHANGED
@@ -102,6 +102,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
102
  <MessageList
103
  messages={messages}
104
  isProcessing={busy}
 
105
  approveTools={approveTools}
106
  onUndoLastTurn={undoLastTurn}
107
  onEditAndRegenerate={editAndRegenerate}
 
102
  <MessageList
103
  messages={messages}
104
  isProcessing={busy}
105
+ sessionId={sessionId}
106
  approveTools={approveTools}
107
  onUndoLastTurn={undoLastTurn}
108
  onEditAndRegenerate={editAndRegenerate}
pyproject.toml CHANGED
@@ -26,6 +26,7 @@ dependencies = [
26
  "uvicorn[standard]>=0.32.0",
27
  "httpx>=0.27.0",
28
  "websockets>=13.0",
 
29
  ]
30
 
31
  [project.optional-dependencies]
 
26
  "uvicorn[standard]>=0.32.0",
27
  "httpx>=0.27.0",
28
  "websockets>=13.0",
29
+ "apscheduler>=3.10,<4",
30
  ]
31
 
32
  [project.optional-dependencies]
scripts/build_kpis.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Hourly KPI rollup for the session-trajectory dataset.
3
+
4
+ ================================================================================
5
+ Data flow
6
+ ================================================================================
7
+
8
+ ┌────────────────────┐ heartbeat ┌────────────────────────────────┐
9
+ │ agent (CLI/web) │ ───────────────▶ │ hf-agent-sessions (dataset) │
10
+ │ Session.send_event│ │ sessions/YYYY-MM-DD/<id>.jsonl│
11
+ └────────────────────┘ └───────────────┬────────────────┘
12
+ │ cron @:05 each hour
13
+
14
+ ┌──────────────────────────────────┐
15
+ │ scripts/build_kpis.py │
16
+ │ (GitHub Actions) │
17
+ └───────────────┬──────────────────┘
18
+ │ upload CSV
19
+
20
+ ┌──────────────────────────────────┐
21
+ │ hf-agent-kpis (dataset) │
22
+ │ hourly/YYYY-MM-DD/HH.csv │
23
+ └──────────────────────────────────┘
24
+
25
+ Each hourly run reads today's + yesterday's session folders (to cover sessions
26
+ that crossed midnight), filters events into the target hour window
27
+ ``[hour, hour+1h)``, computes aggregates, and writes one CSV at
28
+ ``hourly/<date>/<HH>.csv`` in the target dataset. Uploads are idempotent —
29
+ re-running the same hour overwrites.
30
+
31
+ ================================================================================
32
+ Metrics (one row per hour)
33
+ ================================================================================
34
+
35
+ sessions — distinct session_ids with ≥1 event in window
36
+ users — distinct user ids (when present on session rows)
37
+ turns — sum of user-message counts across active sessions
38
+ llm_calls — count of llm_call events
39
+ tokens_prompt / _completion / _cache_read / _cache_creation
40
+ cost_usd — sum of llm_call.cost_usd
41
+ cache_hit_ratio — cache_read / (cache_read + prompt)
42
+ tool_success_rate — tool_output success=True / total tool_output
43
+ failure_rate — sessions that ended with an `error` event / sessions
44
+ regenerate_rate — sessions with any `undo_complete` event / sessions
45
+ time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
46
+ thumbs_up / thumbs_down
47
+ hf_jobs_submitted / _succeeded
48
+ gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
49
+
50
+ ================================================================================
51
+ Usage
52
+ ================================================================================
53
+
54
+ # Run for the most recently completed hour (default — the cron path):
55
+ python scripts/build_kpis.py
56
+
57
+ # Backfill last 24 hours:
58
+ python scripts/build_kpis.py --hours 24
59
+
60
+ # Explicit hour (UTC):
61
+ python scripts/build_kpis.py --datetime 2026-04-24T14
62
+
63
+ Env:
64
+ HF_TOKEN (or HF_KPI_WRITE_TOKEN) — write access to the target dataset.
65
+
66
+ ================================================================================
67
+ Deploy
68
+ ================================================================================
69
+
70
+ See ``.github/workflows/build-kpis.yml`` — runs every hour at :05. To provision:
71
+
72
+ 1. Create the target dataset (once):
73
+ huggingface-cli repo create hf-agent-kpis --type dataset
74
+ 2. Put ``HF_KPI_WRITE_TOKEN`` (or ``HF_TOKEN``) into repo Actions secrets.
75
+ 3. Merge this file; the first scheduled run fires within the hour.
76
+ """
77
+
78
+ from __future__ import annotations
79
+
80
+ import argparse
81
+ import io
82
+ import json
83
+ import logging
84
+ import os
85
+ import sys
86
+ import tempfile
87
+ from collections import defaultdict
88
+ from datetime import date, datetime, timedelta, timezone
89
+ from pathlib import Path
90
+ from typing import Any, Iterable
91
+
92
+ logger = logging.getLogger("build_kpis")
93
+
94
+ # Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
95
+ # only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
96
+ _FLAVOR_GPU_COUNT = {
97
+ "cpu-basic": 0, "cpu-upgrade": 0,
98
+ "t4-small": 1, "t4-medium": 1,
99
+ "l4x1": 1, "l4x4": 4,
100
+ "l40sx1": 1, "l40sx4": 4, "l40sx8": 8,
101
+ "a10g-small": 1, "a10g-large": 1, "a10g-largex2": 2, "a10g-largex4": 4,
102
+ "a100-large": 1, "a100x2": 2, "a100x4": 4, "a100x8": 8,
103
+ "h100": 1, "h100x8": 8,
104
+ }
105
+
106
+
107
+ def _percentile(values: list[float], p: float) -> float:
108
+ if not values:
109
+ return 0.0
110
+ values = sorted(values)
111
+ k = (len(values) - 1) * p
112
+ f = int(k)
113
+ c = min(f + 1, len(values) - 1)
114
+ if f == c:
115
+ return float(values[f])
116
+ return float(values[f] + (values[c] - values[f]) * (k - f))
117
+
118
+
119
+ def _parse_ts(s: Any) -> datetime | None:
120
+ if not s or not isinstance(s, str):
121
+ return None
122
+ try:
123
+ dt = datetime.fromisoformat(s)
124
+ except Exception:
125
+ return None
126
+ # Normalise to aware UTC so comparisons work against window bounds.
127
+ if dt.tzinfo is None:
128
+ dt = dt.replace(tzinfo=timezone.utc)
129
+ return dt
130
+
131
+
132
+ def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[str]:
133
+ """Yield repo-relative paths for all sessions under ``sessions/YYYY-MM-DD/``."""
134
+ prefix = f"sessions/{day.isoformat()}/"
135
+ try:
136
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
137
+ except Exception as e:
138
+ logger.warning("list_repo_files(%s) failed: %s", repo_id, e)
139
+ return []
140
+ return [f for f in files if f.startswith(prefix) and f.endswith(".jsonl")]
141
+
142
+
143
+ def _download_session(repo_id: str, path: str, token: str) -> dict | None:
144
+ """Fetch one session JSONL and decode its single row.
145
+
146
+ ``hf_hub_download`` caches; second run within the same process / runner
147
+ directory is near-free.
148
+ """
149
+ from huggingface_hub import hf_hub_download
150
+ try:
151
+ local = hf_hub_download(
152
+ repo_id=repo_id, filename=path, repo_type="dataset", token=token,
153
+ )
154
+ except Exception as e:
155
+ logger.warning("hf_hub_download(%s) failed: %s", path, e)
156
+ return None
157
+ try:
158
+ with open(local, "r") as f:
159
+ line = f.readline().strip()
160
+ if not line:
161
+ return None
162
+ row = json.loads(line)
163
+ # Session uploader stores messages/events as JSON strings — unpack.
164
+ for key in ("messages", "events", "tools"):
165
+ v = row.get(key)
166
+ if isinstance(v, str):
167
+ try:
168
+ row[key] = json.loads(v)
169
+ except Exception:
170
+ row[key] = []
171
+ return row
172
+ except Exception as e:
173
+ logger.warning("parse(%s) failed: %s", path, e)
174
+ return None
175
+
176
+
177
+ def _filter_session_to_window(
178
+ session: dict, start: datetime, end: datetime,
179
+ ) -> dict | None:
180
+ """Return a copy of ``session`` whose events are only those in ``[start, end)``.
181
+
182
+ ``None`` if no event falls in the window — the caller drops the session
183
+ from this hour's aggregate.
184
+ """
185
+ events = session.get("events") or []
186
+ in_window = []
187
+ for ev in events:
188
+ ts = _parse_ts(ev.get("timestamp"))
189
+ if ts is None:
190
+ continue
191
+ if start <= ts < end:
192
+ in_window.append(ev)
193
+ if not in_window:
194
+ return None
195
+ return {**session, "events": in_window}
196
+
197
+
198
+ def _session_metrics(session: dict) -> dict:
199
+ """Reduce a single session trajectory to its KPI contributions.
200
+
201
+ Assumes ``events`` are already filtered to the target window by the caller.
202
+ """
203
+ # Pre-seed every numeric key so downstream aggregation can sum without
204
+ # having to special-case empty sessions.
205
+ out: dict = {
206
+ "sessions": 0, "turns": 0, "llm_calls": 0,
207
+ "tokens_prompt": 0, "tokens_completion": 0,
208
+ "tokens_cache_read": 0, "tokens_cache_creation": 0,
209
+ "cost_usd": 0.0,
210
+ "tool_calls_total": 0, "tool_calls_success": 0,
211
+ "failures": 0, "regenerate_sessions": 0,
212
+ "thumbs_up": 0, "thumbs_down": 0,
213
+ "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0,
214
+ "first_tool_s": -1,
215
+ }
216
+ events = session.get("events") or []
217
+ messages = session.get("messages") or []
218
+
219
+ turn_count = sum(1 for m in messages if m.get("role") == "user")
220
+ out["turns"] = turn_count
221
+ out["sessions"] = 1
222
+
223
+ tool_success = 0
224
+ tool_total = 0
225
+ had_error = False
226
+ had_undo = False
227
+ first_tool_ts = None
228
+ session_start = session.get("session_start_time")
229
+ gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
230
+ jobs_submitted = 0
231
+ jobs_succeeded = 0
232
+ thumbs_up = 0
233
+ thumbs_down = 0
234
+
235
+ start_dt = _parse_ts(session_start)
236
+
237
+ for ev in events:
238
+ et = ev.get("event_type")
239
+ data = ev.get("data") or {}
240
+ ts = _parse_ts(ev.get("timestamp"))
241
+
242
+ if et == "llm_call":
243
+ out["llm_calls"] += 1
244
+ out["tokens_prompt"] += int(data.get("prompt_tokens") or 0)
245
+ out["tokens_completion"] += int(data.get("completion_tokens") or 0)
246
+ out["tokens_cache_read"] += int(data.get("cache_read_tokens") or 0)
247
+ out["tokens_cache_creation"] += int(data.get("cache_creation_tokens") or 0)
248
+ out["cost_usd"] += float(data.get("cost_usd") or 0.0)
249
+
250
+ elif et == "tool_output":
251
+ tool_total += 1
252
+ if data.get("success"):
253
+ tool_success += 1
254
+ if first_tool_ts is None and ts is not None and start_dt is not None:
255
+ first_tool_ts = (ts - start_dt).total_seconds()
256
+
257
+ elif et == "tool_call":
258
+ if first_tool_ts is None and ts is not None and start_dt is not None:
259
+ first_tool_ts = (ts - start_dt).total_seconds()
260
+
261
+ elif et == "error":
262
+ had_error = True
263
+
264
+ elif et == "undo_complete":
265
+ had_undo = True
266
+
267
+ elif et == "feedback":
268
+ rating = data.get("rating")
269
+ if rating == "up":
270
+ thumbs_up += 1
271
+ elif rating == "down":
272
+ thumbs_down += 1
273
+
274
+ elif et == "hf_job_submit":
275
+ jobs_submitted += 1
276
+
277
+ elif et == "hf_job_complete":
278
+ flavor = data.get("flavor") or "unknown"
279
+ status = (data.get("final_status") or "").lower()
280
+ wall = float(data.get("wall_time_s") or 0.0)
281
+ gpus = _FLAVOR_GPU_COUNT.get(flavor, 0)
282
+ gpu_hours_by_flavor[flavor] += wall * gpus / 3600.0
283
+ if status in ("completed", "succeeded", "success"):
284
+ jobs_succeeded += 1
285
+
286
+ out["tool_calls_total"] = tool_total
287
+ out["tool_calls_success"] = tool_success
288
+ out["failures"] = 1 if had_error else 0
289
+ out["regenerate_sessions"] = 1 if had_undo else 0
290
+ out["thumbs_up"] = thumbs_up
291
+ out["thumbs_down"] = thumbs_down
292
+ out["hf_jobs_submitted"] = jobs_submitted
293
+ out["hf_jobs_succeeded"] = jobs_succeeded
294
+ out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
295
+ out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
296
+ out["_user"] = session.get("user_id") or session.get("session_id")
297
+ return dict(out)
298
+
299
+
300
+ def _aggregate(per_session: list[dict]) -> dict:
301
+ """Collapse a bucket's worth of session rollups into the final KPI row."""
302
+ ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
303
+ gpu_hours: dict[str, float] = defaultdict(float)
304
+ for s in per_session:
305
+ for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
306
+ gpu_hours[f] += h
307
+
308
+ total_sessions = sum(s["sessions"] for s in per_session)
309
+ total_turns = sum(s["turns"] for s in per_session)
310
+ tokens_prompt = sum(s["tokens_prompt"] for s in per_session)
311
+ tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
312
+ tool_total = sum(s["tool_calls_total"] for s in per_session)
313
+ tool_success = sum(s["tool_calls_success"] for s in per_session)
314
+
315
+ unique_users = {s.get("_user") for s in per_session if s.get("_user")}
316
+
317
+ return {
318
+ "sessions": total_sessions,
319
+ "users": len(unique_users),
320
+ "turns": total_turns,
321
+ "llm_calls": int(sum(s["llm_calls"] for s in per_session)),
322
+ "tokens_prompt": int(tokens_prompt),
323
+ "tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
324
+ "tokens_cache_read": int(tokens_cache_read),
325
+ "tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
326
+ "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
327
+ "cache_hit_ratio": round(
328
+ tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
329
+ ) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
330
+ "tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
331
+ "failure_rate": round(
332
+ sum(s["failures"] for s in per_session) / total_sessions, 4
333
+ ) if total_sessions > 0 else 0.0,
334
+ "regenerate_rate": round(
335
+ sum(s["regenerate_sessions"] for s in per_session) / total_sessions, 4
336
+ ) if total_sessions > 0 else 0.0,
337
+ "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
338
+ "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
339
+ "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
340
+ "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
341
+ "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
342
+ "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
343
+ "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
344
+ }
345
+
346
+
347
+ # Back-compat alias: older tests call _aggregate_day.
348
+ _aggregate_day = _aggregate
349
+
350
+
351
+ def _csv_cell(v: Any) -> str:
352
+ s = str(v)
353
+ if "," in s or '"' in s or "\n" in s:
354
+ return '"' + s.replace('"', '""') + '"'
355
+ return s
356
+
357
+
358
+ def _write_csv(
359
+ api, row: dict, bucket_key: str, path_in_repo: str, target_repo: str, token: str,
360
+ ) -> None:
361
+ """Render ``row`` to CSV with a leading ``bucket`` column and upload.
362
+
363
+ ``bucket_key`` is the hour string (ISO ``YYYY-MM-DDTHH``) or date string;
364
+ written as the ``bucket`` column so downstream consumers can union all
365
+ CSVs without date-parsing paths. ``api`` is the caller's ``HfApi``
366
+ instance — reused so we don't spin up a fresh one per CSV.
367
+ """
368
+ columns = list(row.keys())
369
+ buf = io.StringIO()
370
+ buf.write(",".join(["bucket", *columns]) + "\n")
371
+ buf.write(",".join([bucket_key, *[_csv_cell(row[c]) for c in columns]]) + "\n")
372
+
373
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as tmp:
374
+ tmp.write(buf.getvalue())
375
+ tmp_path = tmp.name
376
+
377
+ try:
378
+ api.create_repo(
379
+ repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
380
+ )
381
+ api.upload_file(
382
+ path_or_fileobj=tmp_path,
383
+ path_in_repo=path_in_repo,
384
+ repo_id=target_repo,
385
+ repo_type="dataset",
386
+ token=token,
387
+ commit_message=f"KPIs for {bucket_key}",
388
+ )
389
+ finally:
390
+ try:
391
+ os.unlink(tmp_path)
392
+ except Exception:
393
+ pass
394
+
395
+
396
+ def run_for_hour(
397
+ api, source_repo: str, target_repo: str, hour_dt: datetime, token: str,
398
+ ) -> dict:
399
+ """Roll up one UTC hour [hour_dt, hour_dt+1h).
400
+
401
+ Reads today's + yesterday's session folders so sessions that crossed
402
+ midnight land in the right hourly bucket.
403
+ """
404
+ if hour_dt.tzinfo is None:
405
+ hour_dt = hour_dt.replace(tzinfo=timezone.utc)
406
+ window_start = hour_dt.replace(minute=0, second=0, microsecond=0)
407
+ window_end = window_start + timedelta(hours=1)
408
+
409
+ # Sessions partition by session_start_time date. A session that started
410
+ # at 23:50 yesterday can still emit events in today's first hours, so we
411
+ # look at both folders.
412
+ candidate_dates = {window_start.date(), (window_start - timedelta(days=1)).date()}
413
+
414
+ per_session: list[dict] = []
415
+ for d in sorted(candidate_dates):
416
+ for path in _iter_session_files(api, source_repo, d, token):
417
+ sess = _download_session(source_repo, path, token)
418
+ if not sess:
419
+ continue
420
+ windowed = _filter_session_to_window(sess, window_start, window_end)
421
+ if windowed is None:
422
+ continue
423
+ per_session.append(_session_metrics(windowed))
424
+
425
+ if not per_session:
426
+ logger.info("No sessions in window %s — skipping", window_start.isoformat())
427
+ return {}
428
+
429
+ row = _aggregate(per_session)
430
+ bucket_key = window_start.strftime("%Y-%m-%dT%H")
431
+ path_in_repo = f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv"
432
+ _write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
433
+ logger.info("Wrote KPIs for %s (%d sessions): %s",
434
+ bucket_key, per_session and len(per_session), row)
435
+ return row
436
+
437
+
438
+ # Back-compat for daily backfills — unchanged behaviour.
439
+ def run_for_day(api, source_repo: str, target_repo: str, day: date, token: str) -> dict:
440
+ paths = _iter_session_files(api, source_repo, day, token)
441
+ per_session: list[dict] = []
442
+ for path in paths:
443
+ sess = _download_session(source_repo, path, token)
444
+ if not sess:
445
+ continue
446
+ per_session.append(_session_metrics(sess))
447
+ if not per_session:
448
+ logger.info("No sessions found for %s — skipping", day)
449
+ return {}
450
+ row = _aggregate(per_session)
451
+ path_in_repo = f"daily/{day.isoformat()}.csv"
452
+ _write_csv(api, row, day.isoformat(), path_in_repo, target_repo, token)
453
+ return row
454
+
455
+
456
+ def _parse_hour_arg(s: str) -> datetime:
457
+ """Accept ``YYYY-MM-DDTHH`` or full ISO — always pinned to the start of the hour, UTC."""
458
+ dt = datetime.fromisoformat(s)
459
+ if dt.tzinfo is None:
460
+ dt = dt.replace(tzinfo=timezone.utc)
461
+ return dt.replace(minute=0, second=0, microsecond=0)
462
+
463
+
464
+ def main(argv: list[str] | None = None) -> int:
465
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
466
+ ap = argparse.ArgumentParser()
467
+ ap.add_argument("--source", default="smolagents/ml-intern-sessions")
468
+ ap.add_argument("--target", default="smolagents/ml-intern-kpis")
469
+ ap.add_argument(
470
+ "--hours", type=int, default=1,
471
+ help="Number of trailing hours to roll up (default: 1 = last completed hour).",
472
+ )
473
+ ap.add_argument(
474
+ "--datetime", type=str, default=None,
475
+ help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
476
+ )
477
+ ap.add_argument(
478
+ "--daily-backfill", type=str, default=None,
479
+ help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
480
+ "Writes to daily/<date>.csv. Use for historical backfill only.",
481
+ )
482
+ args = ap.parse_args(argv)
483
+
484
+ token = (
485
+ os.environ.get("HF_KPI_WRITE_TOKEN")
486
+ or os.environ.get("HF_SESSION_UPLOAD_TOKEN")
487
+ or os.environ.get("HF_TOKEN")
488
+ or os.environ.get("HF_ADMIN_TOKEN")
489
+ )
490
+ if not token:
491
+ logger.error(
492
+ "No HF token found. Set one of: HF_KPI_WRITE_TOKEN, "
493
+ "HF_SESSION_UPLOAD_TOKEN, HF_TOKEN, HF_ADMIN_TOKEN."
494
+ )
495
+ return 1
496
+
497
+ from huggingface_hub import HfApi
498
+ api = HfApi()
499
+
500
+ if args.daily_backfill:
501
+ run_for_day(api, args.source, args.target, date.fromisoformat(args.daily_backfill), token)
502
+ return 0
503
+
504
+ if args.datetime:
505
+ target_hours = [_parse_hour_arg(args.datetime)]
506
+ else:
507
+ now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
508
+ # Roll up *completed* hours: start from the hour before ``now``.
509
+ target_hours = [now - timedelta(hours=i) for i in range(1, args.hours + 1)]
510
+
511
+ for hour in target_hours:
512
+ run_for_hour(api, args.source, args.target, hour, token)
513
+ return 0
514
+
515
+
516
+ if __name__ == "__main__":
517
+ sys.exit(main())
scripts/build_sft.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Export session trajectories as raw multi-turn tool-calling SFT data.
3
+
4
+ Reads the source sessions dataset (JSONL, one file per session at
5
+ ``sessions/YYYY-MM-DD/<session_id>.jsonl``) and writes a re-shaped row to a
6
+ target dataset at ``sft/YYYY-MM-DD/<session_id>.jsonl``.
7
+
8
+ **No filtering, no cleaning, no dedup.** Raw passthrough of messages + tools,
9
+ with session-level metadata and derived tags (see ``agent/sft/tagger.py``)
10
+ attached for downstream slicing.
11
+
12
+ Output row schema::
13
+
14
+ {
15
+ "session_id": "...",
16
+ "model": "claude-opus-4-6",
17
+ "timestamp": "2026-04-24T...",
18
+ "tags": ["tool:hf_jobs", "gpu:a100", "hf_job:succeeded", ...],
19
+ "messages": [...], # OpenAI / TRL SFTTrainer format
20
+ "tools": [...] # OpenAI tool schemas the session had access to
21
+ }
22
+
23
+ Usage::
24
+
25
+ python scripts/build_sft.py \\
26
+ --source smolagents/ml-intern-sessions \\
27
+ --target smolagents/ml-intern-sft \\
28
+ --days 7
29
+
30
+ Env:
31
+ HF_TOKEN (or HF_SFT_WRITE_TOKEN) — write access to target dataset.
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import argparse
37
+ import json
38
+ import logging
39
+ import os
40
+ import sys
41
+ import tempfile
42
+ from datetime import date, datetime, timedelta, timezone
43
+ from typing import Iterable
44
+
45
+ # Make ``agent`` importable when this script is run outside the project venv.
46
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
47
+
48
+ from agent.sft.tagger import tag_session # noqa: E402
49
+
50
+ logger = logging.getLogger("build_sft")
51
+
52
+
53
+ def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[str]:
54
+ prefix = f"sessions/{day.isoformat()}/"
55
+ try:
56
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
57
+ except Exception as e:
58
+ logger.warning("list_repo_files(%s) failed: %s", repo_id, e)
59
+ return []
60
+ return [f for f in files if f.startswith(prefix) and f.endswith(".jsonl")]
61
+
62
+
63
+ def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
64
+ from huggingface_hub import hf_hub_download
65
+ try:
66
+ local = hf_hub_download(
67
+ repo_id=repo_id, filename=path, repo_type="dataset", token=token,
68
+ )
69
+ except Exception as e:
70
+ logger.warning("hf_hub_download(%s) failed: %s", path, e)
71
+ return None
72
+ try:
73
+ with open(local, "r") as f:
74
+ line = f.readline().strip()
75
+ if not line:
76
+ return None
77
+ row = json.loads(line)
78
+ # Session uploader stores messages/events/tools as JSON strings.
79
+ for key in ("messages", "events", "tools"):
80
+ v = row.get(key)
81
+ if isinstance(v, str):
82
+ try:
83
+ row[key] = json.loads(v)
84
+ except Exception:
85
+ row[key] = []
86
+ return row
87
+ except Exception as e:
88
+ logger.warning("parse(%s) failed: %s", path, e)
89
+ return None
90
+
91
+
92
+ def _reshape_to_sft(row: dict) -> dict:
93
+ """Raw passthrough: reshape one session row into SFT format + tags.
94
+
95
+ Trajectories predating the ``tools`` addition to ``get_trajectory`` will
96
+ have an empty tools list — still valid, just less useful downstream.
97
+ """
98
+ trajectory = {
99
+ "events": row.get("events") or [],
100
+ "messages": row.get("messages") or [],
101
+ "model_name": row.get("model_name"),
102
+ }
103
+ return {
104
+ "session_id": row.get("session_id"),
105
+ "model": row.get("model_name"),
106
+ "timestamp": row.get("session_start_time"),
107
+ "tags": tag_session(trajectory),
108
+ "messages": row.get("messages") or [],
109
+ "tools": row.get("tools") or [],
110
+ }
111
+
112
+
113
+ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None:
114
+ session_id = row["session_id"]
115
+ path_in_repo = f"sft/{day.isoformat()}/{session_id}.jsonl"
116
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
117
+ json.dump(row, tmp, ensure_ascii=False)
118
+ tmp_path = tmp.name
119
+ try:
120
+ api.create_repo(
121
+ repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
122
+ )
123
+ api.upload_file(
124
+ path_or_fileobj=tmp_path,
125
+ path_in_repo=path_in_repo,
126
+ repo_id=target_repo,
127
+ repo_type="dataset",
128
+ token=token,
129
+ commit_message=f"Add SFT row {session_id}",
130
+ )
131
+ finally:
132
+ try:
133
+ os.unlink(tmp_path)
134
+ except Exception:
135
+ pass
136
+
137
+
138
+ def run_for_day(
139
+ api, source_repo: str, target_repo: str, day: date, token: str,
140
+ ) -> int:
141
+ paths = _iter_session_files(api, source_repo, day, token)
142
+ n = 0
143
+ for path in paths:
144
+ sess = _download_and_parse(source_repo, path, token)
145
+ if not sess:
146
+ continue
147
+ sft_row = _reshape_to_sft(sess)
148
+ if not sft_row.get("session_id"):
149
+ continue
150
+ try:
151
+ _upload_row(api, sft_row, day, target_repo, token)
152
+ n += 1
153
+ except Exception as e:
154
+ logger.warning("upload failed for %s: %s", sft_row["session_id"], e)
155
+ logger.info("Exported %d sessions for %s", n, day)
156
+ return n
157
+
158
+
159
+ def main(argv: list[str] | None = None) -> int:
160
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
161
+ ap = argparse.ArgumentParser()
162
+ ap.add_argument("--source", default="smolagents/ml-intern-sessions")
163
+ ap.add_argument("--target", default="smolagents/ml-intern-sft")
164
+ ap.add_argument(
165
+ "--days", type=int, default=1,
166
+ help="Number of trailing days to export (default: 1 = yesterday).",
167
+ )
168
+ ap.add_argument(
169
+ "--date", type=str, default=None,
170
+ help="Single YYYY-MM-DD to export; overrides --days.",
171
+ )
172
+ args = ap.parse_args(argv)
173
+
174
+ token = (
175
+ os.environ.get("HF_SFT_WRITE_TOKEN")
176
+ or os.environ.get("HF_SESSION_UPLOAD_TOKEN")
177
+ or os.environ.get("HF_TOKEN")
178
+ or os.environ.get("HF_ADMIN_TOKEN")
179
+ )
180
+ if not token:
181
+ logger.error(
182
+ "No HF token found. Set one of: HF_SFT_WRITE_TOKEN, "
183
+ "HF_SESSION_UPLOAD_TOKEN, HF_TOKEN, HF_ADMIN_TOKEN."
184
+ )
185
+ return 1
186
+
187
+ from huggingface_hub import HfApi
188
+ api = HfApi()
189
+
190
+ if args.date:
191
+ target_days = [date.fromisoformat(args.date)]
192
+ else:
193
+ today = datetime.now(timezone.utc).date()
194
+ target_days = [today - timedelta(days=i) for i in range(1, args.days + 1)]
195
+
196
+ total = 0
197
+ for day in target_days:
198
+ total += run_for_day(api, args.source, args.target, day, token)
199
+ logger.info("Total exported: %d sessions", total)
200
+ return 0
201
+
202
+
203
+ if __name__ == "__main__":
204
+ sys.exit(main())
tests/unit/test_build_kpis.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the KPI rollup math.
2
+
3
+ We exercise the pure functions (``_session_metrics`` and ``_aggregate_day``)
4
+ on hand-crafted session trajectories — no network, no HF Hub.
5
+ """
6
+
7
+ import importlib.util
8
+ import sys
9
+ from pathlib import Path
10
+
11
+
12
+ def _load():
13
+ """Load ``scripts/build_kpis.py`` without treating ``scripts`` as a package."""
14
+ path = Path(__file__).parent.parent.parent / "scripts" / "build_kpis.py"
15
+ spec = importlib.util.spec_from_file_location("build_kpis", path)
16
+ mod = importlib.util.module_from_spec(spec)
17
+ sys.modules["build_kpis"] = mod
18
+ spec.loader.exec_module(mod) # type: ignore
19
+ return mod
20
+
21
+
22
+ def _ev(event_type, data=None, ts="2026-04-24T10:00:00"):
23
+ return {"timestamp": ts, "event_type": event_type, "data": data or {}}
24
+
25
+
26
+ def _session(events, user_id="u1", start="2026-04-24T09:59:00"):
27
+ return {
28
+ "session_id": "sess-" + user_id,
29
+ "session_start_time": start,
30
+ "session_end_time": "2026-04-24T10:05:00",
31
+ "model_name": "claude-opus-4-6",
32
+ "messages": [{"role": "user", "content": "hi"}],
33
+ "events": events,
34
+ "user_id": user_id,
35
+ }
36
+
37
+
38
+ def test_llm_call_accumulates_tokens_and_cost():
39
+ mod = _load()
40
+ events = [
41
+ _ev("llm_call", {
42
+ "prompt_tokens": 100, "completion_tokens": 50,
43
+ "cache_read_tokens": 40, "cache_creation_tokens": 10,
44
+ "cost_usd": 0.01,
45
+ }),
46
+ _ev("llm_call", {
47
+ "prompt_tokens": 200, "completion_tokens": 100,
48
+ "cache_read_tokens": 80, "cost_usd": 0.02,
49
+ }),
50
+ ]
51
+ m = mod._session_metrics(_session(events))
52
+ assert m["llm_calls"] == 2
53
+ assert m["tokens_prompt"] == 300
54
+ assert m["tokens_completion"] == 150
55
+ assert m["tokens_cache_read"] == 120
56
+ assert m["tokens_cache_creation"] == 10
57
+ assert abs(m["cost_usd"] - 0.03) < 1e-9
58
+
59
+
60
+ def test_tool_success_rate_and_first_action():
61
+ mod = _load()
62
+ events = [
63
+ _ev("tool_call", {"tool": "bash"}, ts="2026-04-24T10:00:05"),
64
+ _ev("tool_output", {"success": True}),
65
+ _ev("tool_output", {"success": False}),
66
+ ]
67
+ m = mod._session_metrics(_session(events))
68
+ assert m["tool_calls_total"] == 2
69
+ assert m["tool_calls_success"] == 1
70
+ # 65s from start to first action
71
+ assert m["first_tool_s"] == 65
72
+
73
+
74
+ def test_hf_job_gpu_hours():
75
+ mod = _load()
76
+ events = [
77
+ _ev("hf_job_submit", {"flavor": "a100-large", "job_id": "j1"}),
78
+ _ev("hf_job_complete", {
79
+ "flavor": "a100-large",
80
+ "final_status": "COMPLETED",
81
+ "wall_time_s": 3600,
82
+ }),
83
+ ]
84
+ m = mod._session_metrics(_session(events))
85
+ assert m["hf_jobs_submitted"] == 1
86
+ assert m["hf_jobs_succeeded"] == 1
87
+ # a100-large = 1 gpu * 1 hour = 1 gpu-hour
88
+ assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
89
+
90
+
91
+ def test_feedback_counts():
92
+ mod = _load()
93
+ events = [
94
+ _ev("feedback", {"rating": "up"}),
95
+ _ev("feedback", {"rating": "up"}),
96
+ _ev("feedback", {"rating": "down"}),
97
+ ]
98
+ m = mod._session_metrics(_session(events))
99
+ assert m["thumbs_up"] == 2
100
+ assert m["thumbs_down"] == 1
101
+
102
+
103
+ def test_aggregate_day_cache_hit_and_users():
104
+ mod = _load()
105
+ s1 = mod._session_metrics(_session(
106
+ [_ev("llm_call", {"prompt_tokens": 100, "cache_read_tokens": 400, "cost_usd": 0.5})],
107
+ user_id="u1",
108
+ ))
109
+ s2 = mod._session_metrics(_session(
110
+ [_ev("llm_call", {"prompt_tokens": 200, "cache_read_tokens": 100, "cost_usd": 1.0})],
111
+ user_id="u2",
112
+ ))
113
+ row = mod._aggregate_day([s1, s2])
114
+ assert row["sessions"] == 2
115
+ assert row["users"] == 2
116
+ assert row["tokens_prompt"] == 300
117
+ assert row["tokens_cache_read"] == 500
118
+ # 500 / (500 + 300) = 0.625
119
+ assert abs(row["cache_hit_ratio"] - 0.625) < 1e-9
120
+ assert abs(row["cost_usd"] - 1.5) < 1e-9
121
+
122
+
123
+ def test_failure_and_regenerate_rates():
124
+ mod = _load()
125
+ s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
126
+ s2 = mod._session_metrics(_session([_ev("undo_complete")], user_id="b"))
127
+ s3 = mod._session_metrics(_session([], user_id="c"))
128
+ row = mod._aggregate_day([s1, s2, s3])
129
+ assert row["failure_rate"] == round(1 / 3, 4)
130
+ assert row["regenerate_rate"] == round(1 / 3, 4)
131
+
132
+
133
+ def test_window_filter_keeps_only_events_in_range():
134
+ from datetime import datetime, timezone
135
+ mod = _load()
136
+ events = [
137
+ _ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00"),
138
+ _ev("llm_call", {"prompt_tokens": 200}, ts="2026-04-24T10:05:00"),
139
+ _ev("tool_call", {"tool": "bash"}, ts="2026-04-24T10:30:00"),
140
+ _ev("llm_call", {"prompt_tokens": 400}, ts="2026-04-24T11:10:00"),
141
+ ]
142
+ session = _session(events, start="2026-04-24T09:44:00")
143
+ # Only events in [10:00, 11:00) should remain.
144
+ window_start = datetime(2026, 4, 24, 10, 0, 0, tzinfo=timezone.utc)
145
+ window_end = datetime(2026, 4, 24, 11, 0, 0, tzinfo=timezone.utc)
146
+ windowed = mod._filter_session_to_window(session, window_start, window_end)
147
+ assert windowed is not None
148
+ types = [e["event_type"] for e in windowed["events"]]
149
+ assert types == ["llm_call", "tool_call"]
150
+ # Metrics only reflect in-window events.
151
+ m = mod._session_metrics(windowed)
152
+ assert m["tokens_prompt"] == 200
153
+ assert m["llm_calls"] == 1
154
+ assert m["tool_calls_total"] == 0 # tool_call not tool_output
155
+
156
+
157
+ def test_window_filter_returns_none_when_nothing_in_range():
158
+ from datetime import datetime, timezone
159
+ mod = _load()
160
+ events = [_ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00")]
161
+ session = _session(events)
162
+ window_start = datetime(2026, 4, 24, 10, 0, 0, tzinfo=timezone.utc)
163
+ window_end = datetime(2026, 4, 24, 11, 0, 0, tzinfo=timezone.utc)
164
+ assert mod._filter_session_to_window(session, window_start, window_end) is None
tests/unit/test_build_sft.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke test for the SFT reshape — raw passthrough with tags attached."""
2
+
3
+ import importlib.util
4
+ import sys
5
+ from pathlib import Path
6
+
7
+
8
+ def _load():
9
+ path = Path(__file__).parent.parent.parent / "scripts" / "build_sft.py"
10
+ spec = importlib.util.spec_from_file_location("build_sft", path)
11
+ mod = importlib.util.module_from_spec(spec)
12
+ sys.modules["build_sft"] = mod
13
+ spec.loader.exec_module(mod) # type: ignore
14
+ return mod
15
+
16
+
17
+ def _session_row():
18
+ return {
19
+ "session_id": "abc",
20
+ "session_start_time": "2026-04-24T10:00:00",
21
+ "session_end_time": "2026-04-24T10:05:00",
22
+ "model_name": "claude-opus-4-6",
23
+ "messages": [
24
+ {"role": "system", "content": "You are an agent"},
25
+ {"role": "user", "content": "fine-tune llama"},
26
+ {"role": "assistant", "content": None, "tool_calls": [
27
+ {"id": "c1", "type": "function",
28
+ "function": {"name": "hf_jobs", "arguments": '{"script":"from trl import SFTTrainer"}'}},
29
+ ]},
30
+ {"role": "tool", "tool_call_id": "c1", "content": "ok"},
31
+ {"role": "assistant", "content": "done"},
32
+ ],
33
+ "events": [
34
+ {"timestamp": "2026-04-24T10:00:05", "event_type": "tool_call",
35
+ "data": {"tool": "hf_jobs",
36
+ "arguments": {"script": "from trl import SFTTrainer"}}},
37
+ {"timestamp": "2026-04-24T10:00:06", "event_type": "hf_job_submit",
38
+ "data": {"flavor": "a100-large", "push_to_hub": True}},
39
+ {"timestamp": "2026-04-24T10:45:00", "event_type": "hf_job_complete",
40
+ "data": {"flavor": "a100-large", "final_status": "COMPLETED",
41
+ "wall_time_s": 2700}},
42
+ {"timestamp": "2026-04-24T10:45:05", "event_type": "turn_complete",
43
+ "data": {}},
44
+ ],
45
+ "tools": [{"type": "function", "function": {"name": "hf_jobs"}}],
46
+ }
47
+
48
+
49
+ def test_reshape_preserves_messages_and_tools_and_adds_tags():
50
+ mod = _load()
51
+ row = mod._reshape_to_sft(_session_row())
52
+ assert row["session_id"] == "abc"
53
+ assert row["model"] == "claude-opus-4-6"
54
+ assert row["timestamp"] == "2026-04-24T10:00:00"
55
+ # Messages preserved verbatim, in order, with tool_calls + tool role rows.
56
+ assert len(row["messages"]) == 5
57
+ assert row["messages"][2]["tool_calls"][0]["function"]["name"] == "hf_jobs"
58
+ assert row["messages"][3]["role"] == "tool"
59
+ # Tools preserved verbatim.
60
+ assert row["tools"] == [{"type": "function", "function": {"name": "hf_jobs"}}]
61
+ # Tags include the expected signals.
62
+ tags = set(row["tags"])
63
+ assert "tool:hf_jobs" in tags
64
+ assert "hf_job:succeeded" in tags
65
+ assert "hf_job:push_to_hub" in tags
66
+ assert "gpu:a100" in tags
67
+ assert "outcome:completed" in tags
68
+ assert "task:training" in tags
69
+ assert "model:opus" in tags
70
+
71
+
72
+ def test_reshape_handles_missing_tools_field():
73
+ mod = _load()
74
+ row = _session_row()
75
+ del row["tools"]
76
+ out = mod._reshape_to_sft(row)
77
+ assert out["tools"] == []
78
+ assert isinstance(out["tags"], list) # still computes tags
tests/unit/test_heartbeat.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Heartbeat + stable-local-path tests for Session.
2
+
3
+ We don't spin up the real agent loop — we build a minimal Session with a
4
+ stubbed config and an in-memory queue, then call send_event repeatedly while
5
+ monkeypatching time.monotonic to simulate seconds passing.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ from pathlib import Path
11
+ from unittest.mock import patch
12
+
13
+ import pytest
14
+
15
+ from agent.core.session import Event, Session
16
+
17
+
18
+ class _FakeConfig:
19
+ model_name = "claude-opus-4-6"
20
+ save_sessions = True
21
+ session_dataset_repo = "fake/repo"
22
+ auto_save_interval = 1
23
+ heartbeat_interval_s = 60
24
+ max_iterations = 10
25
+ yolo_mode = False
26
+ confirm_cpu_jobs = False
27
+ auto_file_upload = False
28
+ reasoning_effort = None
29
+ mcpServers: dict = {}
30
+
31
+
32
+ def _mk_session(tmp_path: Path) -> Session:
33
+ import os
34
+ os.chdir(tmp_path) # so session_logs/ lands under tmp_path
35
+ # Stub out the context manager to avoid litellm lookups.
36
+ from agent.context_manager.manager import ContextManager
37
+ cm = ContextManager.__new__(ContextManager)
38
+ cm.items = []
39
+ cm.tool_specs = []
40
+ cm.model_max_tokens = 200_000
41
+ cm.running_context_usage = 0
42
+ cm.compact_size = 0.1
43
+ cm.untouched_messages = 5
44
+ cm.hf_token = None
45
+ cm.local_mode = True
46
+ s = Session(
47
+ event_queue=asyncio.Queue(),
48
+ config=_FakeConfig(),
49
+ tool_router=None,
50
+ context_manager=cm,
51
+ hf_token=None,
52
+ local_mode=True,
53
+ )
54
+ return s
55
+
56
+
57
+ def test_heartbeat_fires_after_interval(tmp_path, monkeypatch):
58
+ # Use asyncio.run rather than pytest-asyncio so the test works without the
59
+ # plugin installed (same pattern elsewhere in this repo).
60
+ async def body():
61
+ s = _mk_session(tmp_path)
62
+ calls = []
63
+
64
+ def fake_upload(repo_id):
65
+ calls.append(repo_id)
66
+ return "fake/path.json"
67
+
68
+ monkeypatch.setattr(s, "save_and_upload_detached", fake_upload)
69
+
70
+ # t=0: first event, should NOT trigger (initial _last_heartbeat_ts = now)
71
+ with patch("agent.core.telemetry.time.monotonic", return_value=100.0):
72
+ s._last_heartbeat_ts = 100.0
73
+ await s.send_event(Event(event_type="x"))
74
+ assert calls == []
75
+
76
+ # t=+30s: still under interval → no save
77
+ with patch("agent.core.telemetry.time.monotonic", return_value=130.0):
78
+ await s.send_event(Event(event_type="y"))
79
+ assert calls == []
80
+
81
+ # t=+61s: over 60s → save fires once
82
+ with patch("agent.core.telemetry.time.monotonic", return_value=161.0):
83
+ await s.send_event(Event(event_type="z"))
84
+ # create_task runs on the event loop; wait for the to_thread to complete
85
+ await asyncio.sleep(0.05)
86
+ assert calls == ["fake/repo"]
87
+
88
+ # Next event shortly after → no second save (interval resets to 161)
89
+ with patch("agent.core.telemetry.time.monotonic", return_value=170.0):
90
+ await s.send_event(Event(event_type="w"))
91
+ await asyncio.sleep(0.05)
92
+ assert len(calls) == 1
93
+
94
+ asyncio.run(body())
95
+
96
+
97
+ def test_stable_local_path_overwrites(tmp_path):
98
+ import os
99
+ os.chdir(tmp_path)
100
+ from agent.context_manager.manager import ContextManager
101
+ cm = ContextManager.__new__(ContextManager)
102
+ cm.items = []
103
+ cm.tool_specs = []
104
+ cm.model_max_tokens = 200_000
105
+ cm.running_context_usage = 0
106
+ cm.compact_size = 0.1
107
+ cm.untouched_messages = 5
108
+ cm.hf_token = None
109
+ cm.local_mode = True
110
+
111
+ s = Session(
112
+ event_queue=asyncio.Queue(),
113
+ config=_FakeConfig(),
114
+ tool_router=None,
115
+ context_manager=cm,
116
+ hf_token=None,
117
+ local_mode=True,
118
+ )
119
+
120
+ p1 = s.save_trajectory_local(directory="session_logs")
121
+ assert p1 is not None
122
+ p2 = s.save_trajectory_local(directory="session_logs")
123
+ p3 = s.save_trajectory_local(directory="session_logs")
124
+ # All three saves land on the same file — heartbeat should not spam files.
125
+ assert p1 == p2 == p3
126
+ files = list(Path("session_logs").glob("session_*.json"))
127
+ # Exactly one final file; the .tmp should be renamed away.
128
+ assert len(files) == 1
129
+
130
+ # File is valid JSON (atomic write → no torn content).
131
+ with open(p1) as f:
132
+ data = json.load(f)
133
+ assert data["session_id"] == s.session_id
134
+ assert data["upload_status"] == "pending"
tests/unit/test_kpis_scheduler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke tests for backend/kpis_scheduler.py.
2
+
3
+ Exercise the pure / fast paths only:
4
+ * token resolution order
5
+ * build_kpis import path
6
+ * start()/shutdown() lifecycle without APScheduler actually running a job
7
+ * backfill() passes the right hour values through to _run_hour
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import importlib.util
14
+ import sys
15
+ from datetime import datetime, timezone
16
+ from pathlib import Path
17
+
18
+
19
+ def _load():
20
+ path = Path(__file__).parent.parent.parent / "backend" / "kpis_scheduler.py"
21
+ spec = importlib.util.spec_from_file_location("kpis_scheduler", path)
22
+ mod = importlib.util.module_from_spec(spec)
23
+ sys.modules["kpis_scheduler"] = mod
24
+ assert spec.loader is not None
25
+ spec.loader.exec_module(mod)
26
+ return mod
27
+
28
+
29
+ def test_token_resolution_order(monkeypatch):
30
+ mod = _load()
31
+ for var in ("HF_KPI_WRITE_TOKEN", "HF_SESSION_UPLOAD_TOKEN", "HF_TOKEN", "HF_ADMIN_TOKEN"):
32
+ monkeypatch.delenv(var, raising=False)
33
+ assert mod._resolve_token() is None
34
+
35
+ monkeypatch.setenv("HF_ADMIN_TOKEN", "admin")
36
+ assert mod._resolve_token() == "admin"
37
+
38
+ monkeypatch.setenv("HF_TOKEN", "generic")
39
+ assert mod._resolve_token() == "generic"
40
+
41
+ monkeypatch.setenv("HF_SESSION_UPLOAD_TOKEN", "sessions")
42
+ assert mod._resolve_token() == "sessions"
43
+
44
+ monkeypatch.setenv("HF_KPI_WRITE_TOKEN", "kpis")
45
+ assert mod._resolve_token() == "kpis"
46
+
47
+
48
+ def test_load_build_kpis_exposes_run_for_hour():
49
+ mod = _load()
50
+ bk = mod._load_build_kpis()
51
+ assert hasattr(bk, "run_for_hour")
52
+ assert callable(bk.run_for_hour)
53
+
54
+
55
+ def test_backfill_calls_run_hour_for_each_hour(monkeypatch):
56
+ mod = _load()
57
+ monkeypatch.setenv("HF_KPI_WRITE_TOKEN", "x")
58
+ calls: list[datetime] = []
59
+
60
+ async def fake_run_hour(hour_dt):
61
+ calls.append(hour_dt)
62
+
63
+ monkeypatch.setattr(mod, "_run_hour", fake_run_hour)
64
+ asyncio.run(mod.backfill(hours=3))
65
+ assert len(calls) == 3
66
+ # Hours are returned most-recent-first
67
+ assert calls[0] > calls[1] > calls[2]
68
+ # All aligned to the top of the hour
69
+ for c in calls:
70
+ assert c.minute == 0 and c.second == 0 and c.microsecond == 0
71
+ assert c.tzinfo == timezone.utc
72
+
73
+
74
+ def test_start_is_no_op_when_disabled(monkeypatch):
75
+ mod = _load()
76
+ # Ensure clean state — _scheduler is module-global
77
+ mod._scheduler = None
78
+ monkeypatch.setenv("ML_INTERN_KPIS_DISABLED", "1")
79
+ mod.start()
80
+ assert mod._scheduler is None # never instantiated
81
+
82
+
83
+ def test_start_skips_cleanly_without_apscheduler(monkeypatch):
84
+ mod = _load()
85
+ mod._scheduler = None
86
+ monkeypatch.delenv("ML_INTERN_KPIS_DISABLED", raising=False)
87
+
88
+ # Force the apscheduler import to fail — start() should log and return.
89
+ real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
90
+
91
+ def fake_import(name, *args, **kwargs):
92
+ if name.startswith("apscheduler"):
93
+ raise ImportError("apscheduler unavailable in test")
94
+ return real_import(name, *args, **kwargs)
95
+
96
+ monkeypatch.setattr(
97
+ "builtins.__import__",
98
+ fake_import,
99
+ )
100
+ mod.start() # should not raise
101
+ assert mod._scheduler is None
102
+
103
+
104
+ def test_shutdown_is_no_op_when_not_started():
105
+ mod = _load()
106
+ mod._scheduler = None
107
+ asyncio.run(mod.shutdown()) # must not raise
tests/unit/test_redact.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the secret scrubber used before session upload."""
2
+
3
+ from agent.core.redact import scrub, scrub_string
4
+
5
+
6
+ def test_hf_token():
7
+ s = "here is a token hf_" + "A" * 35 + " ok"
8
+ out = scrub_string(s)
9
+ assert "hf_" not in out
10
+ assert "[REDACTED_HF_TOKEN]" in out
11
+
12
+
13
+ def test_anthropic_key():
14
+ s = "key=sk-ant-api03_" + "a" * 40
15
+ out = scrub_string(s)
16
+ # The env-var name prefix matches too; just verify we don't leave the body.
17
+ assert "sk-ant-api03_" not in out
18
+
19
+
20
+ def test_github_token():
21
+ s = "ghp_" + "a" * 40
22
+ out = scrub_string(s)
23
+ assert out == "[REDACTED_GITHUB_TOKEN]"
24
+
25
+
26
+ def test_github_fine_grained_pat():
27
+ # Fine-grained PATs: github_pat_<alphanumeric + underscore>, 36+ chars
28
+ s = "github_pat_" + "A1B2_" * 10
29
+ out = scrub_string(s)
30
+ assert "github_pat_" not in out
31
+ assert "[REDACTED_GITHUB_TOKEN]" in out
32
+
33
+
34
+ def test_aws_key_id():
35
+ s = "AWS_ACCESS_KEY_ID=AKIAABCDEFGHIJKLMNOP"
36
+ out = scrub_string(s)
37
+ assert "AKIAABCDEFGHIJKLMNOP" not in out
38
+
39
+
40
+ def test_bearer_header():
41
+ s = "Authorization: Bearer abcdef0123456789abcdef0123456789"
42
+ out = scrub_string(s)
43
+ assert "abcdef0123456789abcdef0123456789" not in out
44
+ assert "Bearer [REDACTED]" in out
45
+
46
+
47
+ def test_env_var_style():
48
+ s = "HF_TOKEN=hf_" + "x" * 40 + " run"
49
+ out = scrub_string(s)
50
+ # Either the value-scrubber or the HF-token regex should fire.
51
+ assert "hf_xxxx" not in out
52
+
53
+
54
+ def test_scrub_nested_dict_and_list():
55
+ payload = {
56
+ "msg": "token hf_" + "Z" * 35,
57
+ "tools": [
58
+ {"args": {"secret": "ghp_" + "Q" * 40}},
59
+ "no secrets here",
60
+ ],
61
+ "n": 42,
62
+ }
63
+ out = scrub(payload)
64
+ # Original not mutated
65
+ assert "hf_" in payload["msg"]
66
+ # Redacted copy
67
+ assert "[REDACTED_HF_TOKEN]" in out["msg"]
68
+ assert out["tools"][0]["args"]["secret"] == "[REDACTED_GITHUB_TOKEN]"
69
+ assert out["tools"][1] == "no secrets here"
70
+ assert out["n"] == 42
71
+
72
+
73
+ def test_scrub_preserves_non_strings():
74
+ assert scrub(None) is None
75
+ assert scrub(123) == 123
76
+ assert scrub(True) is True
tests/unit/test_sft_tagger.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for agent.sft.tagger — one test per tag namespace."""
2
+
3
+ from agent.sft.tagger import tag_session
4
+
5
+
6
+ def _ev(event_type, data=None, ts="2026-04-24T10:00:00"):
7
+ return {"timestamp": ts, "event_type": event_type, "data": data or {}}
8
+
9
+
10
+ def _traj(events=None, messages=None, model="claude-opus-4-6"):
11
+ return {
12
+ "session_id": "sess-1",
13
+ "model_name": model,
14
+ "session_start_time": "2026-04-24T09:59:00",
15
+ "session_end_time": "2026-04-24T10:05:00",
16
+ "messages": messages
17
+ or [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "ok"}],
18
+ "events": events or [],
19
+ }
20
+
21
+
22
+ def test_model_family():
23
+ assert "model:opus" in tag_session(_traj(model="claude-opus-4-6"))
24
+ assert "model:sonnet" in tag_session(_traj(model="bedrock/claude-sonnet-4-5"))
25
+ assert "model:kimi" in tag_session(_traj(model="moonshotai/Kimi-K2.6"))
26
+ assert "model:other" in tag_session(_traj(model="unknown-model-xyz"))
27
+
28
+
29
+ def test_turns_buckets():
30
+ short = _traj(messages=[{"role": "user", "content": "hi"}])
31
+ medium = _traj(messages=[{"role": "user", "content": "q"} for _ in range(10)])
32
+ long = _traj(messages=[{"role": "user", "content": "q"} for _ in range(25)])
33
+ assert "turns:short" in tag_session(short)
34
+ assert "turns:medium" in tag_session(medium)
35
+ assert "turns:long" in tag_session(long)
36
+
37
+
38
+ def test_cost_buckets():
39
+ cheap = _traj(events=[_ev("llm_call", {"cost_usd": 0.05})])
40
+ med = _traj(events=[_ev("llm_call", {"cost_usd": 0.5})])
41
+ expensive = _traj(events=[_ev("llm_call", {"cost_usd": 5.0})])
42
+ assert "cost:low" in tag_session(cheap)
43
+ assert "cost:med" in tag_session(med)
44
+ assert "cost:high" in tag_session(expensive)
45
+
46
+
47
+ def test_tool_tags():
48
+ events = [
49
+ _ev("tool_call", {"tool": "hf_jobs", "arguments": {}}),
50
+ _ev("tool_call", {"tool": "research"}),
51
+ _ev("tool_call", {"tool": "bash"}),
52
+ ]
53
+ tags = tag_session(_traj(events))
54
+ assert "tool:hf_jobs" in tags
55
+ assert "tool:research" in tags
56
+ assert "tool:bash" in tags
57
+
58
+
59
+ def test_outcome_completed():
60
+ events = [_ev("turn_complete", {"history_size": 10})]
61
+ assert "outcome:completed" in tag_session(_traj(events))
62
+
63
+
64
+ def test_outcome_errored():
65
+ events = [_ev("error", {"error": "boom"})]
66
+ assert "outcome:errored" in tag_session(_traj(events))
67
+
68
+
69
+ def test_outcome_interrupted():
70
+ events = [_ev("interrupted")]
71
+ assert "outcome:interrupted" in tag_session(_traj(events))
72
+
73
+
74
+ def test_outcome_ongoing():
75
+ # No terminal events → session was still running at save time
76
+ events = [_ev("llm_call", {"cost_usd": 0.01})]
77
+ assert "outcome:ongoing" in tag_session(_traj(events))
78
+
79
+
80
+ def test_outcome_doom_loop_and_context():
81
+ events = [
82
+ _ev("tool_log", {"tool": "system", "log": "Doom loop detected — injecting corrective prompt"}),
83
+ _ev("compacted", {"old_tokens": 100, "new_tokens": 50}),
84
+ _ev("turn_complete", {"history_size": 10}),
85
+ ]
86
+ tags = tag_session(_traj(events))
87
+ assert "outcome:doom_loop" in tags
88
+ assert "outcome:context_exceeded" in tags
89
+
90
+
91
+ def test_hf_job_tags():
92
+ events = [
93
+ _ev("tool_call", {"tool": "hf_jobs", "arguments": {"script": "from trl import SFTTrainer"}}),
94
+ _ev("hf_job_submit", {
95
+ "flavor": "a100-large", "push_to_hub": True, "job_id": "j1",
96
+ }),
97
+ _ev("hf_job_complete", {"flavor": "a100-large", "final_status": "COMPLETED", "wall_time_s": 3600}),
98
+ _ev("hf_job_submit", {"flavor": "a100x4", "push_to_hub": False}),
99
+ _ev("hf_job_complete", {"flavor": "a100x4", "final_status": "FAILED"}),
100
+ ]
101
+ tags = tag_session(_traj(events))
102
+ assert "hf_job:submitted" in tags
103
+ assert "hf_job:multi" in tags
104
+ assert "hf_job:succeeded" in tags
105
+ assert "hf_job:failed" in tags
106
+ assert "hf_job:push_to_hub" in tags
107
+ assert "gpu:a100" in tags
108
+ assert "gpu:multi" in tags
109
+
110
+
111
+ def test_hf_job_oom():
112
+ events = [
113
+ _ev("tool_call", {"tool": "hf_jobs", "arguments": {}}),
114
+ _ev("hf_job_submit", {"flavor": "a100-large"}),
115
+ _ev("tool_output", {"success": False, "output": "RuntimeError: CUDA out of memory. Tried to allocate..."}),
116
+ ]
117
+ tags = tag_session(_traj(events))
118
+ assert "hf_job:oom" in tags
119
+
120
+
121
+ def test_sandbox_tags():
122
+ events = [
123
+ _ev("sandbox_create", {"hardware": "t4-small", "sandbox_id": "s1", "create_latency_s": 5}),
124
+ _ev("sandbox_destroy", {"sandbox_id": "s1", "lifetime_s": 3600}),
125
+ ]
126
+ tags = tag_session(_traj(events))
127
+ assert "sandbox:created" in tags
128
+ assert "sandbox:gpu" in tags
129
+ assert "sandbox:long_lived" in tags
130
+
131
+
132
+ def test_sandbox_cpu_short():
133
+ events = [
134
+ _ev("sandbox_create", {"hardware": "cpu-basic"}),
135
+ _ev("sandbox_destroy", {"lifetime_s": 120}),
136
+ ]
137
+ tags = tag_session(_traj(events))
138
+ assert "sandbox:cpu" in tags
139
+ assert "sandbox:long_lived" not in tags
140
+
141
+
142
+ def test_feedback_tags():
143
+ up_only = _traj(events=[_ev("feedback", {"rating": "up"})])
144
+ down_only = _traj(events=[_ev("feedback", {"rating": "down"})])
145
+ mixed = _traj(events=[_ev("feedback", {"rating": "up"}), _ev("feedback", {"rating": "down"})])
146
+ none = _traj()
147
+ assert "feedback:up" in tag_session(up_only)
148
+ assert "feedback:down" in tag_session(down_only)
149
+ assert "feedback:mixed" in tag_session(mixed)
150
+ assert "feedback:none" in tag_session(none)
151
+
152
+
153
+ def test_task_training():
154
+ events = [
155
+ _ev("tool_call", {"tool": "hf_jobs", "arguments": {
156
+ "script": "from trl import SFTTrainer\ntrainer = SFTTrainer(...)"
157
+ }}),
158
+ _ev("hf_job_submit", {"flavor": "a100-large"}),
159
+ ]
160
+ assert "task:training" in tag_session(_traj(events))
161
+
162
+
163
+ def test_task_research_only():
164
+ events = [
165
+ _ev("tool_call", {"tool": "research"}),
166
+ _ev("tool_call", {"tool": "explore_hf_docs"}),
167
+ ]
168
+ assert "task:research_only" in tag_session(_traj(events))
169
+
170
+
171
+ def test_task_data_prep():
172
+ events = [
173
+ _ev("tool_call", {"tool": "hf_inspect_dataset", "arguments": {}}),
174
+ _ev("tool_call", {"tool": "hub_repo_details"}),
175
+ ]
176
+ tags = tag_session(_traj(events))
177
+ assert "task:data_prep" in tags
178
+
179
+
180
+ def test_no_duplicates_and_sorted():
181
+ events = [
182
+ _ev("tool_call", {"tool": "hf_jobs"}),
183
+ _ev("tool_call", {"tool": "hf_jobs"}), # duplicate
184
+ _ev("hf_job_submit", {"flavor": "a10g-small"}),
185
+ _ev("hf_job_submit", {"flavor": "a10g-small"}),
186
+ ]
187
+ tags = tag_session(_traj(events))
188
+ assert tags == sorted(tags)
189
+ assert len(tags) == len(set(tags))
190
+
191
+
192
+ def test_empty_trajectory_has_required_tags():
193
+ tags = tag_session(_traj())
194
+ namespaces = {t.split(":", 1)[0] for t in tags}
195
+ # Every session must have at least model/turns/cost/outcome/feedback.
196
+ for required in ("model", "turns", "cost", "outcome", "feedback"):
197
+ assert required in namespaces, f"missing {required} — got {tags}"