Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files
server/__init__.py
ADDED
|
File without changes
|
server/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
server/__pycache__/contextflow_environment.cpython-313.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
server/app.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from models import Observation, Action, Reward, State, StepResult, TaskDifficulty
|
| 8 |
+
from server.contextflow_environment import ContextFlowEnvironment
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
app = FastAPI(title="ContextFlow OpenEnv")
|
| 12 |
+
|
| 13 |
+
connections: dict[str, WebSocket] = {}
|
| 14 |
+
environments: dict[str, ContextFlowEnvironment] = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@app.get("/")
|
| 18 |
+
async def root():
|
| 19 |
+
return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@app.get("/health")
|
| 23 |
+
async def health():
|
| 24 |
+
return {"status": "healthy"}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@app.post("/reset")
|
| 28 |
+
async def reset(difficulty: Optional[str] = "medium"):
|
| 29 |
+
try:
|
| 30 |
+
difficulty_enum = TaskDifficulty(difficulty.lower())
|
| 31 |
+
except ValueError:
|
| 32 |
+
difficulty_enum = TaskDifficulty.MEDIUM
|
| 33 |
+
|
| 34 |
+
env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
|
| 35 |
+
observation = env.reset()
|
| 36 |
+
|
| 37 |
+
env_id = observation.episode_id
|
| 38 |
+
environments[env_id] = env
|
| 39 |
+
|
| 40 |
+
return {
|
| 41 |
+
"observation": observation.model_dump(),
|
| 42 |
+
"episode_id": env_id,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@app.post("/step")
|
| 47 |
+
async def step(action: Action):
|
| 48 |
+
if not action.episode_id or action.episode_id not in environments:
|
| 49 |
+
return JSONResponse(
|
| 50 |
+
status_code=400,
|
| 51 |
+
content={"error": "Invalid or missing episode_id"}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
env = environments[action.episode_id]
|
| 55 |
+
result = env.step(action)
|
| 56 |
+
|
| 57 |
+
if result.done:
|
| 58 |
+
del environments[action.episode_id]
|
| 59 |
+
|
| 60 |
+
return result.model_dump()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@app.get("/state/{episode_id}")
|
| 64 |
+
async def get_state(episode_id: str):
|
| 65 |
+
if episode_id not in environments:
|
| 66 |
+
return JSONResponse(
|
| 67 |
+
status_code=404,
|
| 68 |
+
content={"error": "Episode not found"}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
env = environments[episode_id]
|
| 72 |
+
return env.state().model_dump()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.websocket("/ws/{episode_id}")
|
| 76 |
+
async def websocket_endpoint(websocket: WebSocket, episode_id: str):
|
| 77 |
+
await websocket.accept()
|
| 78 |
+
connections[episode_id] = websocket
|
| 79 |
+
|
| 80 |
+
if episode_id not in environments:
|
| 81 |
+
await websocket.send_json({"error": "Episode not found"})
|
| 82 |
+
await websocket.close()
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
while True:
|
| 87 |
+
data = await websocket.receive_text()
|
| 88 |
+
message = json.loads(data)
|
| 89 |
+
|
| 90 |
+
if message["type"] == "reset":
|
| 91 |
+
difficulty = message.get("difficulty", "medium")
|
| 92 |
+
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty))
|
| 93 |
+
observation = env.reset()
|
| 94 |
+
environments[episode_id] = env
|
| 95 |
+
await websocket.send_json({
|
| 96 |
+
"type": "reset",
|
| 97 |
+
"observation": observation.model_dump()
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
elif message["type"] == "step":
|
| 101 |
+
if episode_id not in environments:
|
| 102 |
+
await websocket.send_json({"error": "Episode not found"})
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
env = environments[episode_id]
|
| 106 |
+
action = Action(**message["action"])
|
| 107 |
+
result = env.step(action)
|
| 108 |
+
|
| 109 |
+
if result.done:
|
| 110 |
+
del environments[episode_id]
|
| 111 |
+
|
| 112 |
+
await websocket.send_json({
|
| 113 |
+
"type": "step",
|
| 114 |
+
"result": result.model_dump()
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
elif message["type"] == "state":
|
| 118 |
+
if episode_id not in environments:
|
| 119 |
+
await websocket.send_json({"error": "Episode not found"})
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
env = environments[episode_id]
|
| 123 |
+
await websocket.send_json({
|
| 124 |
+
"type": "state",
|
| 125 |
+
"state": env.state().model_dump()
|
| 126 |
+
})
|
| 127 |
+
|
| 128 |
+
except WebSocketDisconnect:
|
| 129 |
+
pass
|
| 130 |
+
finally:
|
| 131 |
+
if episode_id in connections:
|
| 132 |
+
del connections[episode_id]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import uvicorn
|
| 137 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
server/contextflow_environment.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Dict, Any, Optional, Tuple
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from models import (
|
| 7 |
+
Observation, Action, Reward, State, StepResult,
|
| 8 |
+
TaskDifficulty, ActionType, GraderResult
|
| 9 |
+
)
|
| 10 |
+
from agents import (
|
| 11 |
+
ContextFlowAgent,
|
| 12 |
+
AgentPrediction,
|
| 13 |
+
ConfusionLevel,
|
| 14 |
+
InterventionType,
|
| 15 |
+
KnowledgeGraphAgent,
|
| 16 |
+
PeerLearningAgent,
|
| 17 |
+
RecallAgent,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ContextFlowEnvironment:
|
| 22 |
+
"""
|
| 23 |
+
OpenEnv environment with full ContextFlow multi-agent system.
|
| 24 |
+
|
| 25 |
+
Integrates:
|
| 26 |
+
- RL-based doubt prediction
|
| 27 |
+
- Multi-modal behavioral analysis
|
| 28 |
+
- Gesture recognition
|
| 29 |
+
- Knowledge graphs
|
| 30 |
+
- Peer learning
|
| 31 |
+
- Spaced repetition
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
MAX_STEPS = 100
|
| 35 |
+
|
| 36 |
+
def __init__(self, task_difficulty: TaskDifficulty = TaskDifficulty.MEDIUM):
|
| 37 |
+
self.task_difficulty = task_difficulty
|
| 38 |
+
self.episode_id: Optional[str] = None
|
| 39 |
+
self.step_count: int = 0
|
| 40 |
+
self._state: Optional[State] = None
|
| 41 |
+
self._last_observation: Optional[Observation] = None
|
| 42 |
+
|
| 43 |
+
self._ground_truth_confusion: float = 0.0
|
| 44 |
+
self._confusion_trajectory: list = []
|
| 45 |
+
self._prediction_history: list = []
|
| 46 |
+
self._intervention_history: list = []
|
| 47 |
+
self._task_config = self._get_task_config()
|
| 48 |
+
|
| 49 |
+
self.agent = ContextFlowAgent()
|
| 50 |
+
self.knowledge_graph = KnowledgeGraphAgent()
|
| 51 |
+
self.peer_learning = PeerLearningAgent()
|
| 52 |
+
self.recall_system = RecallAgent()
|
| 53 |
+
|
| 54 |
+
def _get_task_config(self) -> Dict[str, Any]:
|
| 55 |
+
configs = {
|
| 56 |
+
TaskDifficulty.EASY: {
|
| 57 |
+
"prediction_window": 3,
|
| 58 |
+
"noise_level": 0.1,
|
| 59 |
+
"confusion_base": 0.3,
|
| 60 |
+
"intervention_threshold": 0.6,
|
| 61 |
+
"max_steps": 50,
|
| 62 |
+
"confusion_spike_prob": 0.08,
|
| 63 |
+
},
|
| 64 |
+
TaskDifficulty.MEDIUM: {
|
| 65 |
+
"prediction_window": 5,
|
| 66 |
+
"noise_level": 0.2,
|
| 67 |
+
"confusion_base": 0.5,
|
| 68 |
+
"intervention_threshold": 0.5,
|
| 69 |
+
"max_steps": 75,
|
| 70 |
+
"confusion_spike_prob": 0.12,
|
| 71 |
+
},
|
| 72 |
+
TaskDifficulty.HARD: {
|
| 73 |
+
"prediction_window": 7,
|
| 74 |
+
"noise_level": 0.3,
|
| 75 |
+
"confusion_base": 0.6,
|
| 76 |
+
"intervention_threshold": 0.4,
|
| 77 |
+
"max_steps": 100,
|
| 78 |
+
"confusion_spike_prob": 0.15,
|
| 79 |
+
},
|
| 80 |
+
}
|
| 81 |
+
return configs.get(self.task_difficulty, configs[TaskDifficulty.MEDIUM])
|
| 82 |
+
|
| 83 |
+
def _generate_synthetic_data(self) -> Tuple[Observation, float]:
|
| 84 |
+
step = self.step_count
|
| 85 |
+
config = self._task_config
|
| 86 |
+
|
| 87 |
+
base_confusion = config["confusion_base"]
|
| 88 |
+
noise = np.random.normal(0, config["noise_level"])
|
| 89 |
+
|
| 90 |
+
confusion_trend = np.sin(step * 0.1) * 0.2
|
| 91 |
+
confusion_spike = config["confusion_spike_prob"] if np.random.random() < config["confusion_spike_prob"] else 0.0
|
| 92 |
+
|
| 93 |
+
self._ground_truth_confusion = np.clip(
|
| 94 |
+
base_confusion + confusion_trend + noise + confusion_spike,
|
| 95 |
+
0.0, 1.0
|
| 96 |
+
)
|
| 97 |
+
self._confusion_trajectory.append(self._ground_truth_confusion)
|
| 98 |
+
|
| 99 |
+
gaze_features = np.random.randn(16).tolist()
|
| 100 |
+
gesture_features = np.random.randn(63).tolist()
|
| 101 |
+
biometric_features = [
|
| 102 |
+
60 + np.random.randn() * 10,
|
| 103 |
+
0.5 + np.random.randn() * 0.1,
|
| 104 |
+
36.6 + np.random.randn() * 0.5,
|
| 105 |
+
15 + np.random.randn() * 3,
|
| 106 |
+
0.3 + np.random.randn() * 0.1,
|
| 107 |
+
0.7 + np.random.randn() * 0.1,
|
| 108 |
+
]
|
| 109 |
+
audio_features = [200 + np.random.randn() * 50, 0.3 + np.random.randn() * 0.1]
|
| 110 |
+
behavioral_features = np.random.randn(8).tolist()
|
| 111 |
+
|
| 112 |
+
behavioral_features[0] = self._ground_truth_confusion * 0.5
|
| 113 |
+
behavioral_features[1] = self._ground_truth_confusion * 0.3
|
| 114 |
+
|
| 115 |
+
learning_context = {
|
| 116 |
+
"topic": np.random.choice(["math", "science", "programming", "language"]),
|
| 117 |
+
"difficulty": self.task_difficulty.value,
|
| 118 |
+
"time_spent": step * 30,
|
| 119 |
+
"content_length": np.random.randint(100, 1000),
|
| 120 |
+
"subtopic": np.random.choice(["basics", "intermediate", "advanced"]),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
learner_state = {
|
| 124 |
+
"engagement": 1.0 - self._ground_truth_confusion,
|
| 125 |
+
"frustration": self._ground_truth_confusion * 0.8,
|
| 126 |
+
"comprehension": 0.7 - self._ground_truth_confusion * 0.3,
|
| 127 |
+
"confusion_level": self._get_confusion_level(self._ground_truth_confusion).value,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
observation = Observation(
|
| 131 |
+
step=self.step_count,
|
| 132 |
+
episode_id=self.episode_id,
|
| 133 |
+
learning_context=learning_context,
|
| 134 |
+
learner_state=learner_state,
|
| 135 |
+
gaze_features=gaze_features,
|
| 136 |
+
gesture_features=gesture_features,
|
| 137 |
+
biometric_features=biometric_features,
|
| 138 |
+
audio_features=audio_features,
|
| 139 |
+
behavioral_features=behavioral_features,
|
| 140 |
+
confusion_history=self._confusion_trajectory[-10:],
|
| 141 |
+
prediction_window=config["prediction_window"],
|
| 142 |
+
available_interventions=[
|
| 143 |
+
"hint", "simplify", "breakdown", "example", "scaffold",
|
| 144 |
+
"peer_connect", "break", "encourage"
|
| 145 |
+
],
|
| 146 |
+
multimodal_fused=True,
|
| 147 |
+
metadata={
|
| 148 |
+
"knowledge_graph_mastery": self.knowledge_graph.get_prerequisite_mastery(
|
| 149 |
+
learning_context["topic"]
|
| 150 |
+
),
|
| 151 |
+
"similar_learners": len(self.peer_learning.find_similar_learners(
|
| 152 |
+
learning_context["topic"]
|
| 153 |
+
)),
|
| 154 |
+
"recall_cards": len(self.recall_system.cards),
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return observation, self._ground_truth_confusion
|
| 159 |
+
|
| 160 |
+
def _get_confusion_level(self, prob: float) -> ConfusionLevel:
|
| 161 |
+
from agents import ConfusionLevel
|
| 162 |
+
if prob < 0.25:
|
| 163 |
+
return ConfusionLevel.LOW
|
| 164 |
+
elif prob < 0.5:
|
| 165 |
+
return ConfusionLevel.MEDIUM
|
| 166 |
+
elif prob < 0.75:
|
| 167 |
+
return ConfusionLevel.HIGH
|
| 168 |
+
else:
|
| 169 |
+
return ConfusionLevel.CRITICAL
|
| 170 |
+
|
| 171 |
+
def reset(self) -> Observation:
|
| 172 |
+
self.episode_id = str(uuid.uuid4())
|
| 173 |
+
self.step_count = 0
|
| 174 |
+
self._confusion_trajectory = []
|
| 175 |
+
self._prediction_history = []
|
| 176 |
+
self._intervention_history = []
|
| 177 |
+
self._ground_truth_confusion = 0.0
|
| 178 |
+
|
| 179 |
+
self.agent = ContextFlowAgent()
|
| 180 |
+
|
| 181 |
+
observation, _ = self._generate_synthetic_data()
|
| 182 |
+
self._last_observation = observation
|
| 183 |
+
|
| 184 |
+
self._state = State(
|
| 185 |
+
episode_id=self.episode_id,
|
| 186 |
+
step_count=self.step_count,
|
| 187 |
+
max_steps=self._task_config["max_steps"],
|
| 188 |
+
task_difficulty=self.task_difficulty,
|
| 189 |
+
ground_truth_confusion=self._ground_truth_confusion,
|
| 190 |
+
predictions_history=[],
|
| 191 |
+
interventions_history=[],
|
| 192 |
+
episode_reward=0.0,
|
| 193 |
+
task_complete=False,
|
| 194 |
+
task_success=False,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
return observation
|
| 198 |
+
|
| 199 |
+
def step(self, action: Action) -> StepResult:
|
| 200 |
+
if self._state is None:
|
| 201 |
+
raise RuntimeError("Must call reset() before step()")
|
| 202 |
+
|
| 203 |
+
if self._state.task_complete:
|
| 204 |
+
return StepResult(
|
| 205 |
+
observation=self._create_current_observation(),
|
| 206 |
+
reward=Reward(total=0.0),
|
| 207 |
+
done=True,
|
| 208 |
+
info={"message": "Episode already complete"}
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
reward = self._calculate_reward(action)
|
| 212 |
+
self._state.episode_reward += reward.total
|
| 213 |
+
|
| 214 |
+
self._state.predictions_history.append({
|
| 215 |
+
"step": self.step_count,
|
| 216 |
+
"predicted": action.predicted_confusion,
|
| 217 |
+
"ground_truth": self._ground_truth_confusion,
|
| 218 |
+
"action_type": action.action_type.value,
|
| 219 |
+
"confusion_level": self._get_confusion_level(action.predicted_confusion or 0.5).value,
|
| 220 |
+
})
|
| 221 |
+
|
| 222 |
+
if action.action_type == ActionType.TRIGGER_INTERVENTION and action.intervention_type:
|
| 223 |
+
self._intervention_history.append({
|
| 224 |
+
"step": self.step_count,
|
| 225 |
+
"type": action.intervention_type,
|
| 226 |
+
"intensity": action.intervention_intensity or 0.5,
|
| 227 |
+
"effectiveness": 0.0,
|
| 228 |
+
})
|
| 229 |
+
|
| 230 |
+
if action.intervention_type == "peer_connect":
|
| 231 |
+
topic = self._last_observation.learning_context.get("topic", "general") if self._last_observation else "general"
|
| 232 |
+
peers = self.peer_learning.find_similar_learners(topic)
|
| 233 |
+
reward.total += 0.1 * min(len(peers), 3)
|
| 234 |
+
|
| 235 |
+
self.agent.update(reward.total, self._last_observation.learning_context if self._last_observation else {})
|
| 236 |
+
|
| 237 |
+
self.step_count += 1
|
| 238 |
+
self._state.step_count = self.step_count
|
| 239 |
+
|
| 240 |
+
observation, new_gt = self._generate_synthetic_data()
|
| 241 |
+
self._last_observation = observation
|
| 242 |
+
|
| 243 |
+
self._state.ground_truth_confusion = new_gt
|
| 244 |
+
self._state.interventions_history = self._intervention_history.copy()
|
| 245 |
+
|
| 246 |
+
if len(self._intervention_history) > 0:
|
| 247 |
+
last_idx = len(self._intervention_history) - 1
|
| 248 |
+
if len(self._confusion_trajectory) >= 3:
|
| 249 |
+
prev_confusion = self._confusion_trajectory[-3]
|
| 250 |
+
if new_gt < prev_confusion:
|
| 251 |
+
self._intervention_history[last_idx]["effectiveness"] = 0.8
|
| 252 |
+
else:
|
| 253 |
+
self._intervention_history[last_idx]["effectiveness"] = 0.3
|
| 254 |
+
|
| 255 |
+
if self.step_count >= self._task_config["max_steps"]:
|
| 256 |
+
self._state.task_complete = True
|
| 257 |
+
self._state.task_success = self._grade_task().passed
|
| 258 |
+
|
| 259 |
+
done = self._state.task_complete
|
| 260 |
+
|
| 261 |
+
return StepResult(
|
| 262 |
+
observation=observation,
|
| 263 |
+
reward=reward,
|
| 264 |
+
done=done,
|
| 265 |
+
info={
|
| 266 |
+
"grader_result": self._grade_task() if done else None,
|
| 267 |
+
"episode_summary": {
|
| 268 |
+
"total_reward": self._state.episode_reward,
|
| 269 |
+
"predictions_made": len(self._prediction_history),
|
| 270 |
+
"interventions_triggered": len(self._intervention_history),
|
| 271 |
+
"knowledge_graph_active": True,
|
| 272 |
+
"peer_learning_active": True,
|
| 273 |
+
"recall_system_active": True,
|
| 274 |
+
},
|
| 275 |
+
"agent_state": {
|
| 276 |
+
"epsilon": self.agent.epsilon,
|
| 277 |
+
"recent_avg_reward": np.mean(self.agent.episode_rewards[-10:]) if self.agent.episode_rewards else 0.0,
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def _calculate_reward(self, action: Action) -> Reward:
|
| 283 |
+
gt = self._ground_truth_confusion
|
| 284 |
+
pred = action.predicted_confusion if action.predicted_confusion is not None else 0.5
|
| 285 |
+
|
| 286 |
+
prediction_error = abs(pred - gt)
|
| 287 |
+
confusion_reward = 1.0 - prediction_error
|
| 288 |
+
|
| 289 |
+
early_detection = 0.0
|
| 290 |
+
if len(self._confusion_trajectory) > 1:
|
| 291 |
+
prev_confusion = self._confusion_trajectory[-2]
|
| 292 |
+
if gt > prev_confusion and pred > prev_confusion:
|
| 293 |
+
early_detection = 0.2
|
| 294 |
+
if gt > 0.6 and pred > 0.6:
|
| 295 |
+
early_detection = 0.3
|
| 296 |
+
|
| 297 |
+
intervention_reward = 0.0
|
| 298 |
+
if action.action_type == ActionType.TRIGGER_INTERVENTION:
|
| 299 |
+
if gt > self._task_config["intervention_threshold"]:
|
| 300 |
+
intervention_reward = 0.3
|
| 301 |
+
elif gt < 0.3:
|
| 302 |
+
intervention_reward = -0.1
|
| 303 |
+
|
| 304 |
+
partial_progress = 0.0
|
| 305 |
+
if len(self._confusion_trajectory) >= 5:
|
| 306 |
+
recent_trend = np.mean(self._confusion_trajectory[-5:])
|
| 307 |
+
if recent_trend < 0.4:
|
| 308 |
+
partial_progress = 0.1
|
| 309 |
+
|
| 310 |
+
penalty = 0.0
|
| 311 |
+
if action.intervention_intensity and action.intervention_intensity > 0.9:
|
| 312 |
+
penalty = -0.2
|
| 313 |
+
|
| 314 |
+
total = confusion_reward * 0.4 + early_detection * 0.2 + intervention_reward * 0.2 + partial_progress * 0.1 + penalty
|
| 315 |
+
|
| 316 |
+
return Reward(
|
| 317 |
+
total=total,
|
| 318 |
+
confusion_prediction_reward=confusion_reward * 0.4,
|
| 319 |
+
early_detection_reward=early_detection,
|
| 320 |
+
intervention_reward=intervention_reward,
|
| 321 |
+
partial_progress_reward=partial_progress,
|
| 322 |
+
penalty=penalty,
|
| 323 |
+
metadata={
|
| 324 |
+
"prediction_error": prediction_error,
|
| 325 |
+
"ground_truth": gt,
|
| 326 |
+
"predicted": pred,
|
| 327 |
+
}
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def _grade_task(self) -> GraderResult:
|
| 331 |
+
if not self._prediction_history:
|
| 332 |
+
return GraderResult(
|
| 333 |
+
score=0.0,
|
| 334 |
+
feedback="No predictions made",
|
| 335 |
+
metrics={},
|
| 336 |
+
passed=False
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
predictions = self._state.predictions_history
|
| 340 |
+
gt_trajectory = self._confusion_trajectory[:len(predictions)]
|
| 341 |
+
|
| 342 |
+
mae = np.mean([
|
| 343 |
+
abs(p["predicted"] - gt)
|
| 344 |
+
for p, gt in zip(predictions, gt_trajectory)
|
| 345 |
+
if p["predicted"] is not None
|
| 346 |
+
])
|
| 347 |
+
|
| 348 |
+
confusion_threshold = 0.6
|
| 349 |
+
early_detections = 0
|
| 350 |
+
total_spikes = 0
|
| 351 |
+
|
| 352 |
+
for i in range(1, len(gt_trajectory)):
|
| 353 |
+
if gt_trajectory[i] > confusion_threshold:
|
| 354 |
+
total_spikes += 1
|
| 355 |
+
if i < len(predictions) and predictions[i]["predicted"] > confusion_threshold:
|
| 356 |
+
if predictions[i]["confusion_level"] in ["high", "critical"]:
|
| 357 |
+
early_detections += 1
|
| 358 |
+
|
| 359 |
+
early_detection_rate = early_detections / max(total_spikes, 1)
|
| 360 |
+
|
| 361 |
+
intervention_effectiveness = 0.0
|
| 362 |
+
if self._intervention_history:
|
| 363 |
+
effective_interventions = sum(1 for i in self._intervention_history if i.get("effectiveness", 0) > 0.5)
|
| 364 |
+
intervention_effectiveness = effective_interventions / len(self._intervention_history)
|
| 365 |
+
|
| 366 |
+
score = (1 - mae) * 0.4 + early_detection_rate * 0.3 + intervention_effectiveness * 0.3
|
| 367 |
+
|
| 368 |
+
feedback_parts = []
|
| 369 |
+
feedback_parts.append(f"MAE: {mae:.3f}")
|
| 370 |
+
feedback_parts.append(f"Early Detection: {early_detection_rate:.1%}")
|
| 371 |
+
feedback_parts.append(f"Intervention Effect: {intervention_effectiveness:.1%}")
|
| 372 |
+
feedback_parts.append(f"Predictions: {len(predictions)}")
|
| 373 |
+
feedback_parts.append(f"Interventions: {len(self._intervention_history)}")
|
| 374 |
+
|
| 375 |
+
passed = score >= self._get_passing_threshold()
|
| 376 |
+
|
| 377 |
+
return GraderResult(
|
| 378 |
+
score=score,
|
| 379 |
+
feedback=" | ".join(feedback_parts),
|
| 380 |
+
metrics={
|
| 381 |
+
"mae": float(mae),
|
| 382 |
+
"early_detection_rate": float(early_detection_rate),
|
| 383 |
+
"intervention_effectiveness": float(intervention_effectiveness),
|
| 384 |
+
"total_predictions": len(predictions),
|
| 385 |
+
"total_interventions": len(self._intervention_history),
|
| 386 |
+
},
|
| 387 |
+
passed=passed
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def _get_passing_threshold(self) -> float:
|
| 391 |
+
thresholds = {
|
| 392 |
+
TaskDifficulty.EASY: 0.5,
|
| 393 |
+
TaskDifficulty.MEDIUM: 0.6,
|
| 394 |
+
TaskDifficulty.HARD: 0.7,
|
| 395 |
+
}
|
| 396 |
+
return thresholds.get(self.task_difficulty, 0.6)
|
| 397 |
+
|
| 398 |
+
def _create_current_observation(self) -> Observation:
|
| 399 |
+
return Observation(
|
| 400 |
+
step=self.step_count,
|
| 401 |
+
episode_id=self.episode_id,
|
| 402 |
+
learning_context={"topic": "completed"},
|
| 403 |
+
learner_state={"engagement": 0.0},
|
| 404 |
+
gaze_features=[],
|
| 405 |
+
gesture_features=[],
|
| 406 |
+
biometric_features=[],
|
| 407 |
+
audio_features=[],
|
| 408 |
+
behavioral_features=[],
|
| 409 |
+
confusion_history=self._confusion_trajectory,
|
| 410 |
+
prediction_window=self._task_config["prediction_window"],
|
| 411 |
+
available_interventions=[],
|
| 412 |
+
multimodal_fused=True,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def get_state(self) -> State:
|
| 416 |
+
if self._state is None:
|
| 417 |
+
raise RuntimeError("Must call reset() before get_state()")
|
| 418 |
+
return self._state
|
| 419 |
+
|
| 420 |
+
def get_agent_prediction(self) -> AgentPrediction:
|
| 421 |
+
obs_dict = {
|
| 422 |
+
"gaze_features": self._last_observation.gaze_features if self._last_observation else [],
|
| 423 |
+
"gesture_features": self._last_observation.gesture_features if self._last_observation else [],
|
| 424 |
+
"biometric_features": self._last_observation.biometric_features if self._last_observation else [],
|
| 425 |
+
"behavioral_features": self._last_observation.behavioral_features if self._last_observation else [],
|
| 426 |
+
"audio_features": self._last_observation.audio_features if self._last_observation else [],
|
| 427 |
+
"learning_context": {"difficulty": self.task_difficulty.value},
|
| 428 |
+
}
|
| 429 |
+
return self.agent.predict(obs_dict)
|
| 430 |
+
|
| 431 |
+
def get_grader(self, difficulty: Optional[TaskDifficulty] = None) -> callable:
|
| 432 |
+
difficulty = difficulty or self.task_difficulty
|
| 433 |
+
def grade(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
|
| 434 |
+
nonlocal difficulty
|
| 435 |
+
temp_env = ContextFlowEnvironment(task_difficulty=difficulty)
|
| 436 |
+
temp_env._confusion_trajectory = ground_truth.copy()
|
| 437 |
+
temp_env._prediction_history = predictions
|
| 438 |
+
temp_env._intervention_history = interventions
|
| 439 |
+
temp_env._state = State(
|
| 440 |
+
episode_id="grader",
|
| 441 |
+
step_count=len(predictions),
|
| 442 |
+
max_steps=temp_env._task_config["max_steps"],
|
| 443 |
+
task_difficulty=difficulty,
|
| 444 |
+
predictions_history=[
|
| 445 |
+
{"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
|
| 446 |
+
for i, (p, gt) in enumerate(zip(predictions, ground_truth))
|
| 447 |
+
],
|
| 448 |
+
interventions_history=interventions,
|
| 449 |
+
)
|
| 450 |
+
return temp_env._grade_task()
|
| 451 |
+
return grade
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def easy_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
|
| 455 |
+
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.EASY)
|
| 456 |
+
env._confusion_trajectory = ground_truth.copy()
|
| 457 |
+
env._prediction_history = predictions
|
| 458 |
+
env._intervention_history = interventions
|
| 459 |
+
env._state = State(
|
| 460 |
+
episode_id="easy_grader",
|
| 461 |
+
step_count=len(predictions),
|
| 462 |
+
max_steps=env._task_config["max_steps"],
|
| 463 |
+
task_difficulty=TaskDifficulty.EASY,
|
| 464 |
+
predictions_history=[
|
| 465 |
+
{"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
|
| 466 |
+
for i, (p, gt) in enumerate(zip(predictions, ground_truth))
|
| 467 |
+
],
|
| 468 |
+
interventions_history=interventions,
|
| 469 |
+
)
|
| 470 |
+
return env._grade_task()
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def medium_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
|
| 474 |
+
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.MEDIUM)
|
| 475 |
+
env._confusion_trajectory = ground_truth.copy()
|
| 476 |
+
env._prediction_history = predictions
|
| 477 |
+
env._intervention_history = interventions
|
| 478 |
+
env._state = State(
|
| 479 |
+
episode_id="medium_grader",
|
| 480 |
+
step_count=len(predictions),
|
| 481 |
+
max_steps=env._task_config["max_steps"],
|
| 482 |
+
task_difficulty=TaskDifficulty.MEDIUM,
|
| 483 |
+
predictions_history=[
|
| 484 |
+
{"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
|
| 485 |
+
for i, (p, gt) in enumerate(zip(predictions, ground_truth))
|
| 486 |
+
],
|
| 487 |
+
interventions_history=interventions,
|
| 488 |
+
)
|
| 489 |
+
return env._grade_task()
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def hard_grader(predictions: list, ground_truth: list, interventions: list) -> GraderResult:
|
| 493 |
+
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty.HARD)
|
| 494 |
+
env._confusion_trajectory = ground_truth.copy()
|
| 495 |
+
env._prediction_history = predictions
|
| 496 |
+
env._intervention_history = interventions
|
| 497 |
+
env._state = State(
|
| 498 |
+
episode_id="hard_grader",
|
| 499 |
+
step_count=len(predictions),
|
| 500 |
+
max_steps=env._task_config["max_steps"],
|
| 501 |
+
task_difficulty=TaskDifficulty.HARD,
|
| 502 |
+
predictions_history=[
|
| 503 |
+
{"step": i, "predicted": p, "ground_truth": gt, "action_type": "predict", "confusion_level": "medium"}
|
| 504 |
+
for i, (p, gt) in enumerate(zip(predictions, ground_truth))
|
| 505 |
+
],
|
| 506 |
+
interventions_history=interventions,
|
| 507 |
+
)
|
| 508 |
+
return env._grade_task()
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
__all__ = ["ContextFlowEnvironment", "easy_grader", "medium_grader", "hard_grader"]
|