Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Aksel Joonas Reedi commited on
feat(observability)
Browse files- agent/config.py +7 -2
- agent/core/agent_loop.py +27 -1
- agent/core/redact.py +68 -0
- agent/core/session.py +46 -4
- agent/core/session_uploader.py +34 -5
- agent/core/telemetry.py +289 -0
- agent/sft/__init__.py +0 -0
- agent/sft/tagger.py +324 -0
- agent/tools/jobs_tool.py +21 -2
- agent/tools/sandbox_tool.py +9 -0
- backend/kpis_scheduler.py +146 -0
- backend/main.py +29 -0
- backend/routes/agent.py +38 -0
- backend/session_manager.py +14 -2
- configs/main_agent_config.json +1 -1
- frontend/src/components/Chat/AssistantMessage.tsx +46 -3
- frontend/src/components/Chat/MessageBubble.tsx +3 -0
- frontend/src/components/Chat/MessageList.tsx +3 -1
- frontend/src/components/SessionChat.tsx +1 -0
- pyproject.toml +1 -0
- scripts/build_kpis.py +517 -0
- scripts/build_sft.py +204 -0
- tests/unit/test_build_kpis.py +164 -0
- tests/unit/test_build_sft.py +78 -0
- tests/unit/test_heartbeat.py +134 -0
- tests/unit/test_kpis_scheduler.py +107 -0
- tests/unit/test_redact.py +76 -0
- tests/unit/test_sft_tagger.py +197 -0
agent/config.py
CHANGED
|
@@ -24,8 +24,13 @@ class Config(BaseModel):
|
|
| 24 |
model_name: str
|
| 25 |
mcpServers: dict[str, MCPServerConfig] = {}
|
| 26 |
save_sessions: bool = True
|
| 27 |
-
session_dataset_repo: str = "
|
| 28 |
-
auto_save_interval: int =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
yolo_mode: bool = False # Auto-approve all tool calls without confirmation
|
| 30 |
max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited)
|
| 31 |
|
|
|
|
| 24 |
model_name: str
|
| 25 |
mcpServers: dict[str, MCPServerConfig] = {}
|
| 26 |
save_sessions: bool = True
|
| 27 |
+
session_dataset_repo: str = "smolagents/ml-intern-sessions"
|
| 28 |
+
auto_save_interval: int = 1 # Save every N user turns (0 = disabled)
|
| 29 |
+
# Mid-turn heartbeat: save + upload every N seconds while events are being
|
| 30 |
+
# emitted. Guards against losing trace data on long-running turns that
|
| 31 |
+
# crash before turn_complete (e.g. a multi-hour hf_jobs wait that OOMs).
|
| 32 |
+
# 0 = disabled. Consumed by agent.core.telemetry.HeartbeatSaver.
|
| 33 |
+
heartbeat_interval_s: int = 60
|
| 34 |
yolo_mode: bool = False # Auto-approve all tool calls without confirmation
|
| 35 |
max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited)
|
| 36 |
|
agent/core/agent_loop.py
CHANGED
|
@@ -6,12 +6,14 @@ import asyncio
|
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
import os
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
from litellm import ChatCompletionMessageToolCall, Message, acompletion
|
| 12 |
from litellm.exceptions import ContextWindowExceededError
|
| 13 |
|
| 14 |
from agent.config import Config
|
|
|
|
| 15 |
from agent.core.doom_loop import check_for_doom_loop
|
| 16 |
from agent.core.llm_params import _resolve_llm_params
|
| 17 |
from agent.core.prompt_caching import with_prompt_caching
|
|
@@ -291,6 +293,7 @@ class LLMResult:
|
|
| 291 |
tool_calls_acc: dict[int, dict]
|
| 292 |
token_count: int
|
| 293 |
finish_reason: str | None
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
|
|
@@ -298,6 +301,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 298 |
response = None
|
| 299 |
_healed_effort = False # one-shot safety net per call
|
| 300 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
|
|
|
| 301 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
| 302 |
try:
|
| 303 |
response = await acompletion(
|
|
@@ -339,6 +343,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 339 |
tool_calls_acc: dict[int, dict] = {}
|
| 340 |
token_count = 0
|
| 341 |
finish_reason = None
|
|
|
|
| 342 |
|
| 343 |
async for chunk in response:
|
| 344 |
if session.is_cancelled:
|
|
@@ -349,6 +354,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 349 |
if not choice:
|
| 350 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 351 |
token_count = chunk.usage.total_tokens
|
|
|
|
| 352 |
continue
|
| 353 |
|
| 354 |
delta = choice.delta
|
|
@@ -379,12 +385,22 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 379 |
|
| 380 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 381 |
token_count = chunk.usage.total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
return LLMResult(
|
| 384 |
content=full_content or None,
|
| 385 |
tool_calls_acc=tool_calls_acc,
|
| 386 |
token_count=token_count,
|
| 387 |
finish_reason=finish_reason,
|
|
|
|
| 388 |
)
|
| 389 |
|
| 390 |
|
|
@@ -393,6 +409,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 393 |
response = None
|
| 394 |
_healed_effort = False
|
| 395 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
|
|
|
| 396 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
| 397 |
try:
|
| 398 |
response = await acompletion(
|
|
@@ -454,11 +471,20 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 454 |
Event(event_type="assistant_message", data={"content": content})
|
| 455 |
)
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
return LLMResult(
|
| 458 |
content=content,
|
| 459 |
tool_calls_acc=tool_calls_acc,
|
| 460 |
token_count=token_count,
|
| 461 |
finish_reason=finish_reason,
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
|
|
|
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
import os
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
|
| 12 |
from litellm import ChatCompletionMessageToolCall, Message, acompletion
|
| 13 |
from litellm.exceptions import ContextWindowExceededError
|
| 14 |
|
| 15 |
from agent.config import Config
|
| 16 |
+
from agent.core import telemetry
|
| 17 |
from agent.core.doom_loop import check_for_doom_loop
|
| 18 |
from agent.core.llm_params import _resolve_llm_params
|
| 19 |
from agent.core.prompt_caching import with_prompt_caching
|
|
|
|
| 293 |
tool_calls_acc: dict[int, dict]
|
| 294 |
token_count: int
|
| 295 |
finish_reason: str | None
|
| 296 |
+
usage: dict = field(default_factory=dict)
|
| 297 |
|
| 298 |
|
| 299 |
async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
|
|
|
|
| 301 |
response = None
|
| 302 |
_healed_effort = False # one-shot safety net per call
|
| 303 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 304 |
+
t_start = time.monotonic()
|
| 305 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
| 306 |
try:
|
| 307 |
response = await acompletion(
|
|
|
|
| 343 |
tool_calls_acc: dict[int, dict] = {}
|
| 344 |
token_count = 0
|
| 345 |
finish_reason = None
|
| 346 |
+
final_usage_chunk = None
|
| 347 |
|
| 348 |
async for chunk in response:
|
| 349 |
if session.is_cancelled:
|
|
|
|
| 354 |
if not choice:
|
| 355 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 356 |
token_count = chunk.usage.total_tokens
|
| 357 |
+
final_usage_chunk = chunk
|
| 358 |
continue
|
| 359 |
|
| 360 |
delta = choice.delta
|
|
|
|
| 385 |
|
| 386 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 387 |
token_count = chunk.usage.total_tokens
|
| 388 |
+
final_usage_chunk = chunk
|
| 389 |
+
|
| 390 |
+
usage = await telemetry.record_llm_call(
|
| 391 |
+
session,
|
| 392 |
+
model=llm_params.get("model", session.config.model_name),
|
| 393 |
+
response=final_usage_chunk,
|
| 394 |
+
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 395 |
+
finish_reason=finish_reason,
|
| 396 |
+
)
|
| 397 |
|
| 398 |
return LLMResult(
|
| 399 |
content=full_content or None,
|
| 400 |
tool_calls_acc=tool_calls_acc,
|
| 401 |
token_count=token_count,
|
| 402 |
finish_reason=finish_reason,
|
| 403 |
+
usage=usage,
|
| 404 |
)
|
| 405 |
|
| 406 |
|
|
|
|
| 409 |
response = None
|
| 410 |
_healed_effort = False
|
| 411 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 412 |
+
t_start = time.monotonic()
|
| 413 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
| 414 |
try:
|
| 415 |
response = await acompletion(
|
|
|
|
| 471 |
Event(event_type="assistant_message", data={"content": content})
|
| 472 |
)
|
| 473 |
|
| 474 |
+
usage = await telemetry.record_llm_call(
|
| 475 |
+
session,
|
| 476 |
+
model=llm_params.get("model", session.config.model_name),
|
| 477 |
+
response=response,
|
| 478 |
+
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 479 |
+
finish_reason=finish_reason,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
return LLMResult(
|
| 483 |
content=content,
|
| 484 |
tool_calls_acc=tool_calls_acc,
|
| 485 |
token_count=token_count,
|
| 486 |
finish_reason=finish_reason,
|
| 487 |
+
usage=usage,
|
| 488 |
)
|
| 489 |
|
| 490 |
|
agent/core/redact.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Secret scrubbing for session trajectories before upload.
|
| 2 |
+
|
| 3 |
+
Users frequently paste HF / API / GitHub tokens into the chat, or scripts echo
|
| 4 |
+
them via env dumps. This module applies regex-based redaction to any string
|
| 5 |
+
value found recursively in a trajectory payload. The goal is best-effort —
|
| 6 |
+
strict formats are matched; we won't catch free-form leaks like "my password
|
| 7 |
+
is hunter2".
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
# Each entry: (compiled regex, replacement placeholder).
|
| 16 |
+
# Patterns are conservative: they only match tokens with the canonical prefix
|
| 17 |
+
# and a minimum body length so we don't paint over normal text.
|
| 18 |
+
_PATTERNS: list[tuple[re.Pattern, str]] = [
|
| 19 |
+
# Hugging Face tokens: hf_[A-Za-z0-9]{30,}
|
| 20 |
+
(re.compile(r"hf_[A-Za-z0-9]{30,}"), "[REDACTED_HF_TOKEN]"),
|
| 21 |
+
# Anthropic: sk-ant-[A-Za-z0-9_\-]{20,}
|
| 22 |
+
(re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"), "[REDACTED_ANTHROPIC_KEY]"),
|
| 23 |
+
# OpenAI: sk-[A-Za-z0-9]{40,} (legacy + proj keys)
|
| 24 |
+
(re.compile(r"sk-(?!ant-)[A-Za-z0-9_\-]{40,}"), "[REDACTED_OPENAI_KEY]"),
|
| 25 |
+
# GitHub classic PATs: ghp_, gho_, ghu_, ghs_, ghr_ followed by 36+ chars
|
| 26 |
+
(re.compile(r"gh[pousr]_[A-Za-z0-9]{36,}"), "[REDACTED_GITHUB_TOKEN]"),
|
| 27 |
+
# GitHub fine-grained PATs: github_pat_<alphanumeric_underscore>
|
| 28 |
+
(re.compile(r"github_pat_[A-Za-z0-9_]{36,}"), "[REDACTED_GITHUB_TOKEN]"),
|
| 29 |
+
# AWS access key IDs: AKIA / ASIA + 16 uppercase alnum
|
| 30 |
+
(re.compile(r"\b(?:AKIA|ASIA)[A-Z0-9]{16}\b"), "[REDACTED_AWS_KEY_ID]"),
|
| 31 |
+
# Generic 'Bearer <token>' header values
|
| 32 |
+
(re.compile(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}"), "Bearer [REDACTED]"),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# Env-var-like exports: we scrub the value but keep the name so callers can
|
| 36 |
+
# still see which secret was referenced. Covers `KEY=value` and `KEY: value`
|
| 37 |
+
# when the key looks secret-y.
|
| 38 |
+
_SECRETY_NAMES = re.compile(
|
| 39 |
+
r"(?i)\b(HF_TOKEN|HUGGINGFACEHUB_API_TOKEN|ANTHROPIC_API_KEY|OPENAI_API_KEY|"
|
| 40 |
+
r"GITHUB_TOKEN|AWS_SECRET_ACCESS_KEY|AWS_ACCESS_KEY_ID|PASSWORD|SECRET|API_KEY)"
|
| 41 |
+
r"\s*[:=]\s*([^\s\"']+)"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def scrub_string(s: str) -> str:
|
| 46 |
+
"""Apply all redaction patterns to a single string. Safe on non-strings."""
|
| 47 |
+
if not isinstance(s, str) or not s:
|
| 48 |
+
return s
|
| 49 |
+
out = s
|
| 50 |
+
for pat, repl in _PATTERNS:
|
| 51 |
+
out = pat.sub(repl, out)
|
| 52 |
+
out = _SECRETY_NAMES.sub(lambda m: f"{m.group(1)}=[REDACTED]", out)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def scrub(obj: Any) -> Any:
|
| 57 |
+
"""Recursively scrub every string value in a nested dict/list structure.
|
| 58 |
+
|
| 59 |
+
Returns a new object — inputs are not mutated."""
|
| 60 |
+
if isinstance(obj, str):
|
| 61 |
+
return scrub_string(obj)
|
| 62 |
+
if isinstance(obj, dict):
|
| 63 |
+
return {k: scrub(v) for k, v in obj.items()}
|
| 64 |
+
if isinstance(obj, list):
|
| 65 |
+
return [scrub(v) for v in obj]
|
| 66 |
+
if isinstance(obj, tuple):
|
| 67 |
+
return tuple(scrub(v) for v in obj)
|
| 68 |
+
return obj
|
agent/core/session.py
CHANGED
|
@@ -108,6 +108,11 @@ class Session:
|
|
| 108 |
self.session_start_time = datetime.now().isoformat()
|
| 109 |
self.turn_count: int = 0
|
| 110 |
self.last_auto_save_turn: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# Per-model probed reasoning-effort cache. Populated by the probe
|
| 113 |
# on /model switch, read by ``effective_effort_for`` below. Keys are
|
|
@@ -132,6 +137,10 @@ class Session:
|
|
| 132 |
}
|
| 133 |
)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def cancel(self) -> None:
|
| 136 |
"""Signal cancellation to the running agent loop."""
|
| 137 |
self._cancelled.set()
|
|
@@ -184,6 +193,12 @@ class Session:
|
|
| 184 |
|
| 185 |
def get_trajectory(self) -> dict:
|
| 186 |
"""Serialize complete session trajectory for logging"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
return {
|
| 188 |
"session_id": self.session_id,
|
| 189 |
"session_start_time": self.session_start_time,
|
|
@@ -191,6 +206,7 @@ class Session:
|
|
| 191 |
"model_name": self.config.model_name,
|
| 192 |
"messages": [msg.model_dump() for msg in self.context_manager.items],
|
| 193 |
"events": self.logged_events,
|
|
|
|
| 194 |
}
|
| 195 |
|
| 196 |
def save_trajectory_local(
|
|
@@ -216,16 +232,42 @@ class Session:
|
|
| 216 |
|
| 217 |
trajectory = self.get_trajectory()
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
# Add upload metadata
|
| 220 |
trajectory["upload_status"] = upload_status
|
| 221 |
trajectory["upload_url"] = dataset_url
|
| 222 |
trajectory["last_save_time"] = datetime.now().isoformat()
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
json.dump(trajectory, f, indent=2)
|
|
|
|
| 229 |
|
| 230 |
return str(filepath)
|
| 231 |
except Exception as e:
|
|
|
|
| 108 |
self.session_start_time = datetime.now().isoformat()
|
| 109 |
self.turn_count: int = 0
|
| 110 |
self.last_auto_save_turn: int = 0
|
| 111 |
+
# Stable local save path so heartbeat saves overwrite one file instead
|
| 112 |
+
# of spamming session_logs/. ``_last_heartbeat_ts`` is owned by
|
| 113 |
+
# ``agent.core.telemetry.HeartbeatSaver`` and lazily initialised there.
|
| 114 |
+
self._local_save_path: Optional[str] = None
|
| 115 |
+
self._last_heartbeat_ts: Optional[float] = None
|
| 116 |
|
| 117 |
# Per-model probed reasoning-effort cache. Populated by the probe
|
| 118 |
# on /model switch, read by ``effective_effort_for`` below. Keys are
|
|
|
|
| 137 |
}
|
| 138 |
)
|
| 139 |
|
| 140 |
+
# Mid-turn heartbeat flush (owned by telemetry module).
|
| 141 |
+
from agent.core.telemetry import HeartbeatSaver
|
| 142 |
+
HeartbeatSaver.maybe_fire(self)
|
| 143 |
+
|
| 144 |
def cancel(self) -> None:
|
| 145 |
"""Signal cancellation to the running agent loop."""
|
| 146 |
self._cancelled.set()
|
|
|
|
| 193 |
|
| 194 |
def get_trajectory(self) -> dict:
|
| 195 |
"""Serialize complete session trajectory for logging"""
|
| 196 |
+
tools: list = []
|
| 197 |
+
if self.tool_router is not None:
|
| 198 |
+
try:
|
| 199 |
+
tools = self.tool_router.get_tool_specs_for_llm() or []
|
| 200 |
+
except Exception:
|
| 201 |
+
tools = []
|
| 202 |
return {
|
| 203 |
"session_id": self.session_id,
|
| 204 |
"session_start_time": self.session_start_time,
|
|
|
|
| 206 |
"model_name": self.config.model_name,
|
| 207 |
"messages": [msg.model_dump() for msg in self.context_manager.items],
|
| 208 |
"events": self.logged_events,
|
| 209 |
+
"tools": tools,
|
| 210 |
}
|
| 211 |
|
| 212 |
def save_trajectory_local(
|
|
|
|
| 232 |
|
| 233 |
trajectory = self.get_trajectory()
|
| 234 |
|
| 235 |
+
# Scrub secrets at save time so session_logs/ never holds raw
|
| 236 |
+
# tokens on disk — a log aggregator, crash dump, or filesystem
|
| 237 |
+
# snapshot between heartbeats would otherwise leak them.
|
| 238 |
+
try:
|
| 239 |
+
from agent.core.redact import scrub
|
| 240 |
+
for key in ("messages", "events", "tools"):
|
| 241 |
+
if key in trajectory:
|
| 242 |
+
trajectory[key] = scrub(trajectory[key])
|
| 243 |
+
except Exception as _e:
|
| 244 |
+
logger.debug("Redact-on-save failed (non-fatal): %s", _e)
|
| 245 |
+
|
| 246 |
# Add upload metadata
|
| 247 |
trajectory["upload_status"] = upload_status
|
| 248 |
trajectory["upload_url"] = dataset_url
|
| 249 |
trajectory["last_save_time"] = datetime.now().isoformat()
|
| 250 |
|
| 251 |
+
# Reuse one stable path per session so heartbeat saves overwrite
|
| 252 |
+
# the same file instead of creating a new timestamped file every
|
| 253 |
+
# minute. The timestamp in the filename is kept for first-save
|
| 254 |
+
# ordering; subsequent saves just rewrite that file.
|
| 255 |
+
if self._local_save_path and Path(self._local_save_path).parent == log_dir:
|
| 256 |
+
filepath = Path(self._local_save_path)
|
| 257 |
+
else:
|
| 258 |
+
filename = (
|
| 259 |
+
f"session_{self.session_id}_"
|
| 260 |
+
f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 261 |
+
)
|
| 262 |
+
filepath = log_dir / filename
|
| 263 |
+
self._local_save_path = str(filepath)
|
| 264 |
+
|
| 265 |
+
# Atomic-ish write: stage to .tmp then rename so a crash mid-write
|
| 266 |
+
# doesn't leave a truncated JSON that breaks the retry scanner.
|
| 267 |
+
tmp_path = filepath.with_suffix(filepath.suffix + ".tmp")
|
| 268 |
+
with open(tmp_path, "w") as f:
|
| 269 |
json.dump(trajectory, f, indent=2)
|
| 270 |
+
tmp_path.replace(filepath)
|
| 271 |
|
| 272 |
return str(filepath)
|
| 273 |
except Exception as e:
|
agent/core/session_uploader.py
CHANGED
|
@@ -15,8 +15,15 @@ from dotenv import load_dotenv
|
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
-
# Token for session uploads
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def upload_session_as_file(
|
|
@@ -58,15 +65,37 @@ def upload_session_as_file(
|
|
| 58 |
json.dump(data, f, indent=2)
|
| 59 |
return False
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Prepare JSONL content (single line)
|
| 62 |
-
# Store messages
|
|
|
|
| 63 |
session_row = {
|
| 64 |
"session_id": data["session_id"],
|
| 65 |
"session_start_time": data["session_start_time"],
|
| 66 |
"session_end_time": data["session_end_time"],
|
| 67 |
"model_name": data["model_name"],
|
| 68 |
-
"messages": json.dumps(
|
| 69 |
-
"events": json.dumps(
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
# Create temporary JSONL file
|
|
|
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
+
# Token for session uploads. Fallback chain (least-privilege first) — matches
|
| 19 |
+
# backend/kpis_scheduler.py so one write-scoped token on the Space covers every
|
| 20 |
+
# telemetry dataset. Never hardcode tokens in source.
|
| 21 |
+
_SESSION_TOKEN = (
|
| 22 |
+
os.environ.get("HF_SESSION_UPLOAD_TOKEN")
|
| 23 |
+
or os.environ.get("HF_TOKEN")
|
| 24 |
+
or os.environ.get("HF_ADMIN_TOKEN")
|
| 25 |
+
or ""
|
| 26 |
+
)
|
| 27 |
|
| 28 |
|
| 29 |
def upload_session_as_file(
|
|
|
|
| 65 |
json.dump(data, f, indent=2)
|
| 66 |
return False
|
| 67 |
|
| 68 |
+
# Scrub secrets (HF tokens, API keys, etc.) from messages + events
|
| 69 |
+
# before they leave the local disk. Best-effort regex-based redaction —
|
| 70 |
+
# see agent/core/redact.py for the patterns covered.
|
| 71 |
+
try:
|
| 72 |
+
from agent.core.redact import scrub # type: ignore
|
| 73 |
+
except Exception:
|
| 74 |
+
# Fallback for environments where the agent package isn't importable
|
| 75 |
+
# (shouldn't happen in our subprocess, but be defensive).
|
| 76 |
+
import importlib.util
|
| 77 |
+
_spec = importlib.util.spec_from_file_location(
|
| 78 |
+
"_redact",
|
| 79 |
+
Path(__file__).parent / "redact.py",
|
| 80 |
+
)
|
| 81 |
+
_mod = importlib.util.module_from_spec(_spec)
|
| 82 |
+
_spec.loader.exec_module(_mod) # type: ignore
|
| 83 |
+
scrub = _mod.scrub
|
| 84 |
+
scrubbed_messages = scrub(data["messages"])
|
| 85 |
+
scrubbed_events = scrub(data["events"])
|
| 86 |
+
scrubbed_tools = scrub(data.get("tools") or [])
|
| 87 |
+
|
| 88 |
# Prepare JSONL content (single line)
|
| 89 |
+
# Store messages/events/tools as JSON strings to avoid schema conflicts
|
| 90 |
+
# across sessions with different tool rosters.
|
| 91 |
session_row = {
|
| 92 |
"session_id": data["session_id"],
|
| 93 |
"session_start_time": data["session_start_time"],
|
| 94 |
"session_end_time": data["session_end_time"],
|
| 95 |
"model_name": data["model_name"],
|
| 96 |
+
"messages": json.dumps(scrubbed_messages),
|
| 97 |
+
"events": json.dumps(scrubbed_events),
|
| 98 |
+
"tools": json.dumps(scrubbed_tools),
|
| 99 |
}
|
| 100 |
|
| 101 |
# Create temporary JSONL file
|
agent/core/telemetry.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""All agent observability in one module.
|
| 2 |
+
|
| 3 |
+
Every telemetry signal the agent emits — LLM-call usage / cost, hf_jobs
|
| 4 |
+
lifecycle, sandbox lifecycle, user feedback, mid-turn heartbeat saves — is
|
| 5 |
+
defined here so business-logic files stay free of instrumentation noise.
|
| 6 |
+
|
| 7 |
+
Callsites are one-liners::
|
| 8 |
+
|
| 9 |
+
await telemetry.record_llm_call(session, model=..., response=r, ...)
|
| 10 |
+
await telemetry.record_hf_job_submit(session, job, args, image=..., job_type="Python")
|
| 11 |
+
HeartbeatSaver.maybe_fire(session)
|
| 12 |
+
|
| 13 |
+
All ``record_*`` functions emit a single ``Event`` via ``session.send_event``
|
| 14 |
+
and never raise — telemetry is best-effort and must not break the agent.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import logging
|
| 21 |
+
import time
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── usage extraction ────────────────────────────────────────────────────────
|
| 28 |
+
|
| 29 |
+
def extract_usage(response_or_chunk: Any) -> dict:
|
| 30 |
+
"""Flat usage dict from a litellm response or final-chunk usage object.
|
| 31 |
+
|
| 32 |
+
Normalizes across providers: Anthropic exposes cache tokens as
|
| 33 |
+
``cache_read_input_tokens`` / ``cache_creation_input_tokens``; OpenAI uses
|
| 34 |
+
``prompt_tokens_details.cached_tokens``. Exposed under the stable keys
|
| 35 |
+
``cache_read_tokens`` / ``cache_creation_tokens``.
|
| 36 |
+
"""
|
| 37 |
+
u = getattr(response_or_chunk, "usage", None)
|
| 38 |
+
if u is None and isinstance(response_or_chunk, dict):
|
| 39 |
+
u = response_or_chunk.get("usage")
|
| 40 |
+
if u is None:
|
| 41 |
+
return {}
|
| 42 |
+
|
| 43 |
+
def _g(name, default=0):
|
| 44 |
+
if isinstance(u, dict):
|
| 45 |
+
return u.get(name, default) or default
|
| 46 |
+
return getattr(u, name, default) or default
|
| 47 |
+
|
| 48 |
+
prompt = _g("prompt_tokens")
|
| 49 |
+
completion = _g("completion_tokens")
|
| 50 |
+
total = _g("total_tokens") or (prompt + completion)
|
| 51 |
+
|
| 52 |
+
cache_read = _g("cache_read_input_tokens")
|
| 53 |
+
cache_creation = _g("cache_creation_input_tokens")
|
| 54 |
+
|
| 55 |
+
if not cache_read:
|
| 56 |
+
details = _g("prompt_tokens_details", None)
|
| 57 |
+
if details is not None:
|
| 58 |
+
if isinstance(details, dict):
|
| 59 |
+
cache_read = details.get("cached_tokens", 0) or 0
|
| 60 |
+
else:
|
| 61 |
+
cache_read = getattr(details, "cached_tokens", 0) or 0
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"prompt_tokens": int(prompt),
|
| 65 |
+
"completion_tokens": int(completion),
|
| 66 |
+
"total_tokens": int(total),
|
| 67 |
+
"cache_read_tokens": int(cache_read),
|
| 68 |
+
"cache_creation_tokens": int(cache_creation),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ── llm_call ────────────────────────────────────────────────────────────────
|
| 73 |
+
|
| 74 |
+
async def record_llm_call(
|
| 75 |
+
session: Any,
|
| 76 |
+
*,
|
| 77 |
+
model: str,
|
| 78 |
+
response: Any = None,
|
| 79 |
+
latency_ms: int,
|
| 80 |
+
finish_reason: str | None,
|
| 81 |
+
) -> dict:
|
| 82 |
+
"""Emit an ``llm_call`` event and return the extracted usage dict so
|
| 83 |
+
callers can stash it on their result object if they want."""
|
| 84 |
+
usage = extract_usage(response) if response is not None else {}
|
| 85 |
+
cost_usd = 0.0
|
| 86 |
+
if response is not None:
|
| 87 |
+
try:
|
| 88 |
+
from litellm import completion_cost
|
| 89 |
+
cost_usd = float(completion_cost(completion_response=response) or 0.0)
|
| 90 |
+
except Exception:
|
| 91 |
+
cost_usd = 0.0
|
| 92 |
+
from agent.core.session import Event # local import to avoid cycle
|
| 93 |
+
try:
|
| 94 |
+
await session.send_event(Event(
|
| 95 |
+
event_type="llm_call",
|
| 96 |
+
data={
|
| 97 |
+
"model": model,
|
| 98 |
+
"latency_ms": latency_ms,
|
| 99 |
+
"finish_reason": finish_reason,
|
| 100 |
+
"cost_usd": cost_usd,
|
| 101 |
+
**usage,
|
| 102 |
+
},
|
| 103 |
+
))
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.debug("record_llm_call failed (non-fatal): %s", e)
|
| 106 |
+
return usage
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ── hf_jobs ────────────────────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
def _infer_push_to_hub(script_or_cmd: Any) -> bool:
|
| 112 |
+
if not isinstance(script_or_cmd, str):
|
| 113 |
+
return False
|
| 114 |
+
return (
|
| 115 |
+
"push_to_hub=True" in script_or_cmd
|
| 116 |
+
or "push_to_hub=true" in script_or_cmd
|
| 117 |
+
or "hub_model_id" in script_or_cmd
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
async def record_hf_job_submit(
|
| 122 |
+
session: Any,
|
| 123 |
+
job: Any,
|
| 124 |
+
args: dict,
|
| 125 |
+
*,
|
| 126 |
+
image: str,
|
| 127 |
+
job_type: str,
|
| 128 |
+
) -> float:
|
| 129 |
+
"""Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
|
| 130 |
+
caller can pass it back into :func:`record_hf_job_complete`."""
|
| 131 |
+
from agent.core.session import Event
|
| 132 |
+
t_start = time.monotonic()
|
| 133 |
+
try:
|
| 134 |
+
script_text = args.get("script") or args.get("command") or ""
|
| 135 |
+
await session.send_event(Event(
|
| 136 |
+
event_type="hf_job_submit",
|
| 137 |
+
data={
|
| 138 |
+
"job_id": getattr(job, "id", None),
|
| 139 |
+
"job_url": getattr(job, "url", None),
|
| 140 |
+
"flavor": args.get("hardware_flavor", "cpu-basic"),
|
| 141 |
+
"timeout": args.get("timeout", "30m"),
|
| 142 |
+
"job_type": job_type,
|
| 143 |
+
"image": image,
|
| 144 |
+
"push_to_hub": _infer_push_to_hub(script_text),
|
| 145 |
+
},
|
| 146 |
+
))
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
|
| 149 |
+
return t_start
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
async def record_hf_job_complete(
|
| 153 |
+
session: Any,
|
| 154 |
+
job: Any,
|
| 155 |
+
*,
|
| 156 |
+
flavor: str,
|
| 157 |
+
final_status: str,
|
| 158 |
+
submit_ts: float,
|
| 159 |
+
) -> None:
|
| 160 |
+
from agent.core.session import Event
|
| 161 |
+
try:
|
| 162 |
+
wall_time_s = int(time.monotonic() - submit_ts)
|
| 163 |
+
await session.send_event(Event(
|
| 164 |
+
event_type="hf_job_complete",
|
| 165 |
+
data={
|
| 166 |
+
"job_id": getattr(job, "id", None),
|
| 167 |
+
"flavor": flavor,
|
| 168 |
+
"final_status": final_status,
|
| 169 |
+
"wall_time_s": wall_time_s,
|
| 170 |
+
},
|
| 171 |
+
))
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ── sandbox ─────────────────────────────────────────────────────────────────
|
| 177 |
+
|
| 178 |
+
async def record_sandbox_create(
|
| 179 |
+
session: Any,
|
| 180 |
+
sandbox: Any,
|
| 181 |
+
*,
|
| 182 |
+
hardware: str,
|
| 183 |
+
create_latency_s: int,
|
| 184 |
+
) -> None:
|
| 185 |
+
from agent.core.session import Event
|
| 186 |
+
try:
|
| 187 |
+
# Pin created-at on the session so record_sandbox_destroy can diff.
|
| 188 |
+
session._sandbox_created_at = time.monotonic() - create_latency_s
|
| 189 |
+
await session.send_event(Event(
|
| 190 |
+
event_type="sandbox_create",
|
| 191 |
+
data={
|
| 192 |
+
"sandbox_id": getattr(sandbox, "space_id", None),
|
| 193 |
+
"hardware": hardware,
|
| 194 |
+
"create_latency_s": int(create_latency_s),
|
| 195 |
+
},
|
| 196 |
+
))
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.debug("record_sandbox_create failed (non-fatal): %s", e)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
|
| 202 |
+
from agent.core.session import Event
|
| 203 |
+
try:
|
| 204 |
+
created = getattr(session, "_sandbox_created_at", None)
|
| 205 |
+
lifetime_s = int(time.monotonic() - created) if created else None
|
| 206 |
+
await session.send_event(Event(
|
| 207 |
+
event_type="sandbox_destroy",
|
| 208 |
+
data={
|
| 209 |
+
"sandbox_id": getattr(sandbox, "space_id", None),
|
| 210 |
+
"lifetime_s": lifetime_s,
|
| 211 |
+
},
|
| 212 |
+
))
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ── feedback ───────────────────────────────────────────────────────────────
|
| 218 |
+
|
| 219 |
+
async def record_feedback(
|
| 220 |
+
session: Any,
|
| 221 |
+
*,
|
| 222 |
+
rating: str,
|
| 223 |
+
turn_index: int | None = None,
|
| 224 |
+
message_id: str | None = None,
|
| 225 |
+
comment: str | None = None,
|
| 226 |
+
) -> None:
|
| 227 |
+
from agent.core.session import Event
|
| 228 |
+
try:
|
| 229 |
+
await session.send_event(Event(
|
| 230 |
+
event_type="feedback",
|
| 231 |
+
data={
|
| 232 |
+
"rating": rating,
|
| 233 |
+
"turn_index": turn_index,
|
| 234 |
+
"message_id": message_id,
|
| 235 |
+
"comment": (comment or "")[:500],
|
| 236 |
+
},
|
| 237 |
+
))
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.debug("record_feedback failed (non-fatal): %s", e)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ── heartbeat ──────────────────────────────────────────────────────────────
|
| 243 |
+
|
| 244 |
+
# Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
|
| 245 |
+
# keeps *weak* references to tasks, so the returned Task would otherwise be
|
| 246 |
+
# eligible for GC before running — the task gets discarded and the upload
|
| 247 |
+
# silently never happens. Hold strong refs until the task completes.
|
| 248 |
+
_heartbeat_tasks: set[asyncio.Task] = set()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class HeartbeatSaver:
|
| 252 |
+
"""Time-gated mid-turn flush.
|
| 253 |
+
|
| 254 |
+
Called from ``Session.send_event`` after every event. Fires
|
| 255 |
+
``save_and_upload_detached`` in a worker thread at most once per
|
| 256 |
+
``heartbeat_interval_s`` (default 60s). Guards against losing trace data
|
| 257 |
+
on long-running turns that crash before ``turn_complete``.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def maybe_fire(session: Any) -> None:
|
| 262 |
+
if not getattr(session.config, "save_sessions", False):
|
| 263 |
+
return
|
| 264 |
+
interval = getattr(session.config, "heartbeat_interval_s", 0) or 0
|
| 265 |
+
if interval <= 0:
|
| 266 |
+
return
|
| 267 |
+
now = time.monotonic()
|
| 268 |
+
last = getattr(session, "_last_heartbeat_ts", None)
|
| 269 |
+
if last is None:
|
| 270 |
+
# Initialise on first event; no save yet.
|
| 271 |
+
session._last_heartbeat_ts = now
|
| 272 |
+
return
|
| 273 |
+
if now - last < interval:
|
| 274 |
+
return
|
| 275 |
+
session._last_heartbeat_ts = now
|
| 276 |
+
repo_id = session.config.session_dataset_repo
|
| 277 |
+
try:
|
| 278 |
+
task = asyncio.get_running_loop().create_task(
|
| 279 |
+
asyncio.to_thread(session.save_and_upload_detached, repo_id)
|
| 280 |
+
)
|
| 281 |
+
# Hold a strong reference until the task finishes so asyncio can't
|
| 282 |
+
# GC it. ``set.discard`` is a no-op on missing keys → safe callback.
|
| 283 |
+
_heartbeat_tasks.add(task)
|
| 284 |
+
task.add_done_callback(_heartbeat_tasks.discard)
|
| 285 |
+
except RuntimeError:
|
| 286 |
+
try:
|
| 287 |
+
session.save_and_upload_detached(repo_id)
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.debug("Heartbeat save failed (non-fatal): %s", e)
|
agent/sft/__init__.py
ADDED
|
File without changes
|
agent/sft/tagger.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Derive tags for a session trajectory.
|
| 2 |
+
|
| 3 |
+
``tag_session(trajectory)`` → ``list[str]``. Pure function. No filtering, no
|
| 4 |
+
mutation — tags are purely metadata so downstream pipelines can slice the raw
|
| 5 |
+
SFT dataset (``where 'hf_job:succeeded' in tags``) without re-reading trajectories.
|
| 6 |
+
|
| 7 |
+
Tag namespaces (all tags are ``"<namespace>:<value>"`` strings):
|
| 8 |
+
|
| 9 |
+
* ``tool:<name>`` — every tool called at least once (``tool:hf_jobs``, …)
|
| 10 |
+
* ``outcome:<end>`` — ``completed`` / ``errored`` / ``interrupted`` /
|
| 11 |
+
``ongoing`` / ``doom_loop`` / ``context_exceeded``
|
| 12 |
+
* ``hf_job:<facet>`` — ``submitted``, ``succeeded``, ``failed``,
|
| 13 |
+
``multi`` (>1), ``oom``, ``push_to_hub``
|
| 14 |
+
* ``gpu:<kind>`` — ``none``, ``t4``, ``a10g``, ``a100``, ``l40s``,
|
| 15 |
+
``h100``, plus ``gpu:multi`` for x2/x4/x8 flavors
|
| 16 |
+
* ``sandbox:<facet>`` — ``created``, ``gpu``, ``cpu``, ``long_lived`` (>30 min)
|
| 17 |
+
* ``feedback:<kind>`` — ``up``, ``down``, ``mixed``, ``none``
|
| 18 |
+
* ``model:<family>`` — ``opus`` / ``sonnet`` / ``haiku`` / ``kimi`` /
|
| 19 |
+
``gpt`` / ``deepseek`` / ``qwen`` / ``other``
|
| 20 |
+
* ``turns:<bucket>`` — ``short`` (<5) / ``medium`` (5–20) / ``long`` (>20)
|
| 21 |
+
* ``cost:<bucket>`` — ``low`` (<$0.10) / ``med`` (<$1) / ``high``
|
| 22 |
+
* ``task:<kind>`` — ``training`` / ``inference`` / ``data_prep`` /
|
| 23 |
+
``research_only`` (heuristic on tools + scripts)
|
| 24 |
+
|
| 25 |
+
Tags are deduplicated before returning.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
from typing import Any, Iterable
|
| 31 |
+
|
| 32 |
+
# Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
|
| 33 |
+
_GPU_FAMILY = {
|
| 34 |
+
"cpu-basic": "none", "cpu-upgrade": "none",
|
| 35 |
+
"t4-small": "t4", "t4-medium": "t4",
|
| 36 |
+
"l4x1": "l40s", "l4x4": "l40s",
|
| 37 |
+
"l40sx1": "l40s", "l40sx4": "l40s", "l40sx8": "l40s",
|
| 38 |
+
"a10g-small": "a10g", "a10g-large": "a10g",
|
| 39 |
+
"a10g-largex2": "a10g", "a10g-largex4": "a10g",
|
| 40 |
+
"a100-large": "a100", "a100x2": "a100",
|
| 41 |
+
"a100x4": "a100", "a100x8": "a100",
|
| 42 |
+
"h100": "h100", "h100x8": "h100",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Substrings that count a flavor as multi-GPU.
|
| 46 |
+
_MULTI_GPU_MARKERS = ("x2", "x4", "x8")
|
| 47 |
+
|
| 48 |
+
# Tool names that don't touch training/inference or sandbox/jobs. If a session
|
| 49 |
+
# only used these, we tag it research_only.
|
| 50 |
+
_RESEARCH_ONLY_TOOLS = {
|
| 51 |
+
"research", "github_find_examples", "github_read_file", "github_list_repos",
|
| 52 |
+
"hf_papers", "explore_hf_docs", "fetch_hf_docs", "hub_repo_details",
|
| 53 |
+
"plan", "hf_inspect_dataset", "web_search",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# Tool names that signal data manipulation workflows.
|
| 57 |
+
_DATA_PREP_TOOLS = {"hf_inspect_dataset", "dataset_tools", "hub_repo_details"}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _model_family(model_name: str | None) -> str:
|
| 61 |
+
if not model_name:
|
| 62 |
+
return "other"
|
| 63 |
+
n = model_name.lower()
|
| 64 |
+
if "opus" in n:
|
| 65 |
+
return "opus"
|
| 66 |
+
if "sonnet" in n:
|
| 67 |
+
return "sonnet"
|
| 68 |
+
if "haiku" in n:
|
| 69 |
+
return "haiku"
|
| 70 |
+
if "kimi" in n:
|
| 71 |
+
return "kimi"
|
| 72 |
+
if "gpt" in n:
|
| 73 |
+
return "gpt"
|
| 74 |
+
if "deepseek" in n:
|
| 75 |
+
return "deepseek"
|
| 76 |
+
if "qwen" in n:
|
| 77 |
+
return "qwen"
|
| 78 |
+
if "llama" in n:
|
| 79 |
+
return "llama"
|
| 80 |
+
return "other"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _turns_bucket(n: int) -> str:
|
| 84 |
+
if n < 5:
|
| 85 |
+
return "short"
|
| 86 |
+
if n <= 20:
|
| 87 |
+
return "medium"
|
| 88 |
+
return "long"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _cost_bucket(cost_usd: float) -> str:
|
| 92 |
+
if cost_usd < 0.10:
|
| 93 |
+
return "low"
|
| 94 |
+
if cost_usd < 1.0:
|
| 95 |
+
return "med"
|
| 96 |
+
return "high"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _flavor_to_gpu_tags(flavor: str) -> list[str]:
|
| 100 |
+
family = _GPU_FAMILY.get(flavor, "none")
|
| 101 |
+
tags = [f"gpu:{family}"]
|
| 102 |
+
if any(m in flavor for m in _MULTI_GPU_MARKERS):
|
| 103 |
+
tags.append("gpu:multi")
|
| 104 |
+
return tags
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _has_oom_signal(tool_outputs: Iterable[str]) -> bool:
|
| 108 |
+
for out in tool_outputs:
|
| 109 |
+
if not isinstance(out, str):
|
| 110 |
+
continue
|
| 111 |
+
low = out.lower()
|
| 112 |
+
if "outofmemoryerror" in low or "cuda out of memory" in low or "oom" in low:
|
| 113 |
+
return True
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _infer_task_tag(
|
| 118 |
+
tool_names: set[str],
|
| 119 |
+
hf_job_submit_scripts: list[str],
|
| 120 |
+
) -> str | None:
|
| 121 |
+
"""Return a ``task:*`` tag or None if we can't tell.
|
| 122 |
+
|
| 123 |
+
Heuristic order: training > inference > data_prep > research_only.
|
| 124 |
+
"""
|
| 125 |
+
# training: any hf_jobs script with a Trainer/SFT/training keyword, OR uses
|
| 126 |
+
# hf_jobs at all and a script mentions training APIs.
|
| 127 |
+
for script in hf_job_submit_scripts:
|
| 128 |
+
low = script.lower()
|
| 129 |
+
if any(k in low for k in (
|
| 130 |
+
"sftconfig", "sfttrainer", "trainer(", "trainingarguments",
|
| 131 |
+
"grpo", "dpo", ".train(", "transformers import",
|
| 132 |
+
"trainer import", "fine-tune", "finetune",
|
| 133 |
+
)):
|
| 134 |
+
return "training"
|
| 135 |
+
|
| 136 |
+
# inference: sessions that use inference tools but never hf_jobs/sandbox
|
| 137 |
+
uses_compute = bool(tool_names & {"hf_jobs", "sandbox_create", "sandbox_exec"})
|
| 138 |
+
if not uses_compute and tool_names & {"inference", "generate", "run_inference"}:
|
| 139 |
+
return "inference"
|
| 140 |
+
|
| 141 |
+
# data_prep: primarily dataset tools and no training/inference
|
| 142 |
+
if tool_names & _DATA_PREP_TOOLS and not uses_compute:
|
| 143 |
+
return "data_prep"
|
| 144 |
+
|
| 145 |
+
# research_only: every tool used is in the research allow-list
|
| 146 |
+
if tool_names and tool_names <= _RESEARCH_ONLY_TOOLS:
|
| 147 |
+
return "research_only"
|
| 148 |
+
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def tag_session(trajectory: dict) -> list[str]:
|
| 153 |
+
"""Derive tags from a session trajectory. Pure function."""
|
| 154 |
+
tags: set[str] = set()
|
| 155 |
+
|
| 156 |
+
events: list[dict] = trajectory.get("events") or []
|
| 157 |
+
messages: list[dict] = trajectory.get("messages") or []
|
| 158 |
+
model_name: str | None = trajectory.get("model_name")
|
| 159 |
+
|
| 160 |
+
# model
|
| 161 |
+
tags.add(f"model:{_model_family(model_name)}")
|
| 162 |
+
|
| 163 |
+
# turns
|
| 164 |
+
user_turns = sum(1 for m in messages if m.get("role") == "user")
|
| 165 |
+
tags.add(f"turns:{_turns_bucket(user_turns)}")
|
| 166 |
+
|
| 167 |
+
# cost + tool-name enumeration + outcome detection
|
| 168 |
+
cost_usd = 0.0
|
| 169 |
+
tool_names: set[str] = set()
|
| 170 |
+
tool_outputs: list[str] = []
|
| 171 |
+
hf_job_submit_count = 0
|
| 172 |
+
hf_job_submit_scripts: list[str] = []
|
| 173 |
+
hf_job_success_count = 0
|
| 174 |
+
hf_job_fail_count = 0
|
| 175 |
+
hf_job_push_to_hub = False
|
| 176 |
+
gpu_tags_seen: set[str] = set()
|
| 177 |
+
|
| 178 |
+
# Outcome is the *last* terminal signal. Seed with "ongoing" — overridden
|
| 179 |
+
# if we see a terminal event.
|
| 180 |
+
outcome = "ongoing"
|
| 181 |
+
had_error = False
|
| 182 |
+
had_doom_loop = False
|
| 183 |
+
had_compact = False
|
| 184 |
+
|
| 185 |
+
feedback_up = 0
|
| 186 |
+
feedback_down = 0
|
| 187 |
+
|
| 188 |
+
sandbox_created = False
|
| 189 |
+
sandbox_hardware: str | None = None
|
| 190 |
+
sandbox_lifetime_s: int | None = None
|
| 191 |
+
|
| 192 |
+
for ev in events:
|
| 193 |
+
et = ev.get("event_type")
|
| 194 |
+
data = ev.get("data") or {}
|
| 195 |
+
|
| 196 |
+
if et == "llm_call":
|
| 197 |
+
cost_usd += float(data.get("cost_usd") or 0.0)
|
| 198 |
+
|
| 199 |
+
elif et == "tool_call":
|
| 200 |
+
name = data.get("tool")
|
| 201 |
+
if name:
|
| 202 |
+
tool_names.add(name)
|
| 203 |
+
|
| 204 |
+
elif et == "tool_output":
|
| 205 |
+
out = data.get("output")
|
| 206 |
+
if isinstance(out, str):
|
| 207 |
+
tool_outputs.append(out)
|
| 208 |
+
|
| 209 |
+
elif et == "hf_job_submit":
|
| 210 |
+
hf_job_submit_count += 1
|
| 211 |
+
if data.get("push_to_hub"):
|
| 212 |
+
hf_job_push_to_hub = True
|
| 213 |
+
flavor = data.get("flavor") or "cpu-basic"
|
| 214 |
+
for t in _flavor_to_gpu_tags(flavor):
|
| 215 |
+
gpu_tags_seen.add(t)
|
| 216 |
+
|
| 217 |
+
elif et == "hf_job_complete":
|
| 218 |
+
final = (data.get("final_status") or "").lower()
|
| 219 |
+
if final in ("completed", "succeeded", "success"):
|
| 220 |
+
hf_job_success_count += 1
|
| 221 |
+
elif final in ("failed", "error", "timeout", "cancelled"):
|
| 222 |
+
hf_job_fail_count += 1
|
| 223 |
+
|
| 224 |
+
elif et == "sandbox_create":
|
| 225 |
+
sandbox_created = True
|
| 226 |
+
sandbox_hardware = data.get("hardware")
|
| 227 |
+
|
| 228 |
+
elif et == "sandbox_destroy":
|
| 229 |
+
lt = data.get("lifetime_s")
|
| 230 |
+
if isinstance(lt, (int, float)):
|
| 231 |
+
sandbox_lifetime_s = int(lt)
|
| 232 |
+
|
| 233 |
+
elif et == "feedback":
|
| 234 |
+
rating = data.get("rating")
|
| 235 |
+
if rating == "up":
|
| 236 |
+
feedback_up += 1
|
| 237 |
+
elif rating == "down":
|
| 238 |
+
feedback_down += 1
|
| 239 |
+
|
| 240 |
+
elif et == "error":
|
| 241 |
+
had_error = True
|
| 242 |
+
elif et == "turn_complete":
|
| 243 |
+
if not had_error:
|
| 244 |
+
outcome = "completed"
|
| 245 |
+
elif et == "interrupted":
|
| 246 |
+
outcome = "interrupted"
|
| 247 |
+
elif et == "compacted":
|
| 248 |
+
had_compact = True
|
| 249 |
+
elif et == "tool_log":
|
| 250 |
+
log_text = (data.get("log") or "").lower()
|
| 251 |
+
if "doom loop" in log_text:
|
| 252 |
+
had_doom_loop = True
|
| 253 |
+
|
| 254 |
+
if had_error and outcome not in ("completed", "interrupted"):
|
| 255 |
+
outcome = "errored"
|
| 256 |
+
|
| 257 |
+
tags.add(f"outcome:{outcome}")
|
| 258 |
+
if had_doom_loop:
|
| 259 |
+
tags.add("outcome:doom_loop")
|
| 260 |
+
if had_compact:
|
| 261 |
+
tags.add("outcome:context_exceeded")
|
| 262 |
+
|
| 263 |
+
# tools
|
| 264 |
+
for name in tool_names:
|
| 265 |
+
tags.add(f"tool:{name}")
|
| 266 |
+
|
| 267 |
+
# hf_jobs facets
|
| 268 |
+
if hf_job_submit_count >= 1:
|
| 269 |
+
tags.add("hf_job:submitted")
|
| 270 |
+
if hf_job_submit_count > 1:
|
| 271 |
+
tags.add("hf_job:multi")
|
| 272 |
+
if hf_job_success_count > 0:
|
| 273 |
+
tags.add("hf_job:succeeded")
|
| 274 |
+
if hf_job_fail_count > 0:
|
| 275 |
+
tags.add("hf_job:failed")
|
| 276 |
+
if hf_job_push_to_hub:
|
| 277 |
+
tags.add("hf_job:push_to_hub")
|
| 278 |
+
if _has_oom_signal(tool_outputs):
|
| 279 |
+
tags.add("hf_job:oom")
|
| 280 |
+
|
| 281 |
+
# gpu tags (from all submitted jobs)
|
| 282 |
+
tags.update(gpu_tags_seen)
|
| 283 |
+
if "gpu:none" in tags and len(gpu_tags_seen) > 1:
|
| 284 |
+
# If any GPU flavor was used, drop the "none" tag for clarity.
|
| 285 |
+
tags.discard("gpu:none")
|
| 286 |
+
|
| 287 |
+
# sandbox facets
|
| 288 |
+
if sandbox_created:
|
| 289 |
+
tags.add("sandbox:created")
|
| 290 |
+
if sandbox_hardware:
|
| 291 |
+
fam = _GPU_FAMILY.get(sandbox_hardware, "none")
|
| 292 |
+
tags.add("sandbox:cpu" if fam == "none" else "sandbox:gpu")
|
| 293 |
+
if sandbox_lifetime_s is not None and sandbox_lifetime_s > 1800:
|
| 294 |
+
tags.add("sandbox:long_lived")
|
| 295 |
+
|
| 296 |
+
# feedback
|
| 297 |
+
if feedback_up and feedback_down:
|
| 298 |
+
tags.add("feedback:mixed")
|
| 299 |
+
elif feedback_up:
|
| 300 |
+
tags.add("feedback:up")
|
| 301 |
+
elif feedback_down:
|
| 302 |
+
tags.add("feedback:down")
|
| 303 |
+
else:
|
| 304 |
+
tags.add("feedback:none")
|
| 305 |
+
|
| 306 |
+
# cost bucket
|
| 307 |
+
tags.add(f"cost:{_cost_bucket(cost_usd)}")
|
| 308 |
+
|
| 309 |
+
# task heuristic (needs scripts — pull from the hf_job_submit events'
|
| 310 |
+
# matching tool_call arguments in the event list).
|
| 311 |
+
for ev in events:
|
| 312 |
+
if ev.get("event_type") == "tool_call":
|
| 313 |
+
data = ev.get("data") or {}
|
| 314 |
+
if data.get("tool") == "hf_jobs":
|
| 315 |
+
args = data.get("arguments") or {}
|
| 316 |
+
script = args.get("script") or args.get("command") or ""
|
| 317 |
+
if isinstance(script, str):
|
| 318 |
+
hf_job_submit_scripts.append(script)
|
| 319 |
+
|
| 320 |
+
task_tag = _infer_task_tag(tool_names, hf_job_submit_scripts)
|
| 321 |
+
if task_tag:
|
| 322 |
+
tags.add(f"task:{task_tag}")
|
| 323 |
+
|
| 324 |
+
return sorted(tags)
|
agent/tools/jobs_tool.py
CHANGED
|
@@ -528,14 +528,16 @@ class HfJobsTool:
|
|
| 528 |
job_type = "Docker"
|
| 529 |
|
| 530 |
# Run the job
|
|
|
|
|
|
|
| 531 |
job = await _async_call(
|
| 532 |
self.api.run_job,
|
| 533 |
image=image,
|
| 534 |
command=command,
|
| 535 |
env=_add_default_env(args.get("env")),
|
| 536 |
secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
|
| 537 |
-
flavor=
|
| 538 |
-
timeout=
|
| 539 |
namespace=self.namespace,
|
| 540 |
)
|
| 541 |
|
|
@@ -557,6 +559,16 @@ class HfJobsTool:
|
|
| 557 |
)
|
| 558 |
)
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
# Wait for completion and stream logs
|
| 561 |
logger.info(f"{job_type} job started: {job.url}")
|
| 562 |
logger.info("Streaming logs...")
|
|
@@ -566,6 +578,13 @@ class HfJobsTool:
|
|
| 566 |
namespace=self.namespace,
|
| 567 |
)
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
# Untrack job ID (completed or failed, no longer needs cancellation)
|
| 570 |
if self.session:
|
| 571 |
self.session._running_job_ids.discard(job.id)
|
|
|
|
| 528 |
job_type = "Docker"
|
| 529 |
|
| 530 |
# Run the job
|
| 531 |
+
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 532 |
+
timeout_str = args.get("timeout", "30m")
|
| 533 |
job = await _async_call(
|
| 534 |
self.api.run_job,
|
| 535 |
image=image,
|
| 536 |
command=command,
|
| 537 |
env=_add_default_env(args.get("env")),
|
| 538 |
secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
|
| 539 |
+
flavor=flavor,
|
| 540 |
+
timeout=timeout_str,
|
| 541 |
namespace=self.namespace,
|
| 542 |
)
|
| 543 |
|
|
|
|
| 559 |
)
|
| 560 |
)
|
| 561 |
|
| 562 |
+
# Telemetry: job submission + completion (infra consumption signal).
|
| 563 |
+
submit_ts = None
|
| 564 |
+
if self.session:
|
| 565 |
+
from agent.core import telemetry
|
| 566 |
+
submit_ts = await telemetry.record_hf_job_submit(
|
| 567 |
+
self.session, job,
|
| 568 |
+
{**args, "hardware_flavor": flavor, "timeout": timeout_str},
|
| 569 |
+
image=image, job_type=job_type,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
# Wait for completion and stream logs
|
| 573 |
logger.info(f"{job_type} job started: {job.url}")
|
| 574 |
logger.info("Streaming logs...")
|
|
|
|
| 578 |
namespace=self.namespace,
|
| 579 |
)
|
| 580 |
|
| 581 |
+
if self.session and submit_ts is not None:
|
| 582 |
+
from agent.core import telemetry
|
| 583 |
+
await telemetry.record_hf_job_complete(
|
| 584 |
+
self.session, job,
|
| 585 |
+
flavor=flavor, final_status=final_status, submit_ts=submit_ts,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
# Untrack job ID (completed or failed, no longer needs cancellation)
|
| 589 |
if self.session:
|
| 590 |
self.session._running_job_ids.discard(job.id)
|
agent/tools/sandbox_tool.py
CHANGED
|
@@ -131,6 +131,8 @@ async def _ensure_sandbox(
|
|
| 131 |
}
|
| 132 |
if hardware != "cpu-basic":
|
| 133 |
kwargs["sleep_time"] = 2700
|
|
|
|
|
|
|
| 134 |
try:
|
| 135 |
sb = await asyncio.to_thread(Sandbox.create, **kwargs)
|
| 136 |
except Sandbox.Cancelled:
|
|
@@ -139,6 +141,13 @@ async def _ensure_sandbox(
|
|
| 139 |
watcher_task.cancel()
|
| 140 |
session.sandbox = sb
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# Set a descriptive title (template title is inherited on duplicate)
|
| 143 |
from huggingface_hub import metadata_update
|
| 144 |
|
|
|
|
| 131 |
}
|
| 132 |
if hardware != "cpu-basic":
|
| 133 |
kwargs["sleep_time"] = 2700
|
| 134 |
+
import time as _t
|
| 135 |
+
_t_start = _t.monotonic()
|
| 136 |
try:
|
| 137 |
sb = await asyncio.to_thread(Sandbox.create, **kwargs)
|
| 138 |
except Sandbox.Cancelled:
|
|
|
|
| 141 |
watcher_task.cancel()
|
| 142 |
session.sandbox = sb
|
| 143 |
|
| 144 |
+
# Telemetry: sandbox creation (infra consumption signal)
|
| 145 |
+
from agent.core import telemetry
|
| 146 |
+
await telemetry.record_sandbox_create(
|
| 147 |
+
session, sb, hardware=hardware,
|
| 148 |
+
create_latency_s=int(_t.monotonic() - _t_start),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
# Set a descriptive title (template title is inherited on duplicate)
|
| 152 |
from huggingface_hub import metadata_update
|
| 153 |
|
backend/kpis_scheduler.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""In-process hourly KPI rollup, owned by the backend Space lifespan.
|
| 2 |
+
|
| 3 |
+
Replaces an external GitHub Actions cron so the rollup lives next to the data
|
| 4 |
+
and reuses the Space's existing HF token — no production secrets on the
|
| 5 |
+
public source repo. See ``scripts/build_kpis.py`` for the data-flow diagram
|
| 6 |
+
and metric definitions.
|
| 7 |
+
|
| 8 |
+
Behaviour::
|
| 9 |
+
|
| 10 |
+
lifespan startup → start APScheduler with cron("5 * * * *", UTC)
|
| 11 |
+
→ fire a best-effort 6-hour backfill (fire-and-forget)
|
| 12 |
+
each :05 → run ``build_kpis.run_for_hour`` for the just-completed hour
|
| 13 |
+
lifespan shutdown → scheduler.shutdown(wait=False)
|
| 14 |
+
|
| 15 |
+
Environment::
|
| 16 |
+
|
| 17 |
+
HF_KPI_WRITE_TOKEN | HF_SESSION_UPLOAD_TOKEN | HF_TOKEN | HF_ADMIN_TOKEN
|
| 18 |
+
First one found is used. Least-privilege first.
|
| 19 |
+
KPI_SOURCE_REPO default smolagents/ml-intern-sessions
|
| 20 |
+
KPI_TARGET_REPO default smolagents/ml-intern-kpis
|
| 21 |
+
ML_INTERN_KPIS_DISABLED if truthy, the scheduler is not started
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import asyncio
|
| 27 |
+
import importlib.util
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
from datetime import datetime, timedelta, timezone
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
+
|
| 38 |
+
# Hold strong refs to backfill tasks so asyncio doesn't GC them mid-run.
|
| 39 |
+
_background_tasks: set[asyncio.Task] = set()
|
| 40 |
+
|
| 41 |
+
_scheduler = None # AsyncIOScheduler instance (lazy import)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _resolve_token() -> Optional[str]:
|
| 45 |
+
"""Pick the first available HF token. Least-privilege first."""
|
| 46 |
+
for var in (
|
| 47 |
+
"HF_KPI_WRITE_TOKEN",
|
| 48 |
+
"HF_SESSION_UPLOAD_TOKEN",
|
| 49 |
+
"HF_TOKEN",
|
| 50 |
+
"HF_ADMIN_TOKEN",
|
| 51 |
+
):
|
| 52 |
+
val = os.environ.get(var)
|
| 53 |
+
if val:
|
| 54 |
+
return val
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _load_build_kpis():
|
| 59 |
+
"""Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
|
| 60 |
+
spec = importlib.util.spec_from_file_location(
|
| 61 |
+
"build_kpis", _PROJECT_ROOT / "scripts" / "build_kpis.py",
|
| 62 |
+
)
|
| 63 |
+
mod = importlib.util.module_from_spec(spec)
|
| 64 |
+
assert spec.loader is not None
|
| 65 |
+
spec.loader.exec_module(mod)
|
| 66 |
+
return mod
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def _run_hour(hour_dt: datetime) -> None:
|
| 70 |
+
"""Run one hourly rollup off the event loop. Best-effort, never raises."""
|
| 71 |
+
token = _resolve_token()
|
| 72 |
+
if not token:
|
| 73 |
+
logger.warning("kpis_scheduler: no HF token available, skipping %s", hour_dt)
|
| 74 |
+
return
|
| 75 |
+
try:
|
| 76 |
+
mod = _load_build_kpis()
|
| 77 |
+
from huggingface_hub import HfApi
|
| 78 |
+
api = HfApi()
|
| 79 |
+
source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
|
| 80 |
+
target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
|
| 81 |
+
await asyncio.to_thread(mod.run_for_hour, api, source, target, hour_dt, token)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.warning("kpis_scheduler: rollup for %s failed: %s", hour_dt, e)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def run_last_completed_hour() -> None:
|
| 87 |
+
"""The scheduled-at-:05 job. Rolls up the previous whole hour."""
|
| 88 |
+
now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
|
| 89 |
+
await _run_hour(now - timedelta(hours=1))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
async def backfill(hours: int = 6) -> None:
|
| 93 |
+
"""Catch-up pass for hours the Space was down. Idempotent (overwrites)."""
|
| 94 |
+
now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
|
| 95 |
+
for i in range(1, hours + 1):
|
| 96 |
+
await _run_hour(now - timedelta(hours=i))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def start(backfill_hours: int = 6) -> None:
|
| 100 |
+
"""Called from FastAPI lifespan startup."""
|
| 101 |
+
global _scheduler
|
| 102 |
+
if os.environ.get("ML_INTERN_KPIS_DISABLED"):
|
| 103 |
+
logger.info("kpis_scheduler: disabled via ML_INTERN_KPIS_DISABLED")
|
| 104 |
+
return
|
| 105 |
+
if _scheduler is not None:
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
| 110 |
+
from apscheduler.triggers.cron import CronTrigger
|
| 111 |
+
except ImportError:
|
| 112 |
+
logger.warning("kpis_scheduler: apscheduler not installed, skipping")
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
_scheduler = AsyncIOScheduler(timezone="UTC")
|
| 116 |
+
_scheduler.add_job(
|
| 117 |
+
run_last_completed_hour,
|
| 118 |
+
CronTrigger(minute=5),
|
| 119 |
+
id="kpis_hourly",
|
| 120 |
+
misfire_grace_time=600, # tolerate a 10-min misfire window
|
| 121 |
+
coalesce=True, # collapse multiple missed fires into one
|
| 122 |
+
max_instances=1,
|
| 123 |
+
replace_existing=True,
|
| 124 |
+
)
|
| 125 |
+
_scheduler.start()
|
| 126 |
+
logger.info("kpis_scheduler: started (cron '5 * * * *' UTC)")
|
| 127 |
+
|
| 128 |
+
# Non-blocking backfill. Hold a strong ref until done so asyncio doesn't
|
| 129 |
+
# GC the task before it finishes.
|
| 130 |
+
try:
|
| 131 |
+
task = asyncio.get_running_loop().create_task(backfill(backfill_hours))
|
| 132 |
+
_background_tasks.add(task)
|
| 133 |
+
task.add_done_callback(_background_tasks.discard)
|
| 134 |
+
except RuntimeError:
|
| 135 |
+
# Not in an event loop (tests); skip backfill.
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def shutdown() -> None:
|
| 140 |
+
"""Called from FastAPI lifespan shutdown."""
|
| 141 |
+
global _scheduler
|
| 142 |
+
if _scheduler is None:
|
| 143 |
+
return
|
| 144 |
+
_scheduler.shutdown(wait=False)
|
| 145 |
+
_scheduler = None
|
| 146 |
+
logger.info("kpis_scheduler: stopped")
|
backend/main.py
CHANGED
|
@@ -27,8 +27,37 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
async def lifespan(app: FastAPI):
|
| 28 |
"""Application lifespan handler."""
|
| 29 |
logger.info("Starting HF Agent backend...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
yield
|
|
|
|
| 31 |
logger.info("Shutting down HF Agent backend...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
app = FastAPI(
|
|
|
|
| 27 |
async def lifespan(app: FastAPI):
|
| 28 |
"""Application lifespan handler."""
|
| 29 |
logger.info("Starting HF Agent backend...")
|
| 30 |
+
# Start in-process hourly KPI rollup. Replaces an external cron so the
|
| 31 |
+
# rollup lives next to the data and reuses the Space's HF token.
|
| 32 |
+
try:
|
| 33 |
+
import kpis_scheduler
|
| 34 |
+
kpis_scheduler.start()
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.warning("KPI scheduler failed to start: %s", e)
|
| 37 |
+
|
| 38 |
yield
|
| 39 |
+
|
| 40 |
logger.info("Shutting down HF Agent backend...")
|
| 41 |
+
try:
|
| 42 |
+
import kpis_scheduler
|
| 43 |
+
await kpis_scheduler.shutdown()
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.warning("KPI scheduler shutdown failed: %s", e)
|
| 46 |
+
|
| 47 |
+
# Final-flush: save every still-active session so we don't lose traces on
|
| 48 |
+
# server restart. Uploads are detached subprocesses — this is fast.
|
| 49 |
+
try:
|
| 50 |
+
from session_manager import session_manager
|
| 51 |
+
for sid, agent_session in list(session_manager.sessions.items()):
|
| 52 |
+
sess = agent_session.session
|
| 53 |
+
if sess.config.save_sessions:
|
| 54 |
+
try:
|
| 55 |
+
sess.save_and_upload_detached(sess.config.session_dataset_repo)
|
| 56 |
+
logger.info("Flushed session %s on shutdown", sid)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.warning("Failed to flush session %s: %s", sid, e)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning("Lifespan final-flush skipped: %s", e)
|
| 61 |
|
| 62 |
|
| 63 |
app = FastAPI(
|
backend/routes/agent.py
CHANGED
|
@@ -693,3 +693,41 @@ async def shutdown_session(
|
|
| 693 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 694 |
|
| 695 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 694 |
|
| 695 |
|
| 696 |
+
@router.post("/feedback/{session_id}")
|
| 697 |
+
async def submit_feedback(
|
| 698 |
+
session_id: str,
|
| 699 |
+
body: dict,
|
| 700 |
+
user: dict = Depends(get_current_user),
|
| 701 |
+
) -> dict:
|
| 702 |
+
"""Attach a user feedback signal to a session's event log.
|
| 703 |
+
|
| 704 |
+
Body: {rating: "up"|"down"|"outcome_success"|"outcome_fail",
|
| 705 |
+
turn_index?: int, comment?: str, message_id?: str}
|
| 706 |
+
Appended as a `feedback` event and saved with the session trajectory.
|
| 707 |
+
"""
|
| 708 |
+
_check_session_access(session_id, user)
|
| 709 |
+
agent_session = session_manager.sessions.get(session_id)
|
| 710 |
+
if not agent_session:
|
| 711 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 712 |
+
|
| 713 |
+
rating = body.get("rating")
|
| 714 |
+
if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
|
| 715 |
+
raise HTTPException(status_code=400, detail="invalid rating")
|
| 716 |
+
|
| 717 |
+
from agent.core import telemetry
|
| 718 |
+
await telemetry.record_feedback(
|
| 719 |
+
agent_session.session,
|
| 720 |
+
rating=rating,
|
| 721 |
+
turn_index=body.get("turn_index"),
|
| 722 |
+
message_id=body.get("message_id"),
|
| 723 |
+
comment=body.get("comment"),
|
| 724 |
+
)
|
| 725 |
+
# Fire-and-forget save so feedback reaches the dataset even if the user
|
| 726 |
+
# closes the tab right after clicking.
|
| 727 |
+
if agent_session.session.config.save_sessions:
|
| 728 |
+
agent_session.session.save_and_upload_detached(
|
| 729 |
+
agent_session.session.config.session_dataset_repo
|
| 730 |
+
)
|
| 731 |
+
return {"status": "ok"}
|
| 732 |
+
|
| 733 |
+
|
backend/session_manager.py
CHANGED
|
@@ -290,11 +290,14 @@ class SessionManager:
|
|
| 290 |
"""Delete the sandbox Space if one was created for this session."""
|
| 291 |
sandbox = getattr(session, "sandbox", None)
|
| 292 |
if sandbox and getattr(sandbox, "_owns_space", False):
|
|
|
|
| 293 |
try:
|
| 294 |
-
logger.info(f"Deleting sandbox {
|
| 295 |
await asyncio.to_thread(sandbox.delete)
|
|
|
|
|
|
|
| 296 |
except Exception as e:
|
| 297 |
-
logger.warning(f"Failed to delete sandbox {
|
| 298 |
|
| 299 |
async def _run_session(
|
| 300 |
self,
|
|
@@ -356,6 +359,15 @@ class SessionManager:
|
|
| 356 |
|
| 357 |
await self._cleanup_sandbox(session)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
async with self._lock:
|
| 360 |
if session_id in self.sessions:
|
| 361 |
self.sessions[session_id].is_active = False
|
|
|
|
| 290 |
"""Delete the sandbox Space if one was created for this session."""
|
| 291 |
sandbox = getattr(session, "sandbox", None)
|
| 292 |
if sandbox and getattr(sandbox, "_owns_space", False):
|
| 293 |
+
space_id = getattr(sandbox, "space_id", None)
|
| 294 |
try:
|
| 295 |
+
logger.info(f"Deleting sandbox {space_id}...")
|
| 296 |
await asyncio.to_thread(sandbox.delete)
|
| 297 |
+
from agent.core import telemetry
|
| 298 |
+
await telemetry.record_sandbox_destroy(session, sandbox)
|
| 299 |
except Exception as e:
|
| 300 |
+
logger.warning(f"Failed to delete sandbox {space_id}: {e}")
|
| 301 |
|
| 302 |
async def _run_session(
|
| 303 |
self,
|
|
|
|
| 359 |
|
| 360 |
await self._cleanup_sandbox(session)
|
| 361 |
|
| 362 |
+
# Final-flush: always save on session death so we capture ended
|
| 363 |
+
# sessions even if the client disconnects without /shutdown.
|
| 364 |
+
# Idempotent via session_id key; detached subprocess.
|
| 365 |
+
if session.config.save_sessions:
|
| 366 |
+
try:
|
| 367 |
+
session.save_and_upload_detached(session.config.session_dataset_repo)
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.warning(f"Final-flush failed for {session_id}: {e}")
|
| 370 |
+
|
| 371 |
async with self._lock:
|
| 372 |
if session_id in self.sessions:
|
| 373 |
self.sessions[session_id].is_active = False
|
configs/main_agent_config.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"model_name": "bedrock/us.anthropic.claude-opus-4-6-v1",
|
| 3 |
"save_sessions": true,
|
| 4 |
-
"session_dataset_repo": "
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
|
|
|
| 1 |
{
|
| 2 |
"model_name": "bedrock/us.anthropic.claude-opus-4-6-v1",
|
| 3 |
"save_sessions": true,
|
| 4 |
+
"session_dataset_repo": "smolagents/ml-intern-sessions",
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
frontend/src/components/Chat/AssistantMessage.tsx
CHANGED
|
@@ -1,13 +1,19 @@
|
|
| 1 |
-
import { useMemo } from 'react';
|
| 2 |
-
import { Box, Stack, Typography } from '@mui/material';
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import MarkdownContent from './MarkdownContent';
|
| 4 |
import ToolCallGroup from './ToolCallGroup';
|
|
|
|
| 5 |
import type { UIMessage } from 'ai';
|
| 6 |
import type { MessageMeta } from '@/types/agent';
|
| 7 |
|
| 8 |
interface AssistantMessageProps {
|
| 9 |
message: UIMessage;
|
| 10 |
isStreaming?: boolean;
|
|
|
|
| 11 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 12 |
}
|
| 13 |
|
|
@@ -43,8 +49,27 @@ function groupParts(parts: UIMessage['parts']) {
|
|
| 43 |
return groups;
|
| 44 |
}
|
| 45 |
|
| 46 |
-
export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) {
|
| 47 |
const groups = useMemo(() => groupParts(message.parts), [message.parts]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
// Find the last text group index for streaming cursor
|
| 50 |
let lastTextIdx = -1;
|
|
@@ -114,6 +139,24 @@ export default function AssistantMessage({ message, isStreaming = false, approve
|
|
| 114 |
return null;
|
| 115 |
})}
|
| 116 |
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
</Box>
|
| 118 |
);
|
| 119 |
}
|
|
|
|
| 1 |
+
import { useMemo, useState } from 'react';
|
| 2 |
+
import { Box, IconButton, Stack, Tooltip, Typography } from '@mui/material';
|
| 3 |
+
import ThumbUpOutlined from '@mui/icons-material/ThumbUpOutlined';
|
| 4 |
+
import ThumbUp from '@mui/icons-material/ThumbUp';
|
| 5 |
+
import ThumbDownOutlined from '@mui/icons-material/ThumbDownOutlined';
|
| 6 |
+
import ThumbDown from '@mui/icons-material/ThumbDown';
|
| 7 |
import MarkdownContent from './MarkdownContent';
|
| 8 |
import ToolCallGroup from './ToolCallGroup';
|
| 9 |
+
import { apiFetch } from '@/utils/api';
|
| 10 |
import type { UIMessage } from 'ai';
|
| 11 |
import type { MessageMeta } from '@/types/agent';
|
| 12 |
|
| 13 |
interface AssistantMessageProps {
|
| 14 |
message: UIMessage;
|
| 15 |
isStreaming?: boolean;
|
| 16 |
+
sessionId?: string | null;
|
| 17 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 18 |
}
|
| 19 |
|
|
|
|
| 49 |
return groups;
|
| 50 |
}
|
| 51 |
|
| 52 |
+
export default function AssistantMessage({ message, isStreaming = false, sessionId, approveTools }: AssistantMessageProps) {
|
| 53 |
const groups = useMemo(() => groupParts(message.parts), [message.parts]);
|
| 54 |
+
const [feedback, setFeedback] = useState<'up' | 'down' | null>(null);
|
| 55 |
+
const [feedbackBusy, setFeedbackBusy] = useState(false);
|
| 56 |
+
|
| 57 |
+
const sendFeedback = async (rating: 'up' | 'down') => {
|
| 58 |
+
if (!sessionId || feedbackBusy) return;
|
| 59 |
+
setFeedbackBusy(true);
|
| 60 |
+
// Optimistic toggle — feedback is observability, not a hard requirement.
|
| 61 |
+
setFeedback(rating);
|
| 62 |
+
try {
|
| 63 |
+
await apiFetch(`/api/feedback/${sessionId}`, {
|
| 64 |
+
method: 'POST',
|
| 65 |
+
body: JSON.stringify({ rating, message_id: message.id }),
|
| 66 |
+
});
|
| 67 |
+
} catch {
|
| 68 |
+
// Silently swallow — don't block chat UX on a telemetry write.
|
| 69 |
+
} finally {
|
| 70 |
+
setFeedbackBusy(false);
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
|
| 74 |
// Find the last text group index for streaming cursor
|
| 75 |
let lastTextIdx = -1;
|
|
|
|
| 139 |
return null;
|
| 140 |
})}
|
| 141 |
</Box>
|
| 142 |
+
{!isStreaming && sessionId && (
|
| 143 |
+
<Stack
|
| 144 |
+
direction="row"
|
| 145 |
+
spacing={0.5}
|
| 146 |
+
sx={{ mt: 0.5, ml: 0.5, opacity: feedback ? 1 : 0.5, '&:hover': { opacity: 1 } }}
|
| 147 |
+
>
|
| 148 |
+
<Tooltip title="Helpful">
|
| 149 |
+
<IconButton size="small" disabled={feedbackBusy} onClick={() => sendFeedback('up')}>
|
| 150 |
+
{feedback === 'up' ? <ThumbUp fontSize="inherit" /> : <ThumbUpOutlined fontSize="inherit" />}
|
| 151 |
+
</IconButton>
|
| 152 |
+
</Tooltip>
|
| 153 |
+
<Tooltip title="Not helpful">
|
| 154 |
+
<IconButton size="small" disabled={feedbackBusy} onClick={() => sendFeedback('down')}>
|
| 155 |
+
{feedback === 'down' ? <ThumbDown fontSize="inherit" /> : <ThumbDownOutlined fontSize="inherit" />}
|
| 156 |
+
</IconButton>
|
| 157 |
+
</Tooltip>
|
| 158 |
+
</Stack>
|
| 159 |
+
)}
|
| 160 |
</Box>
|
| 161 |
);
|
| 162 |
}
|
frontend/src/components/Chat/MessageBubble.tsx
CHANGED
|
@@ -9,6 +9,7 @@ interface MessageBubbleProps {
|
|
| 9 |
onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
|
| 10 |
isProcessing?: boolean;
|
| 11 |
isStreaming?: boolean;
|
|
|
|
| 12 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 13 |
}
|
| 14 |
|
|
@@ -19,6 +20,7 @@ export default function MessageBubble({
|
|
| 19 |
onEditAndRegenerate,
|
| 20 |
isProcessing = false,
|
| 21 |
isStreaming = false,
|
|
|
|
| 22 |
approveTools,
|
| 23 |
}: MessageBubbleProps) {
|
| 24 |
if (message.role === 'user') {
|
|
@@ -38,6 +40,7 @@ export default function MessageBubble({
|
|
| 38 |
<AssistantMessage
|
| 39 |
message={message}
|
| 40 |
isStreaming={isStreaming}
|
|
|
|
| 41 |
approveTools={approveTools}
|
| 42 |
/>
|
| 43 |
);
|
|
|
|
| 9 |
onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
|
| 10 |
isProcessing?: boolean;
|
| 11 |
isStreaming?: boolean;
|
| 12 |
+
sessionId?: string | null;
|
| 13 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 14 |
}
|
| 15 |
|
|
|
|
| 20 |
onEditAndRegenerate,
|
| 21 |
isProcessing = false,
|
| 22 |
isStreaming = false,
|
| 23 |
+
sessionId,
|
| 24 |
approveTools,
|
| 25 |
}: MessageBubbleProps) {
|
| 26 |
if (message.role === 'user') {
|
|
|
|
| 40 |
<AssistantMessage
|
| 41 |
message={message}
|
| 42 |
isStreaming={isStreaming}
|
| 43 |
+
sessionId={sessionId}
|
| 44 |
approveTools={approveTools}
|
| 45 |
/>
|
| 46 |
);
|
frontend/src/components/Chat/MessageList.tsx
CHANGED
|
@@ -8,6 +8,7 @@ import type { UIMessage } from 'ai';
|
|
| 8 |
interface MessageListProps {
|
| 9 |
messages: UIMessage[];
|
| 10 |
isProcessing: boolean;
|
|
|
|
| 11 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 12 |
onUndoLastTurn: () => void | Promise<void>;
|
| 13 |
onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
|
|
@@ -57,7 +58,7 @@ function WelcomeGreeting() {
|
|
| 57 |
);
|
| 58 |
}
|
| 59 |
|
| 60 |
-
export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) {
|
| 61 |
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
| 62 |
const stickToBottom = useRef(true);
|
| 63 |
|
|
@@ -139,6 +140,7 @@ export default function MessageList({ messages, isProcessing, approveTools, onUn
|
|
| 139 |
onEditAndRegenerate={onEditAndRegenerate}
|
| 140 |
isProcessing={isProcessing}
|
| 141 |
isStreaming={isProcessing && msg.id === lastAssistantId}
|
|
|
|
| 142 |
approveTools={approveTools}
|
| 143 |
/>
|
| 144 |
))
|
|
|
|
| 8 |
interface MessageListProps {
|
| 9 |
messages: UIMessage[];
|
| 10 |
isProcessing: boolean;
|
| 11 |
+
sessionId?: string | null;
|
| 12 |
approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
|
| 13 |
onUndoLastTurn: () => void | Promise<void>;
|
| 14 |
onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
|
|
|
|
| 58 |
);
|
| 59 |
}
|
| 60 |
|
| 61 |
+
export default function MessageList({ messages, isProcessing, sessionId, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) {
|
| 62 |
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
| 63 |
const stickToBottom = useRef(true);
|
| 64 |
|
|
|
|
| 140 |
onEditAndRegenerate={onEditAndRegenerate}
|
| 141 |
isProcessing={isProcessing}
|
| 142 |
isStreaming={isProcessing && msg.id === lastAssistantId}
|
| 143 |
+
sessionId={sessionId}
|
| 144 |
approveTools={approveTools}
|
| 145 |
/>
|
| 146 |
))
|
frontend/src/components/SessionChat.tsx
CHANGED
|
@@ -102,6 +102,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 102 |
<MessageList
|
| 103 |
messages={messages}
|
| 104 |
isProcessing={busy}
|
|
|
|
| 105 |
approveTools={approveTools}
|
| 106 |
onUndoLastTurn={undoLastTurn}
|
| 107 |
onEditAndRegenerate={editAndRegenerate}
|
|
|
|
| 102 |
<MessageList
|
| 103 |
messages={messages}
|
| 104 |
isProcessing={busy}
|
| 105 |
+
sessionId={sessionId}
|
| 106 |
approveTools={approveTools}
|
| 107 |
onUndoLastTurn={undoLastTurn}
|
| 108 |
onEditAndRegenerate={editAndRegenerate}
|
pyproject.toml
CHANGED
|
@@ -26,6 +26,7 @@ dependencies = [
|
|
| 26 |
"uvicorn[standard]>=0.32.0",
|
| 27 |
"httpx>=0.27.0",
|
| 28 |
"websockets>=13.0",
|
|
|
|
| 29 |
]
|
| 30 |
|
| 31 |
[project.optional-dependencies]
|
|
|
|
| 26 |
"uvicorn[standard]>=0.32.0",
|
| 27 |
"httpx>=0.27.0",
|
| 28 |
"websockets>=13.0",
|
| 29 |
+
"apscheduler>=3.10,<4",
|
| 30 |
]
|
| 31 |
|
| 32 |
[project.optional-dependencies]
|
scripts/build_kpis.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Hourly KPI rollup for the session-trajectory dataset.
|
| 3 |
+
|
| 4 |
+
================================================================================
|
| 5 |
+
Data flow
|
| 6 |
+
================================================================================
|
| 7 |
+
|
| 8 |
+
┌────────────────────┐ heartbeat ┌────────────────────────────────┐
|
| 9 |
+
│ agent (CLI/web) │ ───────────────▶ │ hf-agent-sessions (dataset) │
|
| 10 |
+
│ Session.send_event│ │ sessions/YYYY-MM-DD/<id>.jsonl│
|
| 11 |
+
└────────────────────┘ └───────────────┬────────────────┘
|
| 12 |
+
│ cron @:05 each hour
|
| 13 |
+
▼
|
| 14 |
+
┌──────────────────────────────────┐
|
| 15 |
+
│ scripts/build_kpis.py │
|
| 16 |
+
│ (GitHub Actions) │
|
| 17 |
+
└───────────────┬──────────────────┘
|
| 18 |
+
│ upload CSV
|
| 19 |
+
▼
|
| 20 |
+
┌──────────────────────────────────┐
|
| 21 |
+
│ hf-agent-kpis (dataset) │
|
| 22 |
+
│ hourly/YYYY-MM-DD/HH.csv │
|
| 23 |
+
└──────────────────────────────────┘
|
| 24 |
+
|
| 25 |
+
Each hourly run reads today's + yesterday's session folders (to cover sessions
|
| 26 |
+
that crossed midnight), filters events into the target hour window
|
| 27 |
+
``[hour, hour+1h)``, computes aggregates, and writes one CSV at
|
| 28 |
+
``hourly/<date>/<HH>.csv`` in the target dataset. Uploads are idempotent —
|
| 29 |
+
re-running the same hour overwrites.
|
| 30 |
+
|
| 31 |
+
================================================================================
|
| 32 |
+
Metrics (one row per hour)
|
| 33 |
+
================================================================================
|
| 34 |
+
|
| 35 |
+
sessions — distinct session_ids with ≥1 event in window
|
| 36 |
+
users — distinct user ids (when present on session rows)
|
| 37 |
+
turns — sum of user-message counts across active sessions
|
| 38 |
+
llm_calls — count of llm_call events
|
| 39 |
+
tokens_prompt / _completion / _cache_read / _cache_creation
|
| 40 |
+
cost_usd — sum of llm_call.cost_usd
|
| 41 |
+
cache_hit_ratio — cache_read / (cache_read + prompt)
|
| 42 |
+
tool_success_rate — tool_output success=True / total tool_output
|
| 43 |
+
failure_rate — sessions that ended with an `error` event / sessions
|
| 44 |
+
regenerate_rate — sessions with any `undo_complete` event / sessions
|
| 45 |
+
time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
|
| 46 |
+
thumbs_up / thumbs_down
|
| 47 |
+
hf_jobs_submitted / _succeeded
|
| 48 |
+
gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
|
| 49 |
+
|
| 50 |
+
================================================================================
|
| 51 |
+
Usage
|
| 52 |
+
================================================================================
|
| 53 |
+
|
| 54 |
+
# Run for the most recently completed hour (default — the cron path):
|
| 55 |
+
python scripts/build_kpis.py
|
| 56 |
+
|
| 57 |
+
# Backfill last 24 hours:
|
| 58 |
+
python scripts/build_kpis.py --hours 24
|
| 59 |
+
|
| 60 |
+
# Explicit hour (UTC):
|
| 61 |
+
python scripts/build_kpis.py --datetime 2026-04-24T14
|
| 62 |
+
|
| 63 |
+
Env:
|
| 64 |
+
HF_TOKEN (or HF_KPI_WRITE_TOKEN) — write access to the target dataset.
|
| 65 |
+
|
| 66 |
+
================================================================================
|
| 67 |
+
Deploy
|
| 68 |
+
================================================================================
|
| 69 |
+
|
| 70 |
+
See ``.github/workflows/build-kpis.yml`` — runs every hour at :05. To provision:
|
| 71 |
+
|
| 72 |
+
1. Create the target dataset (once):
|
| 73 |
+
huggingface-cli repo create hf-agent-kpis --type dataset
|
| 74 |
+
2. Put ``HF_KPI_WRITE_TOKEN`` (or ``HF_TOKEN``) into repo Actions secrets.
|
| 75 |
+
3. Merge this file; the first scheduled run fires within the hour.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
from __future__ import annotations
|
| 79 |
+
|
| 80 |
+
import argparse
|
| 81 |
+
import io
|
| 82 |
+
import json
|
| 83 |
+
import logging
|
| 84 |
+
import os
|
| 85 |
+
import sys
|
| 86 |
+
import tempfile
|
| 87 |
+
from collections import defaultdict
|
| 88 |
+
from datetime import date, datetime, timedelta, timezone
|
| 89 |
+
from pathlib import Path
|
| 90 |
+
from typing import Any, Iterable
|
| 91 |
+
|
| 92 |
+
logger = logging.getLogger("build_kpis")
|
| 93 |
+
|
| 94 |
+
# Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
|
| 95 |
+
# only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
|
| 96 |
+
_FLAVOR_GPU_COUNT = {
|
| 97 |
+
"cpu-basic": 0, "cpu-upgrade": 0,
|
| 98 |
+
"t4-small": 1, "t4-medium": 1,
|
| 99 |
+
"l4x1": 1, "l4x4": 4,
|
| 100 |
+
"l40sx1": 1, "l40sx4": 4, "l40sx8": 8,
|
| 101 |
+
"a10g-small": 1, "a10g-large": 1, "a10g-largex2": 2, "a10g-largex4": 4,
|
| 102 |
+
"a100-large": 1, "a100x2": 2, "a100x4": 4, "a100x8": 8,
|
| 103 |
+
"h100": 1, "h100x8": 8,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _percentile(values: list[float], p: float) -> float:
|
| 108 |
+
if not values:
|
| 109 |
+
return 0.0
|
| 110 |
+
values = sorted(values)
|
| 111 |
+
k = (len(values) - 1) * p
|
| 112 |
+
f = int(k)
|
| 113 |
+
c = min(f + 1, len(values) - 1)
|
| 114 |
+
if f == c:
|
| 115 |
+
return float(values[f])
|
| 116 |
+
return float(values[f] + (values[c] - values[f]) * (k - f))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _parse_ts(s: Any) -> datetime | None:
|
| 120 |
+
if not s or not isinstance(s, str):
|
| 121 |
+
return None
|
| 122 |
+
try:
|
| 123 |
+
dt = datetime.fromisoformat(s)
|
| 124 |
+
except Exception:
|
| 125 |
+
return None
|
| 126 |
+
# Normalise to aware UTC so comparisons work against window bounds.
|
| 127 |
+
if dt.tzinfo is None:
|
| 128 |
+
dt = dt.replace(tzinfo=timezone.utc)
|
| 129 |
+
return dt
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[str]:
|
| 133 |
+
"""Yield repo-relative paths for all sessions under ``sessions/YYYY-MM-DD/``."""
|
| 134 |
+
prefix = f"sessions/{day.isoformat()}/"
|
| 135 |
+
try:
|
| 136 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning("list_repo_files(%s) failed: %s", repo_id, e)
|
| 139 |
+
return []
|
| 140 |
+
return [f for f in files if f.startswith(prefix) and f.endswith(".jsonl")]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _download_session(repo_id: str, path: str, token: str) -> dict | None:
|
| 144 |
+
"""Fetch one session JSONL and decode its single row.
|
| 145 |
+
|
| 146 |
+
``hf_hub_download`` caches; second run within the same process / runner
|
| 147 |
+
directory is near-free.
|
| 148 |
+
"""
|
| 149 |
+
from huggingface_hub import hf_hub_download
|
| 150 |
+
try:
|
| 151 |
+
local = hf_hub_download(
|
| 152 |
+
repo_id=repo_id, filename=path, repo_type="dataset", token=token,
|
| 153 |
+
)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
| 156 |
+
return None
|
| 157 |
+
try:
|
| 158 |
+
with open(local, "r") as f:
|
| 159 |
+
line = f.readline().strip()
|
| 160 |
+
if not line:
|
| 161 |
+
return None
|
| 162 |
+
row = json.loads(line)
|
| 163 |
+
# Session uploader stores messages/events as JSON strings — unpack.
|
| 164 |
+
for key in ("messages", "events", "tools"):
|
| 165 |
+
v = row.get(key)
|
| 166 |
+
if isinstance(v, str):
|
| 167 |
+
try:
|
| 168 |
+
row[key] = json.loads(v)
|
| 169 |
+
except Exception:
|
| 170 |
+
row[key] = []
|
| 171 |
+
return row
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.warning("parse(%s) failed: %s", path, e)
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _filter_session_to_window(
|
| 178 |
+
session: dict, start: datetime, end: datetime,
|
| 179 |
+
) -> dict | None:
|
| 180 |
+
"""Return a copy of ``session`` whose events are only those in ``[start, end)``.
|
| 181 |
+
|
| 182 |
+
``None`` if no event falls in the window — the caller drops the session
|
| 183 |
+
from this hour's aggregate.
|
| 184 |
+
"""
|
| 185 |
+
events = session.get("events") or []
|
| 186 |
+
in_window = []
|
| 187 |
+
for ev in events:
|
| 188 |
+
ts = _parse_ts(ev.get("timestamp"))
|
| 189 |
+
if ts is None:
|
| 190 |
+
continue
|
| 191 |
+
if start <= ts < end:
|
| 192 |
+
in_window.append(ev)
|
| 193 |
+
if not in_window:
|
| 194 |
+
return None
|
| 195 |
+
return {**session, "events": in_window}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _session_metrics(session: dict) -> dict:
|
| 199 |
+
"""Reduce a single session trajectory to its KPI contributions.
|
| 200 |
+
|
| 201 |
+
Assumes ``events`` are already filtered to the target window by the caller.
|
| 202 |
+
"""
|
| 203 |
+
# Pre-seed every numeric key so downstream aggregation can sum without
|
| 204 |
+
# having to special-case empty sessions.
|
| 205 |
+
out: dict = {
|
| 206 |
+
"sessions": 0, "turns": 0, "llm_calls": 0,
|
| 207 |
+
"tokens_prompt": 0, "tokens_completion": 0,
|
| 208 |
+
"tokens_cache_read": 0, "tokens_cache_creation": 0,
|
| 209 |
+
"cost_usd": 0.0,
|
| 210 |
+
"tool_calls_total": 0, "tool_calls_success": 0,
|
| 211 |
+
"failures": 0, "regenerate_sessions": 0,
|
| 212 |
+
"thumbs_up": 0, "thumbs_down": 0,
|
| 213 |
+
"hf_jobs_submitted": 0, "hf_jobs_succeeded": 0,
|
| 214 |
+
"first_tool_s": -1,
|
| 215 |
+
}
|
| 216 |
+
events = session.get("events") or []
|
| 217 |
+
messages = session.get("messages") or []
|
| 218 |
+
|
| 219 |
+
turn_count = sum(1 for m in messages if m.get("role") == "user")
|
| 220 |
+
out["turns"] = turn_count
|
| 221 |
+
out["sessions"] = 1
|
| 222 |
+
|
| 223 |
+
tool_success = 0
|
| 224 |
+
tool_total = 0
|
| 225 |
+
had_error = False
|
| 226 |
+
had_undo = False
|
| 227 |
+
first_tool_ts = None
|
| 228 |
+
session_start = session.get("session_start_time")
|
| 229 |
+
gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
|
| 230 |
+
jobs_submitted = 0
|
| 231 |
+
jobs_succeeded = 0
|
| 232 |
+
thumbs_up = 0
|
| 233 |
+
thumbs_down = 0
|
| 234 |
+
|
| 235 |
+
start_dt = _parse_ts(session_start)
|
| 236 |
+
|
| 237 |
+
for ev in events:
|
| 238 |
+
et = ev.get("event_type")
|
| 239 |
+
data = ev.get("data") or {}
|
| 240 |
+
ts = _parse_ts(ev.get("timestamp"))
|
| 241 |
+
|
| 242 |
+
if et == "llm_call":
|
| 243 |
+
out["llm_calls"] += 1
|
| 244 |
+
out["tokens_prompt"] += int(data.get("prompt_tokens") or 0)
|
| 245 |
+
out["tokens_completion"] += int(data.get("completion_tokens") or 0)
|
| 246 |
+
out["tokens_cache_read"] += int(data.get("cache_read_tokens") or 0)
|
| 247 |
+
out["tokens_cache_creation"] += int(data.get("cache_creation_tokens") or 0)
|
| 248 |
+
out["cost_usd"] += float(data.get("cost_usd") or 0.0)
|
| 249 |
+
|
| 250 |
+
elif et == "tool_output":
|
| 251 |
+
tool_total += 1
|
| 252 |
+
if data.get("success"):
|
| 253 |
+
tool_success += 1
|
| 254 |
+
if first_tool_ts is None and ts is not None and start_dt is not None:
|
| 255 |
+
first_tool_ts = (ts - start_dt).total_seconds()
|
| 256 |
+
|
| 257 |
+
elif et == "tool_call":
|
| 258 |
+
if first_tool_ts is None and ts is not None and start_dt is not None:
|
| 259 |
+
first_tool_ts = (ts - start_dt).total_seconds()
|
| 260 |
+
|
| 261 |
+
elif et == "error":
|
| 262 |
+
had_error = True
|
| 263 |
+
|
| 264 |
+
elif et == "undo_complete":
|
| 265 |
+
had_undo = True
|
| 266 |
+
|
| 267 |
+
elif et == "feedback":
|
| 268 |
+
rating = data.get("rating")
|
| 269 |
+
if rating == "up":
|
| 270 |
+
thumbs_up += 1
|
| 271 |
+
elif rating == "down":
|
| 272 |
+
thumbs_down += 1
|
| 273 |
+
|
| 274 |
+
elif et == "hf_job_submit":
|
| 275 |
+
jobs_submitted += 1
|
| 276 |
+
|
| 277 |
+
elif et == "hf_job_complete":
|
| 278 |
+
flavor = data.get("flavor") or "unknown"
|
| 279 |
+
status = (data.get("final_status") or "").lower()
|
| 280 |
+
wall = float(data.get("wall_time_s") or 0.0)
|
| 281 |
+
gpus = _FLAVOR_GPU_COUNT.get(flavor, 0)
|
| 282 |
+
gpu_hours_by_flavor[flavor] += wall * gpus / 3600.0
|
| 283 |
+
if status in ("completed", "succeeded", "success"):
|
| 284 |
+
jobs_succeeded += 1
|
| 285 |
+
|
| 286 |
+
out["tool_calls_total"] = tool_total
|
| 287 |
+
out["tool_calls_success"] = tool_success
|
| 288 |
+
out["failures"] = 1 if had_error else 0
|
| 289 |
+
out["regenerate_sessions"] = 1 if had_undo else 0
|
| 290 |
+
out["thumbs_up"] = thumbs_up
|
| 291 |
+
out["thumbs_down"] = thumbs_down
|
| 292 |
+
out["hf_jobs_submitted"] = jobs_submitted
|
| 293 |
+
out["hf_jobs_succeeded"] = jobs_succeeded
|
| 294 |
+
out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
|
| 295 |
+
out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
|
| 296 |
+
out["_user"] = session.get("user_id") or session.get("session_id")
|
| 297 |
+
return dict(out)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _aggregate(per_session: list[dict]) -> dict:
|
| 301 |
+
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 302 |
+
ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
|
| 303 |
+
gpu_hours: dict[str, float] = defaultdict(float)
|
| 304 |
+
for s in per_session:
|
| 305 |
+
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
| 306 |
+
gpu_hours[f] += h
|
| 307 |
+
|
| 308 |
+
total_sessions = sum(s["sessions"] for s in per_session)
|
| 309 |
+
total_turns = sum(s["turns"] for s in per_session)
|
| 310 |
+
tokens_prompt = sum(s["tokens_prompt"] for s in per_session)
|
| 311 |
+
tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
|
| 312 |
+
tool_total = sum(s["tool_calls_total"] for s in per_session)
|
| 313 |
+
tool_success = sum(s["tool_calls_success"] for s in per_session)
|
| 314 |
+
|
| 315 |
+
unique_users = {s.get("_user") for s in per_session if s.get("_user")}
|
| 316 |
+
|
| 317 |
+
return {
|
| 318 |
+
"sessions": total_sessions,
|
| 319 |
+
"users": len(unique_users),
|
| 320 |
+
"turns": total_turns,
|
| 321 |
+
"llm_calls": int(sum(s["llm_calls"] for s in per_session)),
|
| 322 |
+
"tokens_prompt": int(tokens_prompt),
|
| 323 |
+
"tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
|
| 324 |
+
"tokens_cache_read": int(tokens_cache_read),
|
| 325 |
+
"tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
|
| 326 |
+
"cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
|
| 327 |
+
"cache_hit_ratio": round(
|
| 328 |
+
tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
|
| 329 |
+
) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
|
| 330 |
+
"tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
|
| 331 |
+
"failure_rate": round(
|
| 332 |
+
sum(s["failures"] for s in per_session) / total_sessions, 4
|
| 333 |
+
) if total_sessions > 0 else 0.0,
|
| 334 |
+
"regenerate_rate": round(
|
| 335 |
+
sum(s["regenerate_sessions"] for s in per_session) / total_sessions, 4
|
| 336 |
+
) if total_sessions > 0 else 0.0,
|
| 337 |
+
"time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
|
| 338 |
+
"time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
|
| 339 |
+
"thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
|
| 340 |
+
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 341 |
+
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 342 |
+
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 343 |
+
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# Back-compat alias: older tests call _aggregate_day.
|
| 348 |
+
_aggregate_day = _aggregate
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _csv_cell(v: Any) -> str:
|
| 352 |
+
s = str(v)
|
| 353 |
+
if "," in s or '"' in s or "\n" in s:
|
| 354 |
+
return '"' + s.replace('"', '""') + '"'
|
| 355 |
+
return s
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _write_csv(
|
| 359 |
+
api, row: dict, bucket_key: str, path_in_repo: str, target_repo: str, token: str,
|
| 360 |
+
) -> None:
|
| 361 |
+
"""Render ``row`` to CSV with a leading ``bucket`` column and upload.
|
| 362 |
+
|
| 363 |
+
``bucket_key`` is the hour string (ISO ``YYYY-MM-DDTHH``) or date string;
|
| 364 |
+
written as the ``bucket`` column so downstream consumers can union all
|
| 365 |
+
CSVs without date-parsing paths. ``api`` is the caller's ``HfApi``
|
| 366 |
+
instance — reused so we don't spin up a fresh one per CSV.
|
| 367 |
+
"""
|
| 368 |
+
columns = list(row.keys())
|
| 369 |
+
buf = io.StringIO()
|
| 370 |
+
buf.write(",".join(["bucket", *columns]) + "\n")
|
| 371 |
+
buf.write(",".join([bucket_key, *[_csv_cell(row[c]) for c in columns]]) + "\n")
|
| 372 |
+
|
| 373 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as tmp:
|
| 374 |
+
tmp.write(buf.getvalue())
|
| 375 |
+
tmp_path = tmp.name
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
api.create_repo(
|
| 379 |
+
repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
|
| 380 |
+
)
|
| 381 |
+
api.upload_file(
|
| 382 |
+
path_or_fileobj=tmp_path,
|
| 383 |
+
path_in_repo=path_in_repo,
|
| 384 |
+
repo_id=target_repo,
|
| 385 |
+
repo_type="dataset",
|
| 386 |
+
token=token,
|
| 387 |
+
commit_message=f"KPIs for {bucket_key}",
|
| 388 |
+
)
|
| 389 |
+
finally:
|
| 390 |
+
try:
|
| 391 |
+
os.unlink(tmp_path)
|
| 392 |
+
except Exception:
|
| 393 |
+
pass
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def run_for_hour(
|
| 397 |
+
api, source_repo: str, target_repo: str, hour_dt: datetime, token: str,
|
| 398 |
+
) -> dict:
|
| 399 |
+
"""Roll up one UTC hour [hour_dt, hour_dt+1h).
|
| 400 |
+
|
| 401 |
+
Reads today's + yesterday's session folders so sessions that crossed
|
| 402 |
+
midnight land in the right hourly bucket.
|
| 403 |
+
"""
|
| 404 |
+
if hour_dt.tzinfo is None:
|
| 405 |
+
hour_dt = hour_dt.replace(tzinfo=timezone.utc)
|
| 406 |
+
window_start = hour_dt.replace(minute=0, second=0, microsecond=0)
|
| 407 |
+
window_end = window_start + timedelta(hours=1)
|
| 408 |
+
|
| 409 |
+
# Sessions partition by session_start_time date. A session that started
|
| 410 |
+
# at 23:50 yesterday can still emit events in today's first hours, so we
|
| 411 |
+
# look at both folders.
|
| 412 |
+
candidate_dates = {window_start.date(), (window_start - timedelta(days=1)).date()}
|
| 413 |
+
|
| 414 |
+
per_session: list[dict] = []
|
| 415 |
+
for d in sorted(candidate_dates):
|
| 416 |
+
for path in _iter_session_files(api, source_repo, d, token):
|
| 417 |
+
sess = _download_session(source_repo, path, token)
|
| 418 |
+
if not sess:
|
| 419 |
+
continue
|
| 420 |
+
windowed = _filter_session_to_window(sess, window_start, window_end)
|
| 421 |
+
if windowed is None:
|
| 422 |
+
continue
|
| 423 |
+
per_session.append(_session_metrics(windowed))
|
| 424 |
+
|
| 425 |
+
if not per_session:
|
| 426 |
+
logger.info("No sessions in window %s — skipping", window_start.isoformat())
|
| 427 |
+
return {}
|
| 428 |
+
|
| 429 |
+
row = _aggregate(per_session)
|
| 430 |
+
bucket_key = window_start.strftime("%Y-%m-%dT%H")
|
| 431 |
+
path_in_repo = f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv"
|
| 432 |
+
_write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
|
| 433 |
+
logger.info("Wrote KPIs for %s (%d sessions): %s",
|
| 434 |
+
bucket_key, per_session and len(per_session), row)
|
| 435 |
+
return row
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# Back-compat for daily backfills — unchanged behaviour.
|
| 439 |
+
def run_for_day(api, source_repo: str, target_repo: str, day: date, token: str) -> dict:
|
| 440 |
+
paths = _iter_session_files(api, source_repo, day, token)
|
| 441 |
+
per_session: list[dict] = []
|
| 442 |
+
for path in paths:
|
| 443 |
+
sess = _download_session(source_repo, path, token)
|
| 444 |
+
if not sess:
|
| 445 |
+
continue
|
| 446 |
+
per_session.append(_session_metrics(sess))
|
| 447 |
+
if not per_session:
|
| 448 |
+
logger.info("No sessions found for %s — skipping", day)
|
| 449 |
+
return {}
|
| 450 |
+
row = _aggregate(per_session)
|
| 451 |
+
path_in_repo = f"daily/{day.isoformat()}.csv"
|
| 452 |
+
_write_csv(api, row, day.isoformat(), path_in_repo, target_repo, token)
|
| 453 |
+
return row
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _parse_hour_arg(s: str) -> datetime:
|
| 457 |
+
"""Accept ``YYYY-MM-DDTHH`` or full ISO — always pinned to the start of the hour, UTC."""
|
| 458 |
+
dt = datetime.fromisoformat(s)
|
| 459 |
+
if dt.tzinfo is None:
|
| 460 |
+
dt = dt.replace(tzinfo=timezone.utc)
|
| 461 |
+
return dt.replace(minute=0, second=0, microsecond=0)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def main(argv: list[str] | None = None) -> int:
|
| 465 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
| 466 |
+
ap = argparse.ArgumentParser()
|
| 467 |
+
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 468 |
+
ap.add_argument("--target", default="smolagents/ml-intern-kpis")
|
| 469 |
+
ap.add_argument(
|
| 470 |
+
"--hours", type=int, default=1,
|
| 471 |
+
help="Number of trailing hours to roll up (default: 1 = last completed hour).",
|
| 472 |
+
)
|
| 473 |
+
ap.add_argument(
|
| 474 |
+
"--datetime", type=str, default=None,
|
| 475 |
+
help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
|
| 476 |
+
)
|
| 477 |
+
ap.add_argument(
|
| 478 |
+
"--daily-backfill", type=str, default=None,
|
| 479 |
+
help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
|
| 480 |
+
"Writes to daily/<date>.csv. Use for historical backfill only.",
|
| 481 |
+
)
|
| 482 |
+
args = ap.parse_args(argv)
|
| 483 |
+
|
| 484 |
+
token = (
|
| 485 |
+
os.environ.get("HF_KPI_WRITE_TOKEN")
|
| 486 |
+
or os.environ.get("HF_SESSION_UPLOAD_TOKEN")
|
| 487 |
+
or os.environ.get("HF_TOKEN")
|
| 488 |
+
or os.environ.get("HF_ADMIN_TOKEN")
|
| 489 |
+
)
|
| 490 |
+
if not token:
|
| 491 |
+
logger.error(
|
| 492 |
+
"No HF token found. Set one of: HF_KPI_WRITE_TOKEN, "
|
| 493 |
+
"HF_SESSION_UPLOAD_TOKEN, HF_TOKEN, HF_ADMIN_TOKEN."
|
| 494 |
+
)
|
| 495 |
+
return 1
|
| 496 |
+
|
| 497 |
+
from huggingface_hub import HfApi
|
| 498 |
+
api = HfApi()
|
| 499 |
+
|
| 500 |
+
if args.daily_backfill:
|
| 501 |
+
run_for_day(api, args.source, args.target, date.fromisoformat(args.daily_backfill), token)
|
| 502 |
+
return 0
|
| 503 |
+
|
| 504 |
+
if args.datetime:
|
| 505 |
+
target_hours = [_parse_hour_arg(args.datetime)]
|
| 506 |
+
else:
|
| 507 |
+
now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)
|
| 508 |
+
# Roll up *completed* hours: start from the hour before ``now``.
|
| 509 |
+
target_hours = [now - timedelta(hours=i) for i in range(1, args.hours + 1)]
|
| 510 |
+
|
| 511 |
+
for hour in target_hours:
|
| 512 |
+
run_for_hour(api, args.source, args.target, hour, token)
|
| 513 |
+
return 0
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
sys.exit(main())
|
scripts/build_sft.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Export session trajectories as raw multi-turn tool-calling SFT data.
|
| 3 |
+
|
| 4 |
+
Reads the source sessions dataset (JSONL, one file per session at
|
| 5 |
+
``sessions/YYYY-MM-DD/<session_id>.jsonl``) and writes a re-shaped row to a
|
| 6 |
+
target dataset at ``sft/YYYY-MM-DD/<session_id>.jsonl``.
|
| 7 |
+
|
| 8 |
+
**No filtering, no cleaning, no dedup.** Raw passthrough of messages + tools,
|
| 9 |
+
with session-level metadata and derived tags (see ``agent/sft/tagger.py``)
|
| 10 |
+
attached for downstream slicing.
|
| 11 |
+
|
| 12 |
+
Output row schema::
|
| 13 |
+
|
| 14 |
+
{
|
| 15 |
+
"session_id": "...",
|
| 16 |
+
"model": "claude-opus-4-6",
|
| 17 |
+
"timestamp": "2026-04-24T...",
|
| 18 |
+
"tags": ["tool:hf_jobs", "gpu:a100", "hf_job:succeeded", ...],
|
| 19 |
+
"messages": [...], # OpenAI / TRL SFTTrainer format
|
| 20 |
+
"tools": [...] # OpenAI tool schemas the session had access to
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
Usage::
|
| 24 |
+
|
| 25 |
+
python scripts/build_sft.py \\
|
| 26 |
+
--source smolagents/ml-intern-sessions \\
|
| 27 |
+
--target smolagents/ml-intern-sft \\
|
| 28 |
+
--days 7
|
| 29 |
+
|
| 30 |
+
Env:
|
| 31 |
+
HF_TOKEN (or HF_SFT_WRITE_TOKEN) — write access to target dataset.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import json
|
| 38 |
+
import logging
|
| 39 |
+
import os
|
| 40 |
+
import sys
|
| 41 |
+
import tempfile
|
| 42 |
+
from datetime import date, datetime, timedelta, timezone
|
| 43 |
+
from typing import Iterable
|
| 44 |
+
|
| 45 |
+
# Make ``agent`` importable when this script is run outside the project venv.
|
| 46 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 47 |
+
|
| 48 |
+
from agent.sft.tagger import tag_session # noqa: E402
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger("build_sft")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[str]:
|
| 54 |
+
prefix = f"sessions/{day.isoformat()}/"
|
| 55 |
+
try:
|
| 56 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.warning("list_repo_files(%s) failed: %s", repo_id, e)
|
| 59 |
+
return []
|
| 60 |
+
return [f for f in files if f.startswith(prefix) and f.endswith(".jsonl")]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
|
| 64 |
+
from huggingface_hub import hf_hub_download
|
| 65 |
+
try:
|
| 66 |
+
local = hf_hub_download(
|
| 67 |
+
repo_id=repo_id, filename=path, repo_type="dataset", token=token,
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
| 71 |
+
return None
|
| 72 |
+
try:
|
| 73 |
+
with open(local, "r") as f:
|
| 74 |
+
line = f.readline().strip()
|
| 75 |
+
if not line:
|
| 76 |
+
return None
|
| 77 |
+
row = json.loads(line)
|
| 78 |
+
# Session uploader stores messages/events/tools as JSON strings.
|
| 79 |
+
for key in ("messages", "events", "tools"):
|
| 80 |
+
v = row.get(key)
|
| 81 |
+
if isinstance(v, str):
|
| 82 |
+
try:
|
| 83 |
+
row[key] = json.loads(v)
|
| 84 |
+
except Exception:
|
| 85 |
+
row[key] = []
|
| 86 |
+
return row
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning("parse(%s) failed: %s", path, e)
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _reshape_to_sft(row: dict) -> dict:
|
| 93 |
+
"""Raw passthrough: reshape one session row into SFT format + tags.
|
| 94 |
+
|
| 95 |
+
Trajectories predating the ``tools`` addition to ``get_trajectory`` will
|
| 96 |
+
have an empty tools list — still valid, just less useful downstream.
|
| 97 |
+
"""
|
| 98 |
+
trajectory = {
|
| 99 |
+
"events": row.get("events") or [],
|
| 100 |
+
"messages": row.get("messages") or [],
|
| 101 |
+
"model_name": row.get("model_name"),
|
| 102 |
+
}
|
| 103 |
+
return {
|
| 104 |
+
"session_id": row.get("session_id"),
|
| 105 |
+
"model": row.get("model_name"),
|
| 106 |
+
"timestamp": row.get("session_start_time"),
|
| 107 |
+
"tags": tag_session(trajectory),
|
| 108 |
+
"messages": row.get("messages") or [],
|
| 109 |
+
"tools": row.get("tools") or [],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None:
|
| 114 |
+
session_id = row["session_id"]
|
| 115 |
+
path_in_repo = f"sft/{day.isoformat()}/{session_id}.jsonl"
|
| 116 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
|
| 117 |
+
json.dump(row, tmp, ensure_ascii=False)
|
| 118 |
+
tmp_path = tmp.name
|
| 119 |
+
try:
|
| 120 |
+
api.create_repo(
|
| 121 |
+
repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
|
| 122 |
+
)
|
| 123 |
+
api.upload_file(
|
| 124 |
+
path_or_fileobj=tmp_path,
|
| 125 |
+
path_in_repo=path_in_repo,
|
| 126 |
+
repo_id=target_repo,
|
| 127 |
+
repo_type="dataset",
|
| 128 |
+
token=token,
|
| 129 |
+
commit_message=f"Add SFT row {session_id}",
|
| 130 |
+
)
|
| 131 |
+
finally:
|
| 132 |
+
try:
|
| 133 |
+
os.unlink(tmp_path)
|
| 134 |
+
except Exception:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def run_for_day(
|
| 139 |
+
api, source_repo: str, target_repo: str, day: date, token: str,
|
| 140 |
+
) -> int:
|
| 141 |
+
paths = _iter_session_files(api, source_repo, day, token)
|
| 142 |
+
n = 0
|
| 143 |
+
for path in paths:
|
| 144 |
+
sess = _download_and_parse(source_repo, path, token)
|
| 145 |
+
if not sess:
|
| 146 |
+
continue
|
| 147 |
+
sft_row = _reshape_to_sft(sess)
|
| 148 |
+
if not sft_row.get("session_id"):
|
| 149 |
+
continue
|
| 150 |
+
try:
|
| 151 |
+
_upload_row(api, sft_row, day, target_repo, token)
|
| 152 |
+
n += 1
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning("upload failed for %s: %s", sft_row["session_id"], e)
|
| 155 |
+
logger.info("Exported %d sessions for %s", n, day)
|
| 156 |
+
return n
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main(argv: list[str] | None = None) -> int:
|
| 160 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
| 161 |
+
ap = argparse.ArgumentParser()
|
| 162 |
+
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 163 |
+
ap.add_argument("--target", default="smolagents/ml-intern-sft")
|
| 164 |
+
ap.add_argument(
|
| 165 |
+
"--days", type=int, default=1,
|
| 166 |
+
help="Number of trailing days to export (default: 1 = yesterday).",
|
| 167 |
+
)
|
| 168 |
+
ap.add_argument(
|
| 169 |
+
"--date", type=str, default=None,
|
| 170 |
+
help="Single YYYY-MM-DD to export; overrides --days.",
|
| 171 |
+
)
|
| 172 |
+
args = ap.parse_args(argv)
|
| 173 |
+
|
| 174 |
+
token = (
|
| 175 |
+
os.environ.get("HF_SFT_WRITE_TOKEN")
|
| 176 |
+
or os.environ.get("HF_SESSION_UPLOAD_TOKEN")
|
| 177 |
+
or os.environ.get("HF_TOKEN")
|
| 178 |
+
or os.environ.get("HF_ADMIN_TOKEN")
|
| 179 |
+
)
|
| 180 |
+
if not token:
|
| 181 |
+
logger.error(
|
| 182 |
+
"No HF token found. Set one of: HF_SFT_WRITE_TOKEN, "
|
| 183 |
+
"HF_SESSION_UPLOAD_TOKEN, HF_TOKEN, HF_ADMIN_TOKEN."
|
| 184 |
+
)
|
| 185 |
+
return 1
|
| 186 |
+
|
| 187 |
+
from huggingface_hub import HfApi
|
| 188 |
+
api = HfApi()
|
| 189 |
+
|
| 190 |
+
if args.date:
|
| 191 |
+
target_days = [date.fromisoformat(args.date)]
|
| 192 |
+
else:
|
| 193 |
+
today = datetime.now(timezone.utc).date()
|
| 194 |
+
target_days = [today - timedelta(days=i) for i in range(1, args.days + 1)]
|
| 195 |
+
|
| 196 |
+
total = 0
|
| 197 |
+
for day in target_days:
|
| 198 |
+
total += run_for_day(api, args.source, args.target, day, token)
|
| 199 |
+
logger.info("Total exported: %d sessions", total)
|
| 200 |
+
return 0
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
sys.exit(main())
|
tests/unit/test_build_kpis.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the KPI rollup math.
|
| 2 |
+
|
| 3 |
+
We exercise the pure functions (``_session_metrics`` and ``_aggregate_day``)
|
| 4 |
+
on hand-crafted session trajectories — no network, no HF Hub.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import importlib.util
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _load():
|
| 13 |
+
"""Load ``scripts/build_kpis.py`` without treating ``scripts`` as a package."""
|
| 14 |
+
path = Path(__file__).parent.parent.parent / "scripts" / "build_kpis.py"
|
| 15 |
+
spec = importlib.util.spec_from_file_location("build_kpis", path)
|
| 16 |
+
mod = importlib.util.module_from_spec(spec)
|
| 17 |
+
sys.modules["build_kpis"] = mod
|
| 18 |
+
spec.loader.exec_module(mod) # type: ignore
|
| 19 |
+
return mod
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _ev(event_type, data=None, ts="2026-04-24T10:00:00"):
|
| 23 |
+
return {"timestamp": ts, "event_type": event_type, "data": data or {}}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _session(events, user_id="u1", start="2026-04-24T09:59:00"):
|
| 27 |
+
return {
|
| 28 |
+
"session_id": "sess-" + user_id,
|
| 29 |
+
"session_start_time": start,
|
| 30 |
+
"session_end_time": "2026-04-24T10:05:00",
|
| 31 |
+
"model_name": "claude-opus-4-6",
|
| 32 |
+
"messages": [{"role": "user", "content": "hi"}],
|
| 33 |
+
"events": events,
|
| 34 |
+
"user_id": user_id,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_llm_call_accumulates_tokens_and_cost():
|
| 39 |
+
mod = _load()
|
| 40 |
+
events = [
|
| 41 |
+
_ev("llm_call", {
|
| 42 |
+
"prompt_tokens": 100, "completion_tokens": 50,
|
| 43 |
+
"cache_read_tokens": 40, "cache_creation_tokens": 10,
|
| 44 |
+
"cost_usd": 0.01,
|
| 45 |
+
}),
|
| 46 |
+
_ev("llm_call", {
|
| 47 |
+
"prompt_tokens": 200, "completion_tokens": 100,
|
| 48 |
+
"cache_read_tokens": 80, "cost_usd": 0.02,
|
| 49 |
+
}),
|
| 50 |
+
]
|
| 51 |
+
m = mod._session_metrics(_session(events))
|
| 52 |
+
assert m["llm_calls"] == 2
|
| 53 |
+
assert m["tokens_prompt"] == 300
|
| 54 |
+
assert m["tokens_completion"] == 150
|
| 55 |
+
assert m["tokens_cache_read"] == 120
|
| 56 |
+
assert m["tokens_cache_creation"] == 10
|
| 57 |
+
assert abs(m["cost_usd"] - 0.03) < 1e-9
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_tool_success_rate_and_first_action():
|
| 61 |
+
mod = _load()
|
| 62 |
+
events = [
|
| 63 |
+
_ev("tool_call", {"tool": "bash"}, ts="2026-04-24T10:00:05"),
|
| 64 |
+
_ev("tool_output", {"success": True}),
|
| 65 |
+
_ev("tool_output", {"success": False}),
|
| 66 |
+
]
|
| 67 |
+
m = mod._session_metrics(_session(events))
|
| 68 |
+
assert m["tool_calls_total"] == 2
|
| 69 |
+
assert m["tool_calls_success"] == 1
|
| 70 |
+
# 65s from start to first action
|
| 71 |
+
assert m["first_tool_s"] == 65
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_hf_job_gpu_hours():
|
| 75 |
+
mod = _load()
|
| 76 |
+
events = [
|
| 77 |
+
_ev("hf_job_submit", {"flavor": "a100-large", "job_id": "j1"}),
|
| 78 |
+
_ev("hf_job_complete", {
|
| 79 |
+
"flavor": "a100-large",
|
| 80 |
+
"final_status": "COMPLETED",
|
| 81 |
+
"wall_time_s": 3600,
|
| 82 |
+
}),
|
| 83 |
+
]
|
| 84 |
+
m = mod._session_metrics(_session(events))
|
| 85 |
+
assert m["hf_jobs_submitted"] == 1
|
| 86 |
+
assert m["hf_jobs_succeeded"] == 1
|
| 87 |
+
# a100-large = 1 gpu * 1 hour = 1 gpu-hour
|
| 88 |
+
assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_feedback_counts():
|
| 92 |
+
mod = _load()
|
| 93 |
+
events = [
|
| 94 |
+
_ev("feedback", {"rating": "up"}),
|
| 95 |
+
_ev("feedback", {"rating": "up"}),
|
| 96 |
+
_ev("feedback", {"rating": "down"}),
|
| 97 |
+
]
|
| 98 |
+
m = mod._session_metrics(_session(events))
|
| 99 |
+
assert m["thumbs_up"] == 2
|
| 100 |
+
assert m["thumbs_down"] == 1
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_aggregate_day_cache_hit_and_users():
|
| 104 |
+
mod = _load()
|
| 105 |
+
s1 = mod._session_metrics(_session(
|
| 106 |
+
[_ev("llm_call", {"prompt_tokens": 100, "cache_read_tokens": 400, "cost_usd": 0.5})],
|
| 107 |
+
user_id="u1",
|
| 108 |
+
))
|
| 109 |
+
s2 = mod._session_metrics(_session(
|
| 110 |
+
[_ev("llm_call", {"prompt_tokens": 200, "cache_read_tokens": 100, "cost_usd": 1.0})],
|
| 111 |
+
user_id="u2",
|
| 112 |
+
))
|
| 113 |
+
row = mod._aggregate_day([s1, s2])
|
| 114 |
+
assert row["sessions"] == 2
|
| 115 |
+
assert row["users"] == 2
|
| 116 |
+
assert row["tokens_prompt"] == 300
|
| 117 |
+
assert row["tokens_cache_read"] == 500
|
| 118 |
+
# 500 / (500 + 300) = 0.625
|
| 119 |
+
assert abs(row["cache_hit_ratio"] - 0.625) < 1e-9
|
| 120 |
+
assert abs(row["cost_usd"] - 1.5) < 1e-9
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_failure_and_regenerate_rates():
|
| 124 |
+
mod = _load()
|
| 125 |
+
s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
|
| 126 |
+
s2 = mod._session_metrics(_session([_ev("undo_complete")], user_id="b"))
|
| 127 |
+
s3 = mod._session_metrics(_session([], user_id="c"))
|
| 128 |
+
row = mod._aggregate_day([s1, s2, s3])
|
| 129 |
+
assert row["failure_rate"] == round(1 / 3, 4)
|
| 130 |
+
assert row["regenerate_rate"] == round(1 / 3, 4)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def test_window_filter_keeps_only_events_in_range():
|
| 134 |
+
from datetime import datetime, timezone
|
| 135 |
+
mod = _load()
|
| 136 |
+
events = [
|
| 137 |
+
_ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00"),
|
| 138 |
+
_ev("llm_call", {"prompt_tokens": 200}, ts="2026-04-24T10:05:00"),
|
| 139 |
+
_ev("tool_call", {"tool": "bash"}, ts="2026-04-24T10:30:00"),
|
| 140 |
+
_ev("llm_call", {"prompt_tokens": 400}, ts="2026-04-24T11:10:00"),
|
| 141 |
+
]
|
| 142 |
+
session = _session(events, start="2026-04-24T09:44:00")
|
| 143 |
+
# Only events in [10:00, 11:00) should remain.
|
| 144 |
+
window_start = datetime(2026, 4, 24, 10, 0, 0, tzinfo=timezone.utc)
|
| 145 |
+
window_end = datetime(2026, 4, 24, 11, 0, 0, tzinfo=timezone.utc)
|
| 146 |
+
windowed = mod._filter_session_to_window(session, window_start, window_end)
|
| 147 |
+
assert windowed is not None
|
| 148 |
+
types = [e["event_type"] for e in windowed["events"]]
|
| 149 |
+
assert types == ["llm_call", "tool_call"]
|
| 150 |
+
# Metrics only reflect in-window events.
|
| 151 |
+
m = mod._session_metrics(windowed)
|
| 152 |
+
assert m["tokens_prompt"] == 200
|
| 153 |
+
assert m["llm_calls"] == 1
|
| 154 |
+
assert m["tool_calls_total"] == 0 # tool_call not tool_output
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_window_filter_returns_none_when_nothing_in_range():
|
| 158 |
+
from datetime import datetime, timezone
|
| 159 |
+
mod = _load()
|
| 160 |
+
events = [_ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00")]
|
| 161 |
+
session = _session(events)
|
| 162 |
+
window_start = datetime(2026, 4, 24, 10, 0, 0, tzinfo=timezone.utc)
|
| 163 |
+
window_end = datetime(2026, 4, 24, 11, 0, 0, tzinfo=timezone.utc)
|
| 164 |
+
assert mod._filter_session_to_window(session, window_start, window_end) is None
|
tests/unit/test_build_sft.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke test for the SFT reshape — raw passthrough with tags attached."""
|
| 2 |
+
|
| 3 |
+
import importlib.util
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _load():
|
| 9 |
+
path = Path(__file__).parent.parent.parent / "scripts" / "build_sft.py"
|
| 10 |
+
spec = importlib.util.spec_from_file_location("build_sft", path)
|
| 11 |
+
mod = importlib.util.module_from_spec(spec)
|
| 12 |
+
sys.modules["build_sft"] = mod
|
| 13 |
+
spec.loader.exec_module(mod) # type: ignore
|
| 14 |
+
return mod
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _session_row():
|
| 18 |
+
return {
|
| 19 |
+
"session_id": "abc",
|
| 20 |
+
"session_start_time": "2026-04-24T10:00:00",
|
| 21 |
+
"session_end_time": "2026-04-24T10:05:00",
|
| 22 |
+
"model_name": "claude-opus-4-6",
|
| 23 |
+
"messages": [
|
| 24 |
+
{"role": "system", "content": "You are an agent"},
|
| 25 |
+
{"role": "user", "content": "fine-tune llama"},
|
| 26 |
+
{"role": "assistant", "content": None, "tool_calls": [
|
| 27 |
+
{"id": "c1", "type": "function",
|
| 28 |
+
"function": {"name": "hf_jobs", "arguments": '{"script":"from trl import SFTTrainer"}'}},
|
| 29 |
+
]},
|
| 30 |
+
{"role": "tool", "tool_call_id": "c1", "content": "ok"},
|
| 31 |
+
{"role": "assistant", "content": "done"},
|
| 32 |
+
],
|
| 33 |
+
"events": [
|
| 34 |
+
{"timestamp": "2026-04-24T10:00:05", "event_type": "tool_call",
|
| 35 |
+
"data": {"tool": "hf_jobs",
|
| 36 |
+
"arguments": {"script": "from trl import SFTTrainer"}}},
|
| 37 |
+
{"timestamp": "2026-04-24T10:00:06", "event_type": "hf_job_submit",
|
| 38 |
+
"data": {"flavor": "a100-large", "push_to_hub": True}},
|
| 39 |
+
{"timestamp": "2026-04-24T10:45:00", "event_type": "hf_job_complete",
|
| 40 |
+
"data": {"flavor": "a100-large", "final_status": "COMPLETED",
|
| 41 |
+
"wall_time_s": 2700}},
|
| 42 |
+
{"timestamp": "2026-04-24T10:45:05", "event_type": "turn_complete",
|
| 43 |
+
"data": {}},
|
| 44 |
+
],
|
| 45 |
+
"tools": [{"type": "function", "function": {"name": "hf_jobs"}}],
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_reshape_preserves_messages_and_tools_and_adds_tags():
|
| 50 |
+
mod = _load()
|
| 51 |
+
row = mod._reshape_to_sft(_session_row())
|
| 52 |
+
assert row["session_id"] == "abc"
|
| 53 |
+
assert row["model"] == "claude-opus-4-6"
|
| 54 |
+
assert row["timestamp"] == "2026-04-24T10:00:00"
|
| 55 |
+
# Messages preserved verbatim, in order, with tool_calls + tool role rows.
|
| 56 |
+
assert len(row["messages"]) == 5
|
| 57 |
+
assert row["messages"][2]["tool_calls"][0]["function"]["name"] == "hf_jobs"
|
| 58 |
+
assert row["messages"][3]["role"] == "tool"
|
| 59 |
+
# Tools preserved verbatim.
|
| 60 |
+
assert row["tools"] == [{"type": "function", "function": {"name": "hf_jobs"}}]
|
| 61 |
+
# Tags include the expected signals.
|
| 62 |
+
tags = set(row["tags"])
|
| 63 |
+
assert "tool:hf_jobs" in tags
|
| 64 |
+
assert "hf_job:succeeded" in tags
|
| 65 |
+
assert "hf_job:push_to_hub" in tags
|
| 66 |
+
assert "gpu:a100" in tags
|
| 67 |
+
assert "outcome:completed" in tags
|
| 68 |
+
assert "task:training" in tags
|
| 69 |
+
assert "model:opus" in tags
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_reshape_handles_missing_tools_field():
|
| 73 |
+
mod = _load()
|
| 74 |
+
row = _session_row()
|
| 75 |
+
del row["tools"]
|
| 76 |
+
out = mod._reshape_to_sft(row)
|
| 77 |
+
assert out["tools"] == []
|
| 78 |
+
assert isinstance(out["tags"], list) # still computes tags
|
tests/unit/test_heartbeat.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Heartbeat + stable-local-path tests for Session.
|
| 2 |
+
|
| 3 |
+
We don't spin up the real agent loop — we build a minimal Session with a
|
| 4 |
+
stubbed config and an in-memory queue, then call send_event repeatedly while
|
| 5 |
+
monkeypatching time.monotonic to simulate seconds passing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from unittest.mock import patch
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
|
| 15 |
+
from agent.core.session import Event, Session
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _FakeConfig:
|
| 19 |
+
model_name = "claude-opus-4-6"
|
| 20 |
+
save_sessions = True
|
| 21 |
+
session_dataset_repo = "fake/repo"
|
| 22 |
+
auto_save_interval = 1
|
| 23 |
+
heartbeat_interval_s = 60
|
| 24 |
+
max_iterations = 10
|
| 25 |
+
yolo_mode = False
|
| 26 |
+
confirm_cpu_jobs = False
|
| 27 |
+
auto_file_upload = False
|
| 28 |
+
reasoning_effort = None
|
| 29 |
+
mcpServers: dict = {}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _mk_session(tmp_path: Path) -> Session:
|
| 33 |
+
import os
|
| 34 |
+
os.chdir(tmp_path) # so session_logs/ lands under tmp_path
|
| 35 |
+
# Stub out the context manager to avoid litellm lookups.
|
| 36 |
+
from agent.context_manager.manager import ContextManager
|
| 37 |
+
cm = ContextManager.__new__(ContextManager)
|
| 38 |
+
cm.items = []
|
| 39 |
+
cm.tool_specs = []
|
| 40 |
+
cm.model_max_tokens = 200_000
|
| 41 |
+
cm.running_context_usage = 0
|
| 42 |
+
cm.compact_size = 0.1
|
| 43 |
+
cm.untouched_messages = 5
|
| 44 |
+
cm.hf_token = None
|
| 45 |
+
cm.local_mode = True
|
| 46 |
+
s = Session(
|
| 47 |
+
event_queue=asyncio.Queue(),
|
| 48 |
+
config=_FakeConfig(),
|
| 49 |
+
tool_router=None,
|
| 50 |
+
context_manager=cm,
|
| 51 |
+
hf_token=None,
|
| 52 |
+
local_mode=True,
|
| 53 |
+
)
|
| 54 |
+
return s
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_heartbeat_fires_after_interval(tmp_path, monkeypatch):
|
| 58 |
+
# Use asyncio.run rather than pytest-asyncio so the test works without the
|
| 59 |
+
# plugin installed (same pattern elsewhere in this repo).
|
| 60 |
+
async def body():
|
| 61 |
+
s = _mk_session(tmp_path)
|
| 62 |
+
calls = []
|
| 63 |
+
|
| 64 |
+
def fake_upload(repo_id):
|
| 65 |
+
calls.append(repo_id)
|
| 66 |
+
return "fake/path.json"
|
| 67 |
+
|
| 68 |
+
monkeypatch.setattr(s, "save_and_upload_detached", fake_upload)
|
| 69 |
+
|
| 70 |
+
# t=0: first event, should NOT trigger (initial _last_heartbeat_ts = now)
|
| 71 |
+
with patch("agent.core.telemetry.time.monotonic", return_value=100.0):
|
| 72 |
+
s._last_heartbeat_ts = 100.0
|
| 73 |
+
await s.send_event(Event(event_type="x"))
|
| 74 |
+
assert calls == []
|
| 75 |
+
|
| 76 |
+
# t=+30s: still under interval → no save
|
| 77 |
+
with patch("agent.core.telemetry.time.monotonic", return_value=130.0):
|
| 78 |
+
await s.send_event(Event(event_type="y"))
|
| 79 |
+
assert calls == []
|
| 80 |
+
|
| 81 |
+
# t=+61s: over 60s → save fires once
|
| 82 |
+
with patch("agent.core.telemetry.time.monotonic", return_value=161.0):
|
| 83 |
+
await s.send_event(Event(event_type="z"))
|
| 84 |
+
# create_task runs on the event loop; wait for the to_thread to complete
|
| 85 |
+
await asyncio.sleep(0.05)
|
| 86 |
+
assert calls == ["fake/repo"]
|
| 87 |
+
|
| 88 |
+
# Next event shortly after → no second save (interval resets to 161)
|
| 89 |
+
with patch("agent.core.telemetry.time.monotonic", return_value=170.0):
|
| 90 |
+
await s.send_event(Event(event_type="w"))
|
| 91 |
+
await asyncio.sleep(0.05)
|
| 92 |
+
assert len(calls) == 1
|
| 93 |
+
|
| 94 |
+
asyncio.run(body())
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_stable_local_path_overwrites(tmp_path):
|
| 98 |
+
import os
|
| 99 |
+
os.chdir(tmp_path)
|
| 100 |
+
from agent.context_manager.manager import ContextManager
|
| 101 |
+
cm = ContextManager.__new__(ContextManager)
|
| 102 |
+
cm.items = []
|
| 103 |
+
cm.tool_specs = []
|
| 104 |
+
cm.model_max_tokens = 200_000
|
| 105 |
+
cm.running_context_usage = 0
|
| 106 |
+
cm.compact_size = 0.1
|
| 107 |
+
cm.untouched_messages = 5
|
| 108 |
+
cm.hf_token = None
|
| 109 |
+
cm.local_mode = True
|
| 110 |
+
|
| 111 |
+
s = Session(
|
| 112 |
+
event_queue=asyncio.Queue(),
|
| 113 |
+
config=_FakeConfig(),
|
| 114 |
+
tool_router=None,
|
| 115 |
+
context_manager=cm,
|
| 116 |
+
hf_token=None,
|
| 117 |
+
local_mode=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
p1 = s.save_trajectory_local(directory="session_logs")
|
| 121 |
+
assert p1 is not None
|
| 122 |
+
p2 = s.save_trajectory_local(directory="session_logs")
|
| 123 |
+
p3 = s.save_trajectory_local(directory="session_logs")
|
| 124 |
+
# All three saves land on the same file — heartbeat should not spam files.
|
| 125 |
+
assert p1 == p2 == p3
|
| 126 |
+
files = list(Path("session_logs").glob("session_*.json"))
|
| 127 |
+
# Exactly one final file; the .tmp should be renamed away.
|
| 128 |
+
assert len(files) == 1
|
| 129 |
+
|
| 130 |
+
# File is valid JSON (atomic write → no torn content).
|
| 131 |
+
with open(p1) as f:
|
| 132 |
+
data = json.load(f)
|
| 133 |
+
assert data["session_id"] == s.session_id
|
| 134 |
+
assert data["upload_status"] == "pending"
|
tests/unit/test_kpis_scheduler.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke tests for backend/kpis_scheduler.py.
|
| 2 |
+
|
| 3 |
+
Exercise the pure / fast paths only:
|
| 4 |
+
* token resolution order
|
| 5 |
+
* build_kpis import path
|
| 6 |
+
* start()/shutdown() lifecycle without APScheduler actually running a job
|
| 7 |
+
* backfill() passes the right hour values through to _run_hour
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import asyncio
|
| 13 |
+
import importlib.util
|
| 14 |
+
import sys
|
| 15 |
+
from datetime import datetime, timezone
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _load():
|
| 20 |
+
path = Path(__file__).parent.parent.parent / "backend" / "kpis_scheduler.py"
|
| 21 |
+
spec = importlib.util.spec_from_file_location("kpis_scheduler", path)
|
| 22 |
+
mod = importlib.util.module_from_spec(spec)
|
| 23 |
+
sys.modules["kpis_scheduler"] = mod
|
| 24 |
+
assert spec.loader is not None
|
| 25 |
+
spec.loader.exec_module(mod)
|
| 26 |
+
return mod
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_token_resolution_order(monkeypatch):
|
| 30 |
+
mod = _load()
|
| 31 |
+
for var in ("HF_KPI_WRITE_TOKEN", "HF_SESSION_UPLOAD_TOKEN", "HF_TOKEN", "HF_ADMIN_TOKEN"):
|
| 32 |
+
monkeypatch.delenv(var, raising=False)
|
| 33 |
+
assert mod._resolve_token() is None
|
| 34 |
+
|
| 35 |
+
monkeypatch.setenv("HF_ADMIN_TOKEN", "admin")
|
| 36 |
+
assert mod._resolve_token() == "admin"
|
| 37 |
+
|
| 38 |
+
monkeypatch.setenv("HF_TOKEN", "generic")
|
| 39 |
+
assert mod._resolve_token() == "generic"
|
| 40 |
+
|
| 41 |
+
monkeypatch.setenv("HF_SESSION_UPLOAD_TOKEN", "sessions")
|
| 42 |
+
assert mod._resolve_token() == "sessions"
|
| 43 |
+
|
| 44 |
+
monkeypatch.setenv("HF_KPI_WRITE_TOKEN", "kpis")
|
| 45 |
+
assert mod._resolve_token() == "kpis"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_load_build_kpis_exposes_run_for_hour():
|
| 49 |
+
mod = _load()
|
| 50 |
+
bk = mod._load_build_kpis()
|
| 51 |
+
assert hasattr(bk, "run_for_hour")
|
| 52 |
+
assert callable(bk.run_for_hour)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_backfill_calls_run_hour_for_each_hour(monkeypatch):
|
| 56 |
+
mod = _load()
|
| 57 |
+
monkeypatch.setenv("HF_KPI_WRITE_TOKEN", "x")
|
| 58 |
+
calls: list[datetime] = []
|
| 59 |
+
|
| 60 |
+
async def fake_run_hour(hour_dt):
|
| 61 |
+
calls.append(hour_dt)
|
| 62 |
+
|
| 63 |
+
monkeypatch.setattr(mod, "_run_hour", fake_run_hour)
|
| 64 |
+
asyncio.run(mod.backfill(hours=3))
|
| 65 |
+
assert len(calls) == 3
|
| 66 |
+
# Hours are returned most-recent-first
|
| 67 |
+
assert calls[0] > calls[1] > calls[2]
|
| 68 |
+
# All aligned to the top of the hour
|
| 69 |
+
for c in calls:
|
| 70 |
+
assert c.minute == 0 and c.second == 0 and c.microsecond == 0
|
| 71 |
+
assert c.tzinfo == timezone.utc
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_start_is_no_op_when_disabled(monkeypatch):
|
| 75 |
+
mod = _load()
|
| 76 |
+
# Ensure clean state — _scheduler is module-global
|
| 77 |
+
mod._scheduler = None
|
| 78 |
+
monkeypatch.setenv("ML_INTERN_KPIS_DISABLED", "1")
|
| 79 |
+
mod.start()
|
| 80 |
+
assert mod._scheduler is None # never instantiated
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def test_start_skips_cleanly_without_apscheduler(monkeypatch):
|
| 84 |
+
mod = _load()
|
| 85 |
+
mod._scheduler = None
|
| 86 |
+
monkeypatch.delenv("ML_INTERN_KPIS_DISABLED", raising=False)
|
| 87 |
+
|
| 88 |
+
# Force the apscheduler import to fail — start() should log and return.
|
| 89 |
+
real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
|
| 90 |
+
|
| 91 |
+
def fake_import(name, *args, **kwargs):
|
| 92 |
+
if name.startswith("apscheduler"):
|
| 93 |
+
raise ImportError("apscheduler unavailable in test")
|
| 94 |
+
return real_import(name, *args, **kwargs)
|
| 95 |
+
|
| 96 |
+
monkeypatch.setattr(
|
| 97 |
+
"builtins.__import__",
|
| 98 |
+
fake_import,
|
| 99 |
+
)
|
| 100 |
+
mod.start() # should not raise
|
| 101 |
+
assert mod._scheduler is None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_shutdown_is_no_op_when_not_started():
|
| 105 |
+
mod = _load()
|
| 106 |
+
mod._scheduler = None
|
| 107 |
+
asyncio.run(mod.shutdown()) # must not raise
|
tests/unit/test_redact.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the secret scrubber used before session upload."""
|
| 2 |
+
|
| 3 |
+
from agent.core.redact import scrub, scrub_string
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_hf_token():
|
| 7 |
+
s = "here is a token hf_" + "A" * 35 + " ok"
|
| 8 |
+
out = scrub_string(s)
|
| 9 |
+
assert "hf_" not in out
|
| 10 |
+
assert "[REDACTED_HF_TOKEN]" in out
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_anthropic_key():
|
| 14 |
+
s = "key=sk-ant-api03_" + "a" * 40
|
| 15 |
+
out = scrub_string(s)
|
| 16 |
+
# The env-var name prefix matches too; just verify we don't leave the body.
|
| 17 |
+
assert "sk-ant-api03_" not in out
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_github_token():
|
| 21 |
+
s = "ghp_" + "a" * 40
|
| 22 |
+
out = scrub_string(s)
|
| 23 |
+
assert out == "[REDACTED_GITHUB_TOKEN]"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_github_fine_grained_pat():
|
| 27 |
+
# Fine-grained PATs: github_pat_<alphanumeric + underscore>, 36+ chars
|
| 28 |
+
s = "github_pat_" + "A1B2_" * 10
|
| 29 |
+
out = scrub_string(s)
|
| 30 |
+
assert "github_pat_" not in out
|
| 31 |
+
assert "[REDACTED_GITHUB_TOKEN]" in out
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_aws_key_id():
|
| 35 |
+
s = "AWS_ACCESS_KEY_ID=AKIAABCDEFGHIJKLMNOP"
|
| 36 |
+
out = scrub_string(s)
|
| 37 |
+
assert "AKIAABCDEFGHIJKLMNOP" not in out
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_bearer_header():
|
| 41 |
+
s = "Authorization: Bearer abcdef0123456789abcdef0123456789"
|
| 42 |
+
out = scrub_string(s)
|
| 43 |
+
assert "abcdef0123456789abcdef0123456789" not in out
|
| 44 |
+
assert "Bearer [REDACTED]" in out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_env_var_style():
|
| 48 |
+
s = "HF_TOKEN=hf_" + "x" * 40 + " run"
|
| 49 |
+
out = scrub_string(s)
|
| 50 |
+
# Either the value-scrubber or the HF-token regex should fire.
|
| 51 |
+
assert "hf_xxxx" not in out
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_scrub_nested_dict_and_list():
|
| 55 |
+
payload = {
|
| 56 |
+
"msg": "token hf_" + "Z" * 35,
|
| 57 |
+
"tools": [
|
| 58 |
+
{"args": {"secret": "ghp_" + "Q" * 40}},
|
| 59 |
+
"no secrets here",
|
| 60 |
+
],
|
| 61 |
+
"n": 42,
|
| 62 |
+
}
|
| 63 |
+
out = scrub(payload)
|
| 64 |
+
# Original not mutated
|
| 65 |
+
assert "hf_" in payload["msg"]
|
| 66 |
+
# Redacted copy
|
| 67 |
+
assert "[REDACTED_HF_TOKEN]" in out["msg"]
|
| 68 |
+
assert out["tools"][0]["args"]["secret"] == "[REDACTED_GITHUB_TOKEN]"
|
| 69 |
+
assert out["tools"][1] == "no secrets here"
|
| 70 |
+
assert out["n"] == 42
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_scrub_preserves_non_strings():
|
| 74 |
+
assert scrub(None) is None
|
| 75 |
+
assert scrub(123) == 123
|
| 76 |
+
assert scrub(True) is True
|
tests/unit/test_sft_tagger.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for agent.sft.tagger — one test per tag namespace."""
|
| 2 |
+
|
| 3 |
+
from agent.sft.tagger import tag_session
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _ev(event_type, data=None, ts="2026-04-24T10:00:00"):
|
| 7 |
+
return {"timestamp": ts, "event_type": event_type, "data": data or {}}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _traj(events=None, messages=None, model="claude-opus-4-6"):
|
| 11 |
+
return {
|
| 12 |
+
"session_id": "sess-1",
|
| 13 |
+
"model_name": model,
|
| 14 |
+
"session_start_time": "2026-04-24T09:59:00",
|
| 15 |
+
"session_end_time": "2026-04-24T10:05:00",
|
| 16 |
+
"messages": messages
|
| 17 |
+
or [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "ok"}],
|
| 18 |
+
"events": events or [],
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_model_family():
|
| 23 |
+
assert "model:opus" in tag_session(_traj(model="claude-opus-4-6"))
|
| 24 |
+
assert "model:sonnet" in tag_session(_traj(model="bedrock/claude-sonnet-4-5"))
|
| 25 |
+
assert "model:kimi" in tag_session(_traj(model="moonshotai/Kimi-K2.6"))
|
| 26 |
+
assert "model:other" in tag_session(_traj(model="unknown-model-xyz"))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_turns_buckets():
|
| 30 |
+
short = _traj(messages=[{"role": "user", "content": "hi"}])
|
| 31 |
+
medium = _traj(messages=[{"role": "user", "content": "q"} for _ in range(10)])
|
| 32 |
+
long = _traj(messages=[{"role": "user", "content": "q"} for _ in range(25)])
|
| 33 |
+
assert "turns:short" in tag_session(short)
|
| 34 |
+
assert "turns:medium" in tag_session(medium)
|
| 35 |
+
assert "turns:long" in tag_session(long)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_cost_buckets():
|
| 39 |
+
cheap = _traj(events=[_ev("llm_call", {"cost_usd": 0.05})])
|
| 40 |
+
med = _traj(events=[_ev("llm_call", {"cost_usd": 0.5})])
|
| 41 |
+
expensive = _traj(events=[_ev("llm_call", {"cost_usd": 5.0})])
|
| 42 |
+
assert "cost:low" in tag_session(cheap)
|
| 43 |
+
assert "cost:med" in tag_session(med)
|
| 44 |
+
assert "cost:high" in tag_session(expensive)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_tool_tags():
|
| 48 |
+
events = [
|
| 49 |
+
_ev("tool_call", {"tool": "hf_jobs", "arguments": {}}),
|
| 50 |
+
_ev("tool_call", {"tool": "research"}),
|
| 51 |
+
_ev("tool_call", {"tool": "bash"}),
|
| 52 |
+
]
|
| 53 |
+
tags = tag_session(_traj(events))
|
| 54 |
+
assert "tool:hf_jobs" in tags
|
| 55 |
+
assert "tool:research" in tags
|
| 56 |
+
assert "tool:bash" in tags
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_outcome_completed():
|
| 60 |
+
events = [_ev("turn_complete", {"history_size": 10})]
|
| 61 |
+
assert "outcome:completed" in tag_session(_traj(events))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_outcome_errored():
|
| 65 |
+
events = [_ev("error", {"error": "boom"})]
|
| 66 |
+
assert "outcome:errored" in tag_session(_traj(events))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_outcome_interrupted():
|
| 70 |
+
events = [_ev("interrupted")]
|
| 71 |
+
assert "outcome:interrupted" in tag_session(_traj(events))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_outcome_ongoing():
|
| 75 |
+
# No terminal events → session was still running at save time
|
| 76 |
+
events = [_ev("llm_call", {"cost_usd": 0.01})]
|
| 77 |
+
assert "outcome:ongoing" in tag_session(_traj(events))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_outcome_doom_loop_and_context():
|
| 81 |
+
events = [
|
| 82 |
+
_ev("tool_log", {"tool": "system", "log": "Doom loop detected — injecting corrective prompt"}),
|
| 83 |
+
_ev("compacted", {"old_tokens": 100, "new_tokens": 50}),
|
| 84 |
+
_ev("turn_complete", {"history_size": 10}),
|
| 85 |
+
]
|
| 86 |
+
tags = tag_session(_traj(events))
|
| 87 |
+
assert "outcome:doom_loop" in tags
|
| 88 |
+
assert "outcome:context_exceeded" in tags
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_hf_job_tags():
|
| 92 |
+
events = [
|
| 93 |
+
_ev("tool_call", {"tool": "hf_jobs", "arguments": {"script": "from trl import SFTTrainer"}}),
|
| 94 |
+
_ev("hf_job_submit", {
|
| 95 |
+
"flavor": "a100-large", "push_to_hub": True, "job_id": "j1",
|
| 96 |
+
}),
|
| 97 |
+
_ev("hf_job_complete", {"flavor": "a100-large", "final_status": "COMPLETED", "wall_time_s": 3600}),
|
| 98 |
+
_ev("hf_job_submit", {"flavor": "a100x4", "push_to_hub": False}),
|
| 99 |
+
_ev("hf_job_complete", {"flavor": "a100x4", "final_status": "FAILED"}),
|
| 100 |
+
]
|
| 101 |
+
tags = tag_session(_traj(events))
|
| 102 |
+
assert "hf_job:submitted" in tags
|
| 103 |
+
assert "hf_job:multi" in tags
|
| 104 |
+
assert "hf_job:succeeded" in tags
|
| 105 |
+
assert "hf_job:failed" in tags
|
| 106 |
+
assert "hf_job:push_to_hub" in tags
|
| 107 |
+
assert "gpu:a100" in tags
|
| 108 |
+
assert "gpu:multi" in tags
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def test_hf_job_oom():
|
| 112 |
+
events = [
|
| 113 |
+
_ev("tool_call", {"tool": "hf_jobs", "arguments": {}}),
|
| 114 |
+
_ev("hf_job_submit", {"flavor": "a100-large"}),
|
| 115 |
+
_ev("tool_output", {"success": False, "output": "RuntimeError: CUDA out of memory. Tried to allocate..."}),
|
| 116 |
+
]
|
| 117 |
+
tags = tag_session(_traj(events))
|
| 118 |
+
assert "hf_job:oom" in tags
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def test_sandbox_tags():
|
| 122 |
+
events = [
|
| 123 |
+
_ev("sandbox_create", {"hardware": "t4-small", "sandbox_id": "s1", "create_latency_s": 5}),
|
| 124 |
+
_ev("sandbox_destroy", {"sandbox_id": "s1", "lifetime_s": 3600}),
|
| 125 |
+
]
|
| 126 |
+
tags = tag_session(_traj(events))
|
| 127 |
+
assert "sandbox:created" in tags
|
| 128 |
+
assert "sandbox:gpu" in tags
|
| 129 |
+
assert "sandbox:long_lived" in tags
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test_sandbox_cpu_short():
|
| 133 |
+
events = [
|
| 134 |
+
_ev("sandbox_create", {"hardware": "cpu-basic"}),
|
| 135 |
+
_ev("sandbox_destroy", {"lifetime_s": 120}),
|
| 136 |
+
]
|
| 137 |
+
tags = tag_session(_traj(events))
|
| 138 |
+
assert "sandbox:cpu" in tags
|
| 139 |
+
assert "sandbox:long_lived" not in tags
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_feedback_tags():
|
| 143 |
+
up_only = _traj(events=[_ev("feedback", {"rating": "up"})])
|
| 144 |
+
down_only = _traj(events=[_ev("feedback", {"rating": "down"})])
|
| 145 |
+
mixed = _traj(events=[_ev("feedback", {"rating": "up"}), _ev("feedback", {"rating": "down"})])
|
| 146 |
+
none = _traj()
|
| 147 |
+
assert "feedback:up" in tag_session(up_only)
|
| 148 |
+
assert "feedback:down" in tag_session(down_only)
|
| 149 |
+
assert "feedback:mixed" in tag_session(mixed)
|
| 150 |
+
assert "feedback:none" in tag_session(none)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def test_task_training():
|
| 154 |
+
events = [
|
| 155 |
+
_ev("tool_call", {"tool": "hf_jobs", "arguments": {
|
| 156 |
+
"script": "from trl import SFTTrainer\ntrainer = SFTTrainer(...)"
|
| 157 |
+
}}),
|
| 158 |
+
_ev("hf_job_submit", {"flavor": "a100-large"}),
|
| 159 |
+
]
|
| 160 |
+
assert "task:training" in tag_session(_traj(events))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def test_task_research_only():
|
| 164 |
+
events = [
|
| 165 |
+
_ev("tool_call", {"tool": "research"}),
|
| 166 |
+
_ev("tool_call", {"tool": "explore_hf_docs"}),
|
| 167 |
+
]
|
| 168 |
+
assert "task:research_only" in tag_session(_traj(events))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def test_task_data_prep():
|
| 172 |
+
events = [
|
| 173 |
+
_ev("tool_call", {"tool": "hf_inspect_dataset", "arguments": {}}),
|
| 174 |
+
_ev("tool_call", {"tool": "hub_repo_details"}),
|
| 175 |
+
]
|
| 176 |
+
tags = tag_session(_traj(events))
|
| 177 |
+
assert "task:data_prep" in tags
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def test_no_duplicates_and_sorted():
|
| 181 |
+
events = [
|
| 182 |
+
_ev("tool_call", {"tool": "hf_jobs"}),
|
| 183 |
+
_ev("tool_call", {"tool": "hf_jobs"}), # duplicate
|
| 184 |
+
_ev("hf_job_submit", {"flavor": "a10g-small"}),
|
| 185 |
+
_ev("hf_job_submit", {"flavor": "a10g-small"}),
|
| 186 |
+
]
|
| 187 |
+
tags = tag_session(_traj(events))
|
| 188 |
+
assert tags == sorted(tags)
|
| 189 |
+
assert len(tags) == len(set(tags))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def test_empty_trajectory_has_required_tags():
|
| 193 |
+
tags = tag_session(_traj())
|
| 194 |
+
namespaces = {t.split(":", 1)[0] for t in tags}
|
| 195 |
+
# Every session must have at least model/turns/cost/outcome/feedback.
|
| 196 |
+
for required in ("model", "turns", "cost", "outcome", "feedback"):
|
| 197 |
+
assert required in namespaces, f"missing {required} — got {tags}"
|