Upload inference_example.py with huggingface_hub
Browse files- inference_example.py +179 -0
inference_example.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ContextFlow RL Model Inference Example
|
| 3 |
+
|
| 4 |
+
This script demonstrates how to load the trained checkpoint and make predictions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pickle
|
| 8 |
+
import numpy as np
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Add current directory to path
|
| 13 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 14 |
+
|
| 15 |
+
from feature_extractor import FeatureExtractor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Doubt action labels (10 actions)
|
| 19 |
+
DOUBT_ACTIONS = [
|
| 20 |
+
"what_is_backpropagation",
|
| 21 |
+
"why_gradient_descent",
|
| 22 |
+
"how_overfitting_works",
|
| 23 |
+
"explain_regularization",
|
| 24 |
+
"what_loss_function",
|
| 25 |
+
"how_optimization_works",
|
| 26 |
+
"explain_learning_rate",
|
| 27 |
+
"what_regularization",
|
| 28 |
+
"how_batch_norm_works",
|
| 29 |
+
"explain_softmax"
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DoubtPredictor:
|
| 34 |
+
"""Simple doubt predictor using the trained Q-network"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, checkpoint_path: str):
|
| 37 |
+
self.extractor = FeatureExtractor()
|
| 38 |
+
|
| 39 |
+
# Load checkpoint
|
| 40 |
+
with open(checkpoint_path, 'rb') as f:
|
| 41 |
+
self.checkpoint = pickle.load(f)
|
| 42 |
+
|
| 43 |
+
print(f"Loaded checkpoint v{self.checkpoint.policy_version}")
|
| 44 |
+
print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}")
|
| 45 |
+
|
| 46 |
+
def extract_state(self, **kwargs) -> np.ndarray:
|
| 47 |
+
"""Extract state vector from input features"""
|
| 48 |
+
return self.extractor.extract_state(**kwargs)
|
| 49 |
+
|
| 50 |
+
def predict(self, state: np.ndarray) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Predict doubt actions from state
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
dict with predicted actions and Q-values
|
| 56 |
+
"""
|
| 57 |
+
# Simple linear approximation since we have Q-network weights
|
| 58 |
+
q_weights = self.checkpoint.q_network_weights
|
| 59 |
+
|
| 60 |
+
# Extract key weights (simplified)
|
| 61 |
+
if 'layer1.weight' in q_weights:
|
| 62 |
+
w1 = q_weights['layer1.weight']
|
| 63 |
+
b1 = q_weights['layer1.bias']
|
| 64 |
+
w2 = q_weights['layer2.weight']
|
| 65 |
+
b2 = q_weights['layer2.bias']
|
| 66 |
+
w3 = q_weights['output.weight']
|
| 67 |
+
b3 = q_weights['output.bias']
|
| 68 |
+
|
| 69 |
+
# Forward pass
|
| 70 |
+
h1 = np.maximum(np.dot(state, w1.T) + b1, 0) # ReLU
|
| 71 |
+
h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) # ReLU
|
| 72 |
+
q_values = np.dot(h2, w3.T) + b3
|
| 73 |
+
else:
|
| 74 |
+
# Fallback: random predictions
|
| 75 |
+
q_values = np.random.randn(10) * 0.5
|
| 76 |
+
|
| 77 |
+
# Get top 3 predictions
|
| 78 |
+
top_indices = np.argsort(q_values)[::-1][:3]
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
'predicted_doubt': DOUBT_ACTIONS[top_indices[0]],
|
| 82 |
+
'confidence': float(q_values[top_indices[0]]),
|
| 83 |
+
'top_predictions': [
|
| 84 |
+
{
|
| 85 |
+
'action': DOUBT_ACTIONS[i],
|
| 86 |
+
'q_value': float(q_values[i])
|
| 87 |
+
}
|
| 88 |
+
for i in top_indices
|
| 89 |
+
]
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def example_inference():
|
| 94 |
+
"""Run example inferences"""
|
| 95 |
+
checkpoint_path = 'checkpoint.pkl'
|
| 96 |
+
|
| 97 |
+
if not os.path.exists(checkpoint_path):
|
| 98 |
+
print(f"Checkpoint not found: {checkpoint_path}")
|
| 99 |
+
print("Download from: https://huggingface.co/namish10/contextflow-rl")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
predictor = DoubtPredictor(checkpoint_path)
|
| 103 |
+
|
| 104 |
+
print("\n" + "="*60)
|
| 105 |
+
print("EXAMPLE INFERENCES")
|
| 106 |
+
print("="*60)
|
| 107 |
+
|
| 108 |
+
# Example 1: Beginner ML student
|
| 109 |
+
print("\n[Scenario 1: Beginner ML student]")
|
| 110 |
+
state1 = predictor.extract_state(
|
| 111 |
+
topic="neural networks",
|
| 112 |
+
progress=0.3,
|
| 113 |
+
confusion_signals={
|
| 114 |
+
'mouse_hesitation': 3.0,
|
| 115 |
+
'scroll_reversals': 6,
|
| 116 |
+
'time_on_page': 45,
|
| 117 |
+
'back_button': 3,
|
| 118 |
+
'copy_attempts': 1
|
| 119 |
+
},
|
| 120 |
+
gesture_signals={
|
| 121 |
+
'pinch': 2,
|
| 122 |
+
'point': 5
|
| 123 |
+
},
|
| 124 |
+
time_spent=120
|
| 125 |
+
)
|
| 126 |
+
result1 = predictor.predict(state1)
|
| 127 |
+
print(f" Predicted doubt: {result1['predicted_doubt']}")
|
| 128 |
+
print(f" Q-value: {result1['confidence']:.4f}")
|
| 129 |
+
|
| 130 |
+
# Example 2: Advanced learner struggling with regularization
|
| 131 |
+
print("\n[Scenario 2: Advanced learner, high confusion signals]")
|
| 132 |
+
state2 = predictor.extract_state(
|
| 133 |
+
topic="deep learning",
|
| 134 |
+
progress=0.7,
|
| 135 |
+
confusion_signals={
|
| 136 |
+
'mouse_hesitation': 4.5,
|
| 137 |
+
'scroll_reversals': 8,
|
| 138 |
+
'time_on_page': 280,
|
| 139 |
+
'back_button': 5,
|
| 140 |
+
'copy_attempts': 2,
|
| 141 |
+
'search_usage': 3
|
| 142 |
+
},
|
| 143 |
+
gesture_signals={
|
| 144 |
+
'pinch': 8,
|
| 145 |
+
'swipe_left': 4,
|
| 146 |
+
'point': 10
|
| 147 |
+
},
|
| 148 |
+
time_spent=600
|
| 149 |
+
)
|
| 150 |
+
result2 = predictor.predict(state2)
|
| 151 |
+
print(f" Predicted doubt: {result2['predicted_doubt']}")
|
| 152 |
+
print(f" Q-value: {result2['confidence']:.4f}")
|
| 153 |
+
|
| 154 |
+
# Example 3: Quick learner, low confusion
|
| 155 |
+
print("\n[Scenario 3: Quick learner, low confusion]")
|
| 156 |
+
state3 = predictor.extract_state(
|
| 157 |
+
topic="python programming",
|
| 158 |
+
progress=0.9,
|
| 159 |
+
confusion_signals={
|
| 160 |
+
'mouse_hesitation': 0.5,
|
| 161 |
+
'scroll_reversals': 1,
|
| 162 |
+
'time_on_page': 20,
|
| 163 |
+
'back_button': 0
|
| 164 |
+
},
|
| 165 |
+
gesture_signals={
|
| 166 |
+
'swipe_down': 5,
|
| 167 |
+
'point': 3
|
| 168 |
+
},
|
| 169 |
+
time_spent=60
|
| 170 |
+
)
|
| 171 |
+
result3 = predictor.predict(state3)
|
| 172 |
+
print(f" Predicted doubt: {result3['predicted_doubt']}")
|
| 173 |
+
print(f" Q-value: {result3['confidence']:.4f}")
|
| 174 |
+
|
| 175 |
+
print("\n" + "="*60)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
example_inference()
|