contextflow-rl / inference_example.py
namish10's picture
Upload inference_example.py with huggingface_hub
bdc2b78 verified
"""
ContextFlow RL Model Inference Example
This script demonstrates how to load the trained checkpoint and make predictions.
"""
import pickle
import numpy as np
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(__file__))
from feature_extractor import FeatureExtractor
# Doubt action labels (10 actions)
DOUBT_ACTIONS = [
"what_is_backpropagation",
"why_gradient_descent",
"how_overfitting_works",
"explain_regularization",
"what_loss_function",
"how_optimization_works",
"explain_learning_rate",
"what_regularization",
"how_batch_norm_works",
"explain_softmax"
]
class DoubtPredictor:
"""Simple doubt predictor using the trained Q-network"""
def __init__(self, checkpoint_path: str):
self.extractor = FeatureExtractor()
# Load checkpoint
with open(checkpoint_path, 'rb') as f:
self.checkpoint = pickle.load(f)
print(f"Loaded checkpoint v{self.checkpoint.policy_version}")
print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}")
def extract_state(self, **kwargs) -> np.ndarray:
"""Extract state vector from input features"""
return self.extractor.extract_state(**kwargs)
def predict(self, state: np.ndarray) -> dict:
"""
Predict doubt actions from state
Returns:
dict with predicted actions and Q-values
"""
# Simple linear approximation since we have Q-network weights
q_weights = self.checkpoint.q_network_weights
# Extract key weights (simplified)
if 'layer1.weight' in q_weights:
w1 = q_weights['layer1.weight']
b1 = q_weights['layer1.bias']
w2 = q_weights['layer2.weight']
b2 = q_weights['layer2.bias']
w3 = q_weights['output.weight']
b3 = q_weights['output.bias']
# Forward pass
h1 = np.maximum(np.dot(state, w1.T) + b1, 0) # ReLU
h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) # ReLU
q_values = np.dot(h2, w3.T) + b3
else:
# Fallback: random predictions
q_values = np.random.randn(10) * 0.5
# Get top 3 predictions
top_indices = np.argsort(q_values)[::-1][:3]
return {
'predicted_doubt': DOUBT_ACTIONS[top_indices[0]],
'confidence': float(q_values[top_indices[0]]),
'top_predictions': [
{
'action': DOUBT_ACTIONS[i],
'q_value': float(q_values[i])
}
for i in top_indices
]
}
def example_inference():
"""Run example inferences"""
checkpoint_path = 'checkpoint.pkl'
if not os.path.exists(checkpoint_path):
print(f"Checkpoint not found: {checkpoint_path}")
print("Download from: https://huggingface.co/namish10/contextflow-rl")
return
predictor = DoubtPredictor(checkpoint_path)
print("\n" + "="*60)
print("EXAMPLE INFERENCES")
print("="*60)
# Example 1: Beginner ML student
print("\n[Scenario 1: Beginner ML student]")
state1 = predictor.extract_state(
topic="neural networks",
progress=0.3,
confusion_signals={
'mouse_hesitation': 3.0,
'scroll_reversals': 6,
'time_on_page': 45,
'back_button': 3,
'copy_attempts': 1
},
gesture_signals={
'pinch': 2,
'point': 5
},
time_spent=120
)
result1 = predictor.predict(state1)
print(f" Predicted doubt: {result1['predicted_doubt']}")
print(f" Q-value: {result1['confidence']:.4f}")
# Example 2: Advanced learner struggling with regularization
print("\n[Scenario 2: Advanced learner, high confusion signals]")
state2 = predictor.extract_state(
topic="deep learning",
progress=0.7,
confusion_signals={
'mouse_hesitation': 4.5,
'scroll_reversals': 8,
'time_on_page': 280,
'back_button': 5,
'copy_attempts': 2,
'search_usage': 3
},
gesture_signals={
'pinch': 8,
'swipe_left': 4,
'point': 10
},
time_spent=600
)
result2 = predictor.predict(state2)
print(f" Predicted doubt: {result2['predicted_doubt']}")
print(f" Q-value: {result2['confidence']:.4f}")
# Example 3: Quick learner, low confusion
print("\n[Scenario 3: Quick learner, low confusion]")
state3 = predictor.extract_state(
topic="python programming",
progress=0.9,
confusion_signals={
'mouse_hesitation': 0.5,
'scroll_reversals': 1,
'time_on_page': 20,
'back_button': 0
},
gesture_signals={
'swipe_down': 5,
'point': 3
},
time_spent=60
)
result3 = predictor.predict(state3)
print(f" Predicted doubt: {result3['predicted_doubt']}")
print(f" Q-value: {result3['confidence']:.4f}")
print("\n" + "="*60)
if __name__ == "__main__":
example_inference()