Spaces:
Running
Running
Commit ·
9932c2e
1
Parent(s): 33279ea
Add seed param to /reset: demo pins to seed=0 per stage for consistent known episodes
Browse files- demo/index.html +3 -3
- server/app.py +2 -2
- server/dataset.py +14 -3
- server/undertrial_environment.py +1 -1
demo/index.html
CHANGED
|
@@ -239,7 +239,7 @@ async function loadCase(stage) {
|
|
| 239 |
const labels = {1:'Stage 1 — Landmark', 2:'Stage 2 — Contested', 3:'Stage 3 — Reversal', 4:'Stage 4 — Schema Drift (BNSS)'};
|
| 240 |
document.getElementById('stageLabel').textContent = 'Live Demo Case — ' + (labels[stage] || 'Stage '+stage);
|
| 241 |
try {
|
| 242 |
-
const res = await fetch(BASE + '/reset?stage=' + stage, { method: 'POST' });
|
| 243 |
const data = await res.json();
|
| 244 |
sessionId = data.session_id;
|
| 245 |
const obs = data.observation;
|
|
@@ -279,8 +279,8 @@ async function runDemo() {
|
|
| 279 |
document.getElementById('toolLog').innerHTML = '';
|
| 280 |
document.getElementById('rewardCard').classList.remove('visible');
|
| 281 |
|
| 282 |
-
log('POST /reset?stage=' + stage, 'tool');
|
| 283 |
-
const res = await fetch(BASE + '/reset?stage=' + stage, { method: 'POST' });
|
| 284 |
const data = await res.json();
|
| 285 |
sessionId = data.session_id;
|
| 286 |
log('Session ready: ' + sessionId.slice(0,8) + '...', 'ok');
|
|
|
|
| 239 |
const labels = {1:'Stage 1 — Landmark', 2:'Stage 2 — Contested', 3:'Stage 3 — Reversal', 4:'Stage 4 — Schema Drift (BNSS)'};
|
| 240 |
document.getElementById('stageLabel').textContent = 'Live Demo Case — ' + (labels[stage] || 'Stage '+stage);
|
| 241 |
try {
|
| 242 |
+
const res = await fetch(BASE + '/reset?stage=' + stage + '&seed=0', { method: 'POST' });
|
| 243 |
const data = await res.json();
|
| 244 |
sessionId = data.session_id;
|
| 245 |
const obs = data.observation;
|
|
|
|
| 279 |
document.getElementById('toolLog').innerHTML = '';
|
| 280 |
document.getElementById('rewardCard').classList.remove('visible');
|
| 281 |
|
| 282 |
+
log('POST /reset?stage=' + stage + '&seed=0', 'tool');
|
| 283 |
+
const res = await fetch(BASE + '/reset?stage=' + stage + '&seed=0', { method: 'POST' });
|
| 284 |
const data = await res.json();
|
| 285 |
sessionId = data.session_id;
|
| 286 |
log('Session ready: ' + sessionId.slice(0,8) + '...', 'ok');
|
server/app.py
CHANGED
|
@@ -69,12 +69,12 @@ def health():
|
|
| 69 |
|
| 70 |
|
| 71 |
@app.post("/reset")
|
| 72 |
-
def reset(stage: int = 1, session_id: str = None):
|
| 73 |
if session_id is None:
|
| 74 |
session_id = str(uuid.uuid4())
|
| 75 |
env = get_or_create_env(session_id)
|
| 76 |
env.set_stage(stage)
|
| 77 |
-
obs = env.reset(stage=stage)
|
| 78 |
return {
|
| 79 |
"session_id": session_id,
|
| 80 |
"observation": obs.model_dump(),
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
@app.post("/reset")
|
| 72 |
+
def reset(stage: int = 1, session_id: str = None, seed: int = None):
|
| 73 |
if session_id is None:
|
| 74 |
session_id = str(uuid.uuid4())
|
| 75 |
env = get_or_create_env(session_id)
|
| 76 |
env.set_stage(stage)
|
| 77 |
+
obs = env.reset(stage=stage, seed=seed)
|
| 78 |
return {
|
| 79 |
"session_id": session_id,
|
| 80 |
"observation": obs.model_dump(),
|
server/dataset.py
CHANGED
|
@@ -282,16 +282,27 @@ class BailDataset:
|
|
| 282 |
self,
|
| 283 |
stage: Optional[int] = None,
|
| 284 |
apply_drift: bool = True,
|
|
|
|
| 285 |
) -> Dict[str, Any]:
|
| 286 |
-
"""Sample an episode from the requested curriculum stage.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
s = stage if stage is not None else self._current_stage
|
| 288 |
|
| 289 |
# Fallback: if stage is empty, try adjacent stages
|
| 290 |
for candidate in [s, s-1, s+1, 1, 2, 3, 4]:
|
| 291 |
if 1 <= candidate <= 4 and self._episodes[candidate]:
|
| 292 |
eps = self._episodes[candidate]
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
| 295 |
ep = eps[idx]
|
| 296 |
if apply_drift and s == 4:
|
| 297 |
ep = maybe_apply_drift(ep, probability=0.4)
|
|
|
|
| 282 |
self,
|
| 283 |
stage: Optional[int] = None,
|
| 284 |
apply_drift: bool = True,
|
| 285 |
+
seed: Optional[int] = None,
|
| 286 |
) -> Dict[str, Any]:
|
| 287 |
+
"""Sample an episode from the requested curriculum stage.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
stage: Curriculum stage 1-4. Defaults to current stage.
|
| 291 |
+
apply_drift: Apply BNSS schema drift for stage 4 episodes.
|
| 292 |
+
seed: If set, deterministically picks episode at index (seed % len).
|
| 293 |
+
Used by the demo to always show the same illustrative case.
|
| 294 |
+
"""
|
| 295 |
s = stage if stage is not None else self._current_stage
|
| 296 |
|
| 297 |
# Fallback: if stage is empty, try adjacent stages
|
| 298 |
for candidate in [s, s-1, s+1, 1, 2, 3, 4]:
|
| 299 |
if 1 <= candidate <= 4 and self._episodes[candidate]:
|
| 300 |
eps = self._episodes[candidate]
|
| 301 |
+
if seed is not None:
|
| 302 |
+
idx = seed % len(eps)
|
| 303 |
+
else:
|
| 304 |
+
idx = self._episode_index[candidate] % len(eps)
|
| 305 |
+
self._episode_index[candidate] += 1
|
| 306 |
ep = eps[idx]
|
| 307 |
if apply_drift and s == 4:
|
| 308 |
ep = maybe_apply_drift(ep, probability=0.4)
|
server/undertrial_environment.py
CHANGED
|
@@ -80,7 +80,7 @@ class UndertriAIEnvironment(Environment):
|
|
| 80 |
"""Start a new episode. Returns initial case observation."""
|
| 81 |
self._reset_rubric() if hasattr(self, '_reset_rubric') else None
|
| 82 |
s = stage or self._current_stage
|
| 83 |
-
self._episode = self.dataset.sample_episode(stage=s)
|
| 84 |
self._episode_id = episode_id or str(uuid.uuid4())
|
| 85 |
self._step_count = 0
|
| 86 |
self._flags = []
|
|
|
|
| 80 |
"""Start a new episode. Returns initial case observation."""
|
| 81 |
self._reset_rubric() if hasattr(self, '_reset_rubric') else None
|
| 82 |
s = stage or self._current_stage
|
| 83 |
+
self._episode = self.dataset.sample_episode(stage=s, seed=seed)
|
| 84 |
self._episode_id = episode_id or str(uuid.uuid4())
|
| 85 |
self._step_count = 0
|
| 86 |
self._flags = []
|