Draken1606 commited on
Commit
9932c2e
·
1 Parent(s): 33279ea

Add seed param to /reset: demo pins to seed=0 per stage for consistent known episodes

Browse files
demo/index.html CHANGED
@@ -239,7 +239,7 @@ async function loadCase(stage) {
239
  const labels = {1:'Stage 1 — Landmark', 2:'Stage 2 — Contested', 3:'Stage 3 — Reversal', 4:'Stage 4 — Schema Drift (BNSS)'};
240
  document.getElementById('stageLabel').textContent = 'Live Demo Case — ' + (labels[stage] || 'Stage '+stage);
241
  try {
242
- const res = await fetch(BASE + '/reset?stage=' + stage, { method: 'POST' });
243
  const data = await res.json();
244
  sessionId = data.session_id;
245
  const obs = data.observation;
@@ -279,8 +279,8 @@ async function runDemo() {
279
  document.getElementById('toolLog').innerHTML = '';
280
  document.getElementById('rewardCard').classList.remove('visible');
281
 
282
- log('POST /reset?stage=' + stage, 'tool');
283
- const res = await fetch(BASE + '/reset?stage=' + stage, { method: 'POST' });
284
  const data = await res.json();
285
  sessionId = data.session_id;
286
  log('Session ready: ' + sessionId.slice(0,8) + '...', 'ok');
 
239
  const labels = {1:'Stage 1 — Landmark', 2:'Stage 2 — Contested', 3:'Stage 3 — Reversal', 4:'Stage 4 — Schema Drift (BNSS)'};
240
  document.getElementById('stageLabel').textContent = 'Live Demo Case — ' + (labels[stage] || 'Stage '+stage);
241
  try {
242
+ const res = await fetch(BASE + '/reset?stage=' + stage + '&seed=0', { method: 'POST' });
243
  const data = await res.json();
244
  sessionId = data.session_id;
245
  const obs = data.observation;
 
279
  document.getElementById('toolLog').innerHTML = '';
280
  document.getElementById('rewardCard').classList.remove('visible');
281
 
282
+ log('POST /reset?stage=' + stage + '&seed=0', 'tool');
283
+ const res = await fetch(BASE + '/reset?stage=' + stage + '&seed=0', { method: 'POST' });
284
  const data = await res.json();
285
  sessionId = data.session_id;
286
  log('Session ready: ' + sessionId.slice(0,8) + '...', 'ok');
server/app.py CHANGED
@@ -69,12 +69,12 @@ def health():
69
 
70
 
71
  @app.post("/reset")
72
- def reset(stage: int = 1, session_id: str = None):
73
  if session_id is None:
74
  session_id = str(uuid.uuid4())
75
  env = get_or_create_env(session_id)
76
  env.set_stage(stage)
77
- obs = env.reset(stage=stage)
78
  return {
79
  "session_id": session_id,
80
  "observation": obs.model_dump(),
 
69
 
70
 
71
  @app.post("/reset")
72
+ def reset(stage: int = 1, session_id: str = None, seed: int = None):
73
  if session_id is None:
74
  session_id = str(uuid.uuid4())
75
  env = get_or_create_env(session_id)
76
  env.set_stage(stage)
77
+ obs = env.reset(stage=stage, seed=seed)
78
  return {
79
  "session_id": session_id,
80
  "observation": obs.model_dump(),
server/dataset.py CHANGED
@@ -282,16 +282,27 @@ class BailDataset:
282
  self,
283
  stage: Optional[int] = None,
284
  apply_drift: bool = True,
 
285
  ) -> Dict[str, Any]:
286
- """Sample an episode from the requested curriculum stage."""
 
 
 
 
 
 
 
287
  s = stage if stage is not None else self._current_stage
288
 
289
  # Fallback: if stage is empty, try adjacent stages
290
  for candidate in [s, s-1, s+1, 1, 2, 3, 4]:
291
  if 1 <= candidate <= 4 and self._episodes[candidate]:
292
  eps = self._episodes[candidate]
293
- idx = self._episode_index[candidate] % len(eps)
294
- self._episode_index[candidate] += 1
 
 
 
295
  ep = eps[idx]
296
  if apply_drift and s == 4:
297
  ep = maybe_apply_drift(ep, probability=0.4)
 
282
  self,
283
  stage: Optional[int] = None,
284
  apply_drift: bool = True,
285
+ seed: Optional[int] = None,
286
  ) -> Dict[str, Any]:
287
+ """Sample an episode from the requested curriculum stage.
288
+
289
+ Args:
290
+ stage: Curriculum stage 1-4. Defaults to current stage.
291
+ apply_drift: Apply BNSS schema drift for stage 4 episodes.
292
+ seed: If set, deterministically picks episode at index (seed % len).
293
+ Used by the demo to always show the same illustrative case.
294
+ """
295
  s = stage if stage is not None else self._current_stage
296
 
297
  # Fallback: if stage is empty, try adjacent stages
298
  for candidate in [s, s-1, s+1, 1, 2, 3, 4]:
299
  if 1 <= candidate <= 4 and self._episodes[candidate]:
300
  eps = self._episodes[candidate]
301
+ if seed is not None:
302
+ idx = seed % len(eps)
303
+ else:
304
+ idx = self._episode_index[candidate] % len(eps)
305
+ self._episode_index[candidate] += 1
306
  ep = eps[idx]
307
  if apply_drift and s == 4:
308
  ep = maybe_apply_drift(ep, probability=0.4)
server/undertrial_environment.py CHANGED
@@ -80,7 +80,7 @@ class UndertriAIEnvironment(Environment):
80
  """Start a new episode. Returns initial case observation."""
81
  self._reset_rubric() if hasattr(self, '_reset_rubric') else None
82
  s = stage or self._current_stage
83
- self._episode = self.dataset.sample_episode(stage=s)
84
  self._episode_id = episode_id or str(uuid.uuid4())
85
  self._step_count = 0
86
  self._flags = []
 
80
  """Start a new episode. Returns initial case observation."""
81
  self._reset_rubric() if hasattr(self, '_reset_rubric') else None
82
  s = stage or self._current_stage
83
+ self._episode = self.dataset.sample_episode(stage=s, seed=seed)
84
  self._episode_id = episode_id or str(uuid.uuid4())
85
  self._step_count = 0
86
  self._flags = []