namish10 commited on
Commit
e86ba0b
·
verified ·
1 Parent(s): 788411f

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +195 -0
models.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict, Any
3
+ from enum import Enum
4
+
5
+
6
+ class TaskDifficulty(str, Enum):
7
+ EASY = "easy"
8
+ MEDIUM = "medium"
9
+ HARD = "hard"
10
+
11
+
12
+ class ActionType(str, Enum):
13
+ PREDICT_CONFUSION = "predict_confusion"
14
+ ANALYZE_BEHAVIOR = "analyze_behavior"
15
+ TRIGGER_INTERVENTION = "trigger_intervention"
16
+ CLASSIFY_DIFFICULTY = "classify_difficulty"
17
+ FUSE_MODALITIES = "fuse_modalities"
18
+
19
+
20
+ class Observation(BaseModel):
21
+ step: int = Field(..., description="Current step in the episode")
22
+ episode_id: str = Field(..., description="Unique episode identifier")
23
+
24
+ learning_context: Dict[str, Any] = Field(
25
+ default_factory=dict,
26
+ description="Current learning context (topic, difficulty, time spent)"
27
+ )
28
+
29
+ learner_state: Dict[str, Any] = Field(
30
+ default_factory=dict,
31
+ description="Learner state signals from all modalities"
32
+ )
33
+
34
+ gaze_features: List[float] = Field(
35
+ default_factory=list,
36
+ description="Gaze tracking features (16 dimensions)"
37
+ )
38
+
39
+ gesture_features: List[float] = Field(
40
+ default_factory=list,
41
+ description="Hand gesture features (21 landmarks x 3 coords)"
42
+ )
43
+
44
+ biometric_features: List[float] = Field(
45
+ default_factory=list,
46
+ description="Biometric features (heart rate, GSR, etc.)"
47
+ )
48
+
49
+ audio_features: List[float] = Field(
50
+ default_factory=list,
51
+ description="Audio features (pitch, tone, pauses)"
52
+ )
53
+
54
+ behavioral_features: List[float] = Field(
55
+ default_factory=list,
56
+ description="Behavioral features (scroll speed, clicks, typing)"
57
+ )
58
+
59
+ confusion_history: List[float] = Field(
60
+ default_factory=list,
61
+ description="Historical confusion probabilities"
62
+ )
63
+
64
+ prediction_window: int = Field(
65
+ default=5,
66
+ description="Steps ahead to predict confusion"
67
+ )
68
+
69
+ available_interventions: List[str] = Field(
70
+ default_factory=list,
71
+ description="Available intervention types"
72
+ )
73
+
74
+ multimodal_fused: bool = Field(
75
+ default=False,
76
+ description="Whether multi-modal fusion is enabled"
77
+ )
78
+
79
+ metadata: Dict[str, Any] = Field(
80
+ default_factory=dict,
81
+ description="Additional metadata"
82
+ )
83
+
84
+
85
+ class Action(BaseModel):
86
+ action_type: ActionType = Field(..., description="Type of action to take")
87
+
88
+ predicted_confusion: Optional[float] = Field(
89
+ None,
90
+ description="Predicted confusion probability (0.0-1.0)",
91
+ ge=0.0,
92
+ le=1.0
93
+ )
94
+
95
+ intervention_type: Optional[str] = Field(
96
+ None,
97
+ description="Intervention to trigger (if action_type is trigger_intervention)"
98
+ )
99
+
100
+ intervention_intensity: Optional[float] = Field(
101
+ None,
102
+ description="Intervention intensity (0.0-1.0)",
103
+ ge=0.0,
104
+ le=1.0
105
+ )
106
+
107
+ difficulty_prediction: Optional[TaskDifficulty] = Field(
108
+ None,
109
+ description="Predicted task difficulty (if action_type is classify_difficulty)"
110
+ )
111
+
112
+ modality_weights: Optional[Dict[str, float]] = Field(
113
+ None,
114
+ description="Weights for multi-modal fusion",
115
+ ge=0.0,
116
+ le=1.0
117
+ )
118
+
119
+ reasoning: Optional[str] = Field(
120
+ None,
121
+ description="Agent's reasoning for the action"
122
+ )
123
+
124
+
125
+ class Reward(BaseModel):
126
+ total: float = Field(..., description="Total reward for this step")
127
+
128
+ confusion_prediction_reward: float = Field(
129
+ default=0.0,
130
+ description="Reward for confusion prediction accuracy"
131
+ )
132
+
133
+ early_detection_reward: float = Field(
134
+ default=0.0,
135
+ description="Reward for early confusion detection"
136
+ )
137
+
138
+ intervention_reward: float = Field(
139
+ default=0.0,
140
+ description="Reward for effective intervention"
141
+ )
142
+
143
+ partial_progress_reward: float = Field(
144
+ default=0.0,
145
+ description="Reward for partial progress toward goals"
146
+ )
147
+
148
+ penalty: float = Field(
149
+ default=0.0,
150
+ description="Penalty for negative behaviors"
151
+ )
152
+
153
+ metadata: Dict[str, Any] = Field(
154
+ default_factory=dict,
155
+ description="Additional reward metadata"
156
+ )
157
+
158
+
159
+ class State(BaseModel):
160
+ episode_id: str = Field(..., description="Unique episode identifier")
161
+ step_count: int = Field(default=0, description="Number of steps taken")
162
+ max_steps: int = Field(default=100, description="Maximum steps per episode")
163
+ task_difficulty: TaskDifficulty = Field(default=TaskDifficulty.MEDIUM)
164
+ ground_truth_confusion: Optional[float] = Field(None, description="Actual confusion level")
165
+ predictions_history: List[Dict[str, Any]] = Field(default_factory=list)
166
+ interventions_history: List[Dict[str, Any]] = Field(default_factory=list)
167
+ episode_reward: float = Field(default=0.0)
168
+ task_complete: bool = Field(default=False)
169
+ task_success: bool = Field(False)
170
+
171
+
172
+ class StepResult(BaseModel):
173
+ observation: Observation
174
+ reward: Reward
175
+ done: bool
176
+ info: Dict[str, Any] = Field(default_factory=dict)
177
+
178
+
179
+ class GraderResult(BaseModel):
180
+ score: float = Field(..., ge=0.0, le=1.0, description="Grader score (0.0-1.0)")
181
+ feedback: str = Field(..., description="Feedback on performance")
182
+ metrics: Dict[str, float] = Field(default_factory=dict)
183
+ passed: bool = Field(..., description="Whether task passed")
184
+
185
+
186
+ __all__ = [
187
+ "Observation",
188
+ "Action",
189
+ "Reward",
190
+ "State",
191
+ "StepResult",
192
+ "GraderResult",
193
+ "TaskDifficulty",
194
+ "ActionType",
195
+ ]