GuillaumeSalouHF commited on
Commit
f6f7fe8
·
unverified ·
1 Parent(s): d0d08fc

feat(session): include user_id and total_cost_usd in trajectory dump

Browse files

Mirror of huggingface/ml-intern#136 applied surgically to the Space's
own main branch (the github/main merge has unrelated drift that would
pull in too much).

Add user_id (HF username from OAuth) to the Session object and
propagate it from the SessionManager. Surface it in the trajectory
JSON written by save_trajectory_local() and uploaded to the session
dataset, alongside a total_cost_usd aggregate summed from existing
llm_call events.

agent/core/session.py CHANGED
@@ -79,8 +79,10 @@ class Session:
79
  hf_token: str | None = None,
80
  local_mode: bool = False,
81
  stream: bool = True,
 
82
  ):
83
  self.hf_token: Optional[str] = hf_token
 
84
  self.tool_router = tool_router
85
  self.stream = stream
86
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
@@ -199,11 +201,21 @@ class Session:
199
  tools = self.tool_router.get_tool_specs_for_llm() or []
200
  except Exception:
201
  tools = []
 
 
 
 
 
 
 
 
202
  return {
203
  "session_id": self.session_id,
 
204
  "session_start_time": self.session_start_time,
205
  "session_end_time": datetime.now().isoformat(),
206
  "model_name": self.config.model_name,
 
207
  "messages": [msg.model_dump() for msg in self.context_manager.items],
208
  "events": self.logged_events,
209
  "tools": tools,
 
79
  hf_token: str | None = None,
80
  local_mode: bool = False,
81
  stream: bool = True,
82
+ user_id: str | None = None,
83
  ):
84
  self.hf_token: Optional[str] = hf_token
85
+ self.user_id: Optional[str] = user_id
86
  self.tool_router = tool_router
87
  self.stream = stream
88
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
 
201
  tools = self.tool_router.get_tool_specs_for_llm() or []
202
  except Exception:
203
  tools = []
204
+ # Sum per-call cost from llm_call events so analyzers don't have to
205
+ # walk the events array themselves. Each `llm_call` event already
206
+ # carries cost_usd from `agent.core.telemetry.record_llm_call`.
207
+ total_cost_usd = sum(
208
+ float((e.get("data") or {}).get("cost_usd") or 0.0)
209
+ for e in self.logged_events
210
+ if e.get("event_type") == "llm_call"
211
+ )
212
  return {
213
  "session_id": self.session_id,
214
+ "user_id": self.user_id,
215
  "session_start_time": self.session_start_time,
216
  "session_end_time": datetime.now().isoformat(),
217
  "model_name": self.config.model_name,
218
+ "total_cost_usd": total_cost_usd,
219
  "messages": [msg.model_dump() for msg in self.context_manager.items],
220
  "events": self.logged_events,
221
  "tools": tools,
backend/session_manager.py CHANGED
@@ -192,7 +192,7 @@ class SessionManager:
192
  session_config.model_name = model
193
  session = Session(
194
  event_queue, config=session_config, tool_router=tool_router,
195
- hf_token=hf_token,
196
  )
197
  t1 = _time.monotonic()
198
  logger.info(f"Session initialized in {t1 - t0:.2f}s")
 
192
  session_config.model_name = model
193
  session = Session(
194
  event_queue, config=session_config, tool_router=tool_router,
195
+ hf_token=hf_token, user_id=user_id,
196
  )
197
  t1 = _time.monotonic()
198
  logger.info(f"Session initialized in {t1 - t0:.2f}s")