namish10 commited on
Commit
bdc2b78
·
verified ·
1 Parent(s): 357af64

Upload inference_example.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference_example.py +179 -0
inference_example.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ContextFlow RL Model Inference Example
3
+
4
+ This script demonstrates how to load the trained checkpoint and make predictions.
5
+ """
6
+
7
+ import pickle
8
+ import numpy as np
9
+ import sys
10
+ import os
11
+
12
+ # Add current directory to path
13
+ sys.path.insert(0, os.path.dirname(__file__))
14
+
15
+ from feature_extractor import FeatureExtractor
16
+
17
+
18
+ # Doubt action labels (10 actions)
19
+ DOUBT_ACTIONS = [
20
+ "what_is_backpropagation",
21
+ "why_gradient_descent",
22
+ "how_overfitting_works",
23
+ "explain_regularization",
24
+ "what_loss_function",
25
+ "how_optimization_works",
26
+ "explain_learning_rate",
27
+ "what_regularization",
28
+ "how_batch_norm_works",
29
+ "explain_softmax"
30
+ ]
31
+
32
+
33
+ class DoubtPredictor:
34
+ """Simple doubt predictor using the trained Q-network"""
35
+
36
+ def __init__(self, checkpoint_path: str):
37
+ self.extractor = FeatureExtractor()
38
+
39
+ # Load checkpoint
40
+ with open(checkpoint_path, 'rb') as f:
41
+ self.checkpoint = pickle.load(f)
42
+
43
+ print(f"Loaded checkpoint v{self.checkpoint.policy_version}")
44
+ print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}")
45
+
46
+ def extract_state(self, **kwargs) -> np.ndarray:
47
+ """Extract state vector from input features"""
48
+ return self.extractor.extract_state(**kwargs)
49
+
50
+ def predict(self, state: np.ndarray) -> dict:
51
+ """
52
+ Predict doubt actions from state
53
+
54
+ Returns:
55
+ dict with predicted actions and Q-values
56
+ """
57
+ # Simple linear approximation since we have Q-network weights
58
+ q_weights = self.checkpoint.q_network_weights
59
+
60
+ # Extract key weights (simplified)
61
+ if 'layer1.weight' in q_weights:
62
+ w1 = q_weights['layer1.weight']
63
+ b1 = q_weights['layer1.bias']
64
+ w2 = q_weights['layer2.weight']
65
+ b2 = q_weights['layer2.bias']
66
+ w3 = q_weights['output.weight']
67
+ b3 = q_weights['output.bias']
68
+
69
+ # Forward pass
70
+ h1 = np.maximum(np.dot(state, w1.T) + b1, 0) # ReLU
71
+ h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) # ReLU
72
+ q_values = np.dot(h2, w3.T) + b3
73
+ else:
74
+ # Fallback: random predictions
75
+ q_values = np.random.randn(10) * 0.5
76
+
77
+ # Get top 3 predictions
78
+ top_indices = np.argsort(q_values)[::-1][:3]
79
+
80
+ return {
81
+ 'predicted_doubt': DOUBT_ACTIONS[top_indices[0]],
82
+ 'confidence': float(q_values[top_indices[0]]),
83
+ 'top_predictions': [
84
+ {
85
+ 'action': DOUBT_ACTIONS[i],
86
+ 'q_value': float(q_values[i])
87
+ }
88
+ for i in top_indices
89
+ ]
90
+ }
91
+
92
+
93
+ def example_inference():
94
+ """Run example inferences"""
95
+ checkpoint_path = 'checkpoint.pkl'
96
+
97
+ if not os.path.exists(checkpoint_path):
98
+ print(f"Checkpoint not found: {checkpoint_path}")
99
+ print("Download from: https://huggingface.co/namish10/contextflow-rl")
100
+ return
101
+
102
+ predictor = DoubtPredictor(checkpoint_path)
103
+
104
+ print("\n" + "="*60)
105
+ print("EXAMPLE INFERENCES")
106
+ print("="*60)
107
+
108
+ # Example 1: Beginner ML student
109
+ print("\n[Scenario 1: Beginner ML student]")
110
+ state1 = predictor.extract_state(
111
+ topic="neural networks",
112
+ progress=0.3,
113
+ confusion_signals={
114
+ 'mouse_hesitation': 3.0,
115
+ 'scroll_reversals': 6,
116
+ 'time_on_page': 45,
117
+ 'back_button': 3,
118
+ 'copy_attempts': 1
119
+ },
120
+ gesture_signals={
121
+ 'pinch': 2,
122
+ 'point': 5
123
+ },
124
+ time_spent=120
125
+ )
126
+ result1 = predictor.predict(state1)
127
+ print(f" Predicted doubt: {result1['predicted_doubt']}")
128
+ print(f" Q-value: {result1['confidence']:.4f}")
129
+
130
+ # Example 2: Advanced learner struggling with regularization
131
+ print("\n[Scenario 2: Advanced learner, high confusion signals]")
132
+ state2 = predictor.extract_state(
133
+ topic="deep learning",
134
+ progress=0.7,
135
+ confusion_signals={
136
+ 'mouse_hesitation': 4.5,
137
+ 'scroll_reversals': 8,
138
+ 'time_on_page': 280,
139
+ 'back_button': 5,
140
+ 'copy_attempts': 2,
141
+ 'search_usage': 3
142
+ },
143
+ gesture_signals={
144
+ 'pinch': 8,
145
+ 'swipe_left': 4,
146
+ 'point': 10
147
+ },
148
+ time_spent=600
149
+ )
150
+ result2 = predictor.predict(state2)
151
+ print(f" Predicted doubt: {result2['predicted_doubt']}")
152
+ print(f" Q-value: {result2['confidence']:.4f}")
153
+
154
+ # Example 3: Quick learner, low confusion
155
+ print("\n[Scenario 3: Quick learner, low confusion]")
156
+ state3 = predictor.extract_state(
157
+ topic="python programming",
158
+ progress=0.9,
159
+ confusion_signals={
160
+ 'mouse_hesitation': 0.5,
161
+ 'scroll_reversals': 1,
162
+ 'time_on_page': 20,
163
+ 'back_button': 0
164
+ },
165
+ gesture_signals={
166
+ 'swipe_down': 5,
167
+ 'point': 3
168
+ },
169
+ time_spent=60
170
+ )
171
+ result3 = predictor.predict(state3)
172
+ print(f" Predicted doubt: {result3['predicted_doubt']}")
173
+ print(f" Q-value: {result3['confidence']:.4f}")
174
+
175
+ print("\n" + "="*60)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ example_inference()