namish10 commited on
Commit
a2896bf
·
verified ·
1 Parent(s): 86b5863

Upload folder using huggingface_hub

Browse files
server/__init__.py ADDED
File without changes
server/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (153 Bytes). View file
 
server/__pycache__/contextflow_environment.cpython-313.pyc ADDED
Binary file (23.3 kB). View file
 
server/app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
+ from fastapi.responses import JSONResponse
3
+ import asyncio
4
+ import json
5
+ from typing import Optional
6
+
7
+ from models import Observation, Action, Reward, State, StepResult, TaskDifficulty
8
+ from server.contextflow_environment import ContextFlowEnvironment
9
+
10
+
11
+ app = FastAPI(title="ContextFlow OpenEnv")
12
+
13
+ connections: dict[str, WebSocket] = {}
14
+ environments: dict[str, ContextFlowEnvironment] = {}
15
+
16
+
17
+ @app.get("/")
18
+ async def root():
19
+ return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
20
+
21
+
22
+ @app.get("/health")
23
+ async def health():
24
+ return {"status": "healthy"}
25
+
26
+
27
+ @app.post("/reset")
28
+ async def reset(difficulty: Optional[str] = "medium"):
29
+ try:
30
+ difficulty_enum = TaskDifficulty(difficulty.lower())
31
+ except ValueError:
32
+ difficulty_enum = TaskDifficulty.MEDIUM
33
+
34
+ env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
35
+ observation = env.reset()
36
+
37
+ env_id = observation.episode_id
38
+ environments[env_id] = env
39
+
40
+ return {
41
+ "observation": observation.model_dump(),
42
+ "episode_id": env_id,
43
+ }
44
+
45
+
46
+ @app.post("/step")
47
+ async def step(action: Action):
48
+ if not action.episode_id or action.episode_id not in environments:
49
+ return JSONResponse(
50
+ status_code=400,
51
+ content={"error": "Invalid or missing episode_id"}
52
+ )
53
+
54
+ env = environments[action.episode_id]
55
+ result = env.step(action)
56
+
57
+ if result.done:
58
+ del environments[action.episode_id]
59
+
60
+ return result.model_dump()
61
+
62
+
63
+ @app.get("/state/{episode_id}")
64
+ async def get_state(episode_id: str):
65
+ if episode_id not in environments:
66
+ return JSONResponse(
67
+ status_code=404,
68
+ content={"error": "Episode not found"}
69
+ )
70
+
71
+ env = environments[episode_id]
72
+ return env.state().model_dump()
73
+
74
+
75
+ @app.websocket("/ws/{episode_id}")
76
+ async def websocket_endpoint(websocket: WebSocket, episode_id: str):
77
+ await websocket.accept()
78
+ connections[episode_id] = websocket
79
+
80
+ if episode_id not in environments:
81
+ await websocket.send_json({"error": "Episode not found"})
82
+ await websocket.close()
83
+ return
84
+
85
+ try:
86
+ while True:
87
+ data = await websocket.receive_text()
88
+ message = json.loads(data)
89
+
90
+ if message["type"] == "reset":
91
+ difficulty = message.get("difficulty", "medium")
92
+ env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty))
93
+ observation = env.reset()
94
+ environments[episode_id] = env
95
+ await websocket.send_json({
96
+ "type": "reset",
97
+ "observation": observation.model_dump()
98
+ })
99
+
100
+ elif message["type"] == "step":
101
+ if episode_id not in environments:
102
+ await websocket.send_json({"error": "Episode not found"})
103
+ continue
104
+
105
+ env = environments[episode_id]
106
+ action = Action(**message["action"])
107
+ result = env.step(action)
108
+
109
+ if result.done:
110
+ del environments[episode_id]
111
+
112
+ await websocket.send_json({
113
+ "type": "step",
114
+ "result": result.model_dump()
115
+ })
116
+
117
+ elif message["type"] == "state":
118
+ if episode_id not in environments:
119
+ await websocket.send_json({"error": "Episode not found"})
120
+ continue
121
+
122
+ env = environments[episode_id]
123
+ await websocket.send_json({
124
+ "type": "state",
125
+ "state": env.state().model_dump()
126
+ })
127
+
128
+ except WebSocketDisconnect:
129
+ pass
130
+ finally:
131
+ if episode_id in connections:
132
+ del connections[episode_id]
133
+
134
+
135
+ if __name__ == "__main__":
136
+ import uvicorn
137
+ uvicorn.run(app, host="0.0.0.0", port=8000)
server/contextflow_environment.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import numpy as np
3
+ from typing import Dict, Any, Optional, Tuple
4
+ from datetime import datetime
5
+
6
+ from models import (
7
+ Observation, Action, Reward, State, StepResult,
8
+ TaskDifficulty, ActionType, GraderResult
9
+ )
10
+ from agents import (
11
+ ContextFlowAgent,
12
+ AgentPrediction,
13
+ ConfusionLevel,
14
+ InterventionType,
15
+ KnowledgeGraphAgent,
16
+ PeerLearningAgent,
17
+ RecallAgent,
18
+ )
19
+
20
+
21
+ class ContextFlowEnvironment:
22
+ """
23
+ OpenEnv environment with full ContextFlow multi-agent system.
24
+
25
+ Integrates:
26
+ - RL-based doubt prediction
27
+ - Multi-modal behavioral analysis
28
+ - Gesture recognition
29
+ - Knowledge graphs
30
+ - Peer learning
31
+ - Spaced repetition
32
+ """
33
+
34
+ MAX_STEPS = 100
35
+
36
+ def __init__(self, task_difficulty: TaskDifficulty = TaskDifficulty.MEDIUM):
37
+ self.task_difficulty = task_difficulty
38
+ self.episode_id: Optional[str] = None
39
+ self.step_count: int = 0
40
+ self._state: Optional[State] = None
41
+ self._last_observation: Optional[Observation] = None
42
+
43
+ self._ground_truth_confusion: float = 0.0
44
+ self._confusion_trajectory: list = []
45
+ self._prediction_history: list = []
46
+ self._intervention_history: list = []
47
+ self._task_config = self._get_task_config()
48
+
49
+ self.agent = ContextFlowAgent()
50
+ self.knowledge_graph = KnowledgeGraphAgent()
51
+ self.peer_learning = PeerLearningAgent()
52
+ self.recall_system = RecallAgent()
53
+
54
+ def _get_task_config(self) -> Dict[str, Any]:
55
+ configs = {
56
+ TaskDifficulty.EASY: {
57
+ "prediction_window": 3,
58
+ "noise_level": 0.1,
59
+ "confusion_base": 0.3,
60
+ "intervention_threshold": 0.6,
61
+ "max_steps": 50,
62
+ "confusion_spike_prob": 0.08,
63
+ },
64
+ TaskDifficulty.MEDIUM: {
65
+ "prediction_window": 5,
66
+ "noise_level": 0.2,
67
+ "confusion_base": 0.5,
68
+ "intervention_threshold": 0.5,
69
+ "max_steps": 75,
70
+ "confusion_spike_prob": 0.12,
71
+ },
72
+ TaskDifficulty.HARD: {
73
+ "prediction_window": 7,
74
+ "noise_level": 0.3,
75
+ "confusion_base": 0.6,
76
+ "intervention_threshold": 0.4,
77
+ "max_steps": 100,
78
+ "confusion_spike_prob": 0.15,
79
+ },
80
+ }
81
+ return configs.get(self.task_difficulty, configs[TaskDifficulty.MEDIUM])
82
+
83
+ def _generate_synthetic_data(self) -> Tuple[Observation, float]:
84
+ step = self.step_count
85
+ config = self._task_config
86
+
87
+ base_confusion = config["confusion_base"]
88
+ noise = np.random.normal(0, config["noise_level"])
89
+
90
+ confusion_trend = np.sin(step * 0.1) * 0.2
91
+ confusion_spike = config["confusion_spike_prob"] if np.random.random() < config["confusion_spike_prob"] else 0.0
92
+
93
+ self._ground_truth_confusion = np.clip(
94
+ base_confusion + confusion_trend + noise + confusion_spike,
95
+ 0.0, 1.0
96
+ )
97
+ self._confusion_trajectory.append(self._ground_truth_confusion)
98
+
99
+ gaze_features = np.random.randn(16).tolist()
100
+ gesture_features = np.random.randn(63).tolist()
101
+ biometric_features = [
102
+ 60 + np.random.randn() * 10,
103
+ 0.5 + np.random.randn() * 0.1,
104
+ 36.6 + np.random.randn() * 0.5,
105
+ 15 + np.random.randn() * 3,
106
+ 0.3 + np.random.randn() * 0.1,
107
+ 0.7 + np.random.randn() * 0.1,
108
+ ]
109
+ audio_features = [200 + np.random.randn() * 50, 0.3 + np.random.randn() * 0.1]
110
+ behavioral_features = np.random.randn(8).tolist()
111
+
112
+ behavioral_features[0] = self._ground_truth_confusion * 0.5
113
+ behavioral_features[1] = self._ground_truth_confusion * 0.3
114
+
115
+ learning_context = {
116
+ "topic": np.random.choice(["math", "science", "programming", "language"]),
117
+ "difficulty": self.task_difficulty.value,
118
+ "time_spent": step * 30,
119
+ "content_length": np.random.randint(100, 1000),
120
+ "subtopic": np.random.choice(["basics", "intermediate", "advanced"]),
121
+ }
122
+
123
+ learner_state = {
124
+ "engagement": 1.0 - self._ground_truth_confusion,
125
+ "frustration": self._ground_truth_confusion * 0.8,
126
+ "comprehension": 0.7 - self._ground_truth_confusion * 0.3,
127
+ "confusion_level": self._get_confusion_level(self._ground_truth_confusion).value,
128
+ }
129
+
130
+ observation = Observation(
131
+ step=self.step_count,
132
+ episode_id=self.episode_id,
133
+ learning_context=learning_context,
134
+ learner_state=learner_state,
135
+ gaze_features=gaze_features,
136
+ gesture_features=gesture_features,
137
+ biometric_features=biometric_features,
138
+ audio_features=audio_features,
139
+ behavioral_features=behavioral_features,
140
+ confusion_history=self._confusion_trajectory[-10:],
141
+ prediction_window=config["prediction_window"],
142
+ available_interventions=[
143
+ "hint", "simplify", "breakdown", "example", "scaffold",
144
+ "peer_connect", "break", "encourage"
145
+ ],
146
+ multimodal_fused=True,
147
+ metadata={
148
+ "knowledge_graph_mastery": self.knowledge_graph.get_prerequisite_mastery(
149
+ learning_context["topic"]
150
+ ),
151
+ "similar_learners": len(self.peer_learning.find_similar_learners(
152
+ learning_context["topic"]
153
+ )),
154
+ "recall_cards": len(self.recall_system.cards),
155
+ }
156
+ )
157
+
158
+ return observation, self._ground_truth_confusion
159
+
160
+ def _get_confusion_level(self, prob: float) -> ConfusionLevel:
161
+ from agents import ConfusionLevel
162
+ if prob < 0.25:
163
+ return ConfusionLevel.LOW
164
+ elif prob < 0.5:
165
+ return ConfusionLevel.MEDIUM
166
+ elif prob < 0.75:
167
+ return ConfusionLevel.HIGH
168
+ else:
169
+ return ConfusionLevel.CRITICAL
170
+
171
+ def reset(self) -> Observation:
172
+ self.episode_id = str(uuid.uuid4())
173
+ self.step_count = 0
174
+ self._confusion_trajectory = []
175
+ self._prediction_history = []
176
+ self._intervention_history = []
177
+ self._ground_truth_confusion = 0.0
178
+
179
+ self.agent = ContextFlowAgent()
180
+
181
+ observation, _ = self._generate_synthetic_data()
182
+ self._last_observation = observation
183
+
184
+ self._state = State(
185
+ episode_id=self.episode_id,
186
+ step_count=self.step_count,
187
+ max_steps=self._task_config["max_steps"],
188
+ task_difficulty=self.task_difficulty,
189
+ ground_truth_confusion=self._ground_truth_confusion,
190
+ predictions_history=[],
191
+ interventions_history=[],
192
+ episode_reward=0.0,
193
+ task_complete=False,
194
+ task_success=False,
195
+ )
196
+
197
+ return observation
198
+
199
+ def step(self, action: Action) -> StepResult:
200
+ if self._state is None:
201
+ raise RuntimeError("Must call reset() before step()")
202
+
203
+ if self._state.task_complete:
204
+ return StepResult(
205
+ observation=self._create_current_observation(),
206
+ reward=Reward(total=0.0),
207
+ done=True,
208
+ info={"message": "Episode already complete"}
209
+ )
210
+
211
+ reward = self._calculate_reward(action)
212
+ self._state.episode_reward += reward.total
213
+
214
+ self._state.predictions_history.append({
215
+ "step": self.step_count,
216
+ "predicted": action.predicted_confusion,
217
+ "ground_truth": self._ground_truth_confusion,
218
+ "action_type": action.action_type.value,
219
+ "confusion_level": self._get_confusion_level(action.predicted_confusion or 0.5).value,
220
+ })
221
+
222
+ if action.action_type == ActionType.TRIGGER_INTERVENTION and action.intervention_type:
223
+ self._intervention_history.append({
224
+ "step": self.step_count,
225
+ "type": action.intervention_type,
226
+ "intensity": action.intervention_intensity or 0.5,
227
+ "effectiveness": 0.0,
228
+ })
229
+
230
+ if action.intervention_type == "peer_connect":
231
+ topic = self._last_observation.learning_context.get("topic", "general") if self._last_observation else "general"
232
+ peers = self.peer_learning.find_similar_learners(topic)
233
+ reward.total += 0.1 * min(len(peers), 3)
234
+
235
+ self.agent.update(reward.total, self._last_observation.learning_context if self._last_observation else {})
236
+
237
+ self.step_count += 1
238
+ self._state.step_count = self.step_count
239
+
240
+ observation, new_gt = self._generate_synthetic_data()
241
+ self._last_observation = observation
242
+
243
+ self._state.ground_truth_confusion = new_gt
244
+ self._state.interventions_history = self._intervention_history.copy()
245
+
246
+ if len(self._intervention_history) > 0:
247
+ last_idx = len(self._intervention_history) - 1
248
+ if len(self._confusion_trajectory) >= 3:
249
+ prev_confusion = self._confusion_trajectory[-3]
250
+ if new_gt < prev_confusion:
251
+ self._intervention_history[last_idx]["effectiveness"] = 0.8
252
+ else:
253
+ self._intervention_history[last_idx]["effectiveness"] = 0.3
254
+
255
+ if self.step_count >= self._task_config["max_steps"]:
256
+ self._state.task_complete = True
257
+ self._state.task_success = self._grade_task().passed
258
+
259
+ done = self._state.task_complete
260
+
261
+ return StepResult(
262
+ observation=observation,
263
+ reward=reward,
264
+ done=done,
265
+ info={
266
+ "grader_result": self._grade_task() if done else None,
267
+ "episode_summary": {
268
+ "total_reward": self._state.episode_reward,
269
+ "predictions_made": len(self._prediction_history),
270
+ "interventions_triggered": len(self._intervention_history),
271
+ "knowledge_graph_active": True,
272
+ "peer_learning_active": True,
273
+ "recall_system_active": True,
274
+ },
275
+ "agent_state": {
276
+ "epsilon": self.agent.epsilon,
277
+ "recent_avg_reward": np.mean(self.agent.episode_rewards[-10:]) if self.agent.episode_rewards else 0.0,
278
+ }
279
+ }
280
+ )
281
+
282
+ def _calculate_reward(self, action: Action) -> Reward:
283
+ gt = self._ground_truth_confusion
284
+ pred = action.predicted_confusion if action.predicted_confusion is not None else 0.5
285
+
286
+ prediction_error = abs(pred - gt)
287
+ confusion_reward = 1.0 - prediction_error
288
+
289
+ early_detection = 0.0
290
+ if len(self._confusion_trajectory) > 1:
291
+ prev_confusion = self._confusion_trajectory[-2]
292
+ if gt > prev_confusion and pred > prev_confusion:
293
+ early_detection = 0.2
294
+ if gt > 0.6 and pred > 0.6:
295
+ early_detection = 0.3
296
+
297
+ intervention_reward = 0.0
298
+ if action.action_type == ActionType.TRIGGER_INTERVENTION:
299
+ if gt > self._task_config["intervention_threshold"]:
300
+ intervention_reward = 0.3
301
+ elif gt < 0.3:
302
+ intervention_reward = -0.1
303
+
304
+ partial_progress = 0.0
305
+ if len(self._confusion_trajectory) >= 5:
306
+ recent_trend = np.mean(self._confusion_trajectory[-5:])
307
+ if recent_trend < 0.4:
308
+ partial_progress = 0.1
309
+
310
+ penalty = 0.0
311
+ if action.intervention_intensity and action.intervention_intensity > 0.9:
312
+ penalty = -0.2
313
+
314
+ total = confusion_reward * 0.4 + early_detection * 0.2 + intervention_reward * 0.2 + partial_progress * 0.1 + penalty
315
+
316
+ return Reward(
317
+ total=total,
318
+ confusion_prediction_reward=confusion_reward * 0.4,
319
+ early_detection_reward=early_detection,
320
+ intervention_reward=intervention_reward,
321
+ partial_progress_reward=partial_progress,
322
+ penalty=penalty,
323
+ metadata={
324
+ "prediction_error": prediction_error,
325
+ "ground_truth": gt,
326
+ "predicted": pred,
327
+ }
328
+ )
329
+
330
+ def _grade_task(self) -> GraderResult:
331
+ if not self._prediction_history:
332
+ return GraderResult(
333
+ score=0.0,
334
+ feedback="No predictions made",
335
+ metrics={},
336
+ passed=False
337
+ )
338
+
339
+ predictions = self._state.predictions_history
340
+ gt_trajectory = self._confusion_trajectory[:len(predictions)]
341
+
342
+ mae = np.mean([
343
+ abs(p["predicted"] - gt)
344
+ for p, gt in zip(predictions, gt_trajectory)
345
+ if p["predicted"] is not None
346
+ ])
347
+
348
+ confusion_threshold = 0.6
349
+ early_detections = 0
350
+ total_spikes = 0
351
+
352
+ for i in range(1, len(gt_trajectory)):
353
+ if gt_trajectory[i] > confusion_threshold:
354
+ total_spikes += 1
355
+ if i < len(predictions) and predictions[i]["predicted"] > confusion_threshold:
356
+ if predictions[i]["confusion_level"] in ["high", "critical"]:
357
+ early_detections += 1
358
+
359
+ early_detection_rate = early_detections / max(total_spikes, 1)
360
+
361
+ intervention_effectiveness = 0.0
362
+ if self._intervention_history:
363
+ effective_interventions = sum(1 for i in self._intervention_history if i.get("effectiveness", 0) > 0.5)
364
+ intervention_effectiveness = effective_interventions / len(self._intervention_history)
365
+
366
+ score = (1 - mae) * 0.4 + early_detection_rate * 0.3 + intervention_effectiveness * 0.3
367
+
368
+ feedback_parts = []
369
+ feedback_parts.append(f"MAE: {mae:.3f}")
370
+ feedback_parts.append(f"Early Detection: {early_detection_rate:.1%}")
371
+ feedback_parts.append(f"Intervention Effect: {intervention_effectiveness:.1%}")
372
+ feedback_parts.append(f"Predictions: {len(predictions)}")
373
+ feedback_parts.append(f"Interventions: {len(self._intervention_history)}")
374
+
375
+ passed = score >= self._get_passing_threshold()
376
+
377
+ return GraderResult(
378
+ score=score,
379
+ feedback=" | ".join(feedback_parts),
380
+ metrics={
381
+ "mae": float(mae),
382
+ "early_detection_rate": float(early_detection_rate),
383
+ "intervention_effectiveness": float(intervention_effectiveness),
384
+ "total_predictions": len(predictions),
385
+ "total_interventions": len(self._intervention_history),
386
+ },
387
+ passed=passed
388
+ )
389
+
390
+ def _get_passing_threshold(self) -> float:
391
+ thresholds = {
392
+ TaskDifficulty.EASY: 0.5,
393
+ TaskDifficulty.MEDIUM: 0.6,
394
+ TaskDifficulty.HARD: 0.7,
395
+ }
396
+ return thresholds.get(self.task_difficulty, 0.6)
397
+
398
+ def _create_current_observation(self) -> Observation:
399
+ return Observation(
400
+ step=self.step_count,
401
+ episode_id=self.episode_id,
402
+ learning_context={"topic": "completed"},
403
+ learner_state={"engagement": 0.0},
404
+ gaze_features=[],
405
+ gesture_features=[],
406
+ biometric_features=[],
407
+ audio_features=[],
408
+ behavioral_features=[],
409
+ confusion_history=self._confusion_trajectory,
410
+ prediction_window=self._task_config["prediction_window"],
411
+ available_interventions=[],
412
+ multimodal_fused=True,
413
+ )
414
+
415
+ def get_state(self) -> State:
416
+ if self._state is None:
417
+ raise RuntimeError("Must call reset() before get_state()")
418
+ return self._state
419
+
420
+ def get_agent_prediction(self) -> AgentPrediction:
421
+ obs_dict = {
422
+ "gaze_features": self._last_observation.gaze_features if self._last_observation else [],
423
+ "gesture_features": self._last_observation.gesture_features if self._last_observation else [],
424
+ "biometric_features": self._last_observation.biometric_features if self._last_observation else [],
425
+ "behavioral_features": self._last_observation.behavioral_features if self._last_observation else [],
426
+ "audio_features": self._last_observation.audio_features if self._last_observation else [],
427
+ "learning_context": {"difficulty": self.task_difficulty.value},
428
+ }
429
+ return self.agent.predict(obs_dict)
430
+
431
+ def get_grader(self, difficulty: Optional[TaskDifficulty] = None) -> callable:
432
+ difficulty = difficulty or self.task_difficulty
433
+ def grade(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
434
+ nonlocal difficulty
435
+ temp_env = ContextFlowEnvironment(task_difficulty=difficulty)
436
+ temp_env._confusion_trajectory = ground_truth.copy()
437
+ temp_env._prediction_history = predictions
438
+ temp_env._intervention_history = interventions
439
+ temp_env._state = State(
440
+ episode_id="grader",
441
+ step_count=len(predictions),
442
+ max_steps=temp_env._task_config["max_steps"],
443
+ task_difficulty=difficulty,
444
+ predictions_history=[
445
+ {"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
446
+ for i, (p, gt) in enumerate(zip(predictions, ground_truth))
447
+ ],
448
+ interventions_history=interventions,
449
+ )
450
+ return temp_env._grade_task()
451
+ return grade
452
+
453
+
454
+ def easy_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
455
+ env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.EASY)
456
+ env._confusion_trajectory = ground_truth.copy()
457
+ env._prediction_history = predictions
458
+ env._intervention_history = interventions
459
+ env._state = State(
460
+ episode_id="easy_grader",
461
+ step_count=len(predictions),
462
+ max_steps=env._task_config["max_steps"],
463
+ task_difficulty=TaskDifficulty.EASY,
464
+ predictions_history=[
465
+ {"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
466
+ for i, (p, gt) in enumerate(zip(predictions, ground_truth))
467
+ ],
468
+ interventions_history=interventions,
469
+ )
470
+ return env._grade_task()
471
+
472
+
473
+ def medium_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
474
+ env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.MEDIUM)
475
+ env._confusion_trajectory = ground_truth.copy()
476
+ env._prediction_history = predictions
477
+ env._intervention_history = interventions
478
+ env._state = State(
479
+ episode_id="medium_grader",
480
+ step_count=len(predictions),
481
+ max_steps=env._task_config["max_steps"],
482
+ task_difficulty=TaskDifficulty.MEDIUM,
483
+ predictions_history=[
484
+ {"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
485
+ for i, (p, gt) in enumerate(zip(predictions, ground_truth))
486
+ ],
487
+ interventions_history=interventions,
488
+ )
489
+ return env._grade_task()
490
+
491
+
492
+ def hard_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
493
+ env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.HARD)
494
+ env._confusion_trajectory = ground_truth.copy()
495
+ env._prediction_history = predictions
496
+ env._intervention_history = interventions
497
+ env._state = State(
498
+ episode_id="hard_grader",
499
+ step_count=len(predictions),
500
+ max_steps=env._task_config["max_steps"],
501
+ task_difficulty=TaskDifficulty.HARD,
502
+ predictions_history=[
503
+ {"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
504
+ for i, (p, gt) in enumerate(zip(predictions, ground_truth))
505
+ ],
506
+ interventions_history=interventions,
507
+ )
508
+ return env._grade_task()
509
+
510
+
511
+ __all__ = ["ContextFlowEnvironment", "easy_grader", "medium_grader", "hard_grader"]