""" ContextFlow RL Model Inference Example This script demonstrates how to load the trained checkpoint and make predictions. """ import pickle import numpy as np import sys import os # Add current directory to path sys.path.insert(0, os.path.dirname(__file__)) from feature_extractor import FeatureExtractor # Doubt action labels (10 actions) DOUBT_ACTIONS = [ "what_is_backpropagation", "why_gradient_descent", "how_overfitting_works", "explain_regularization", "what_loss_function", "how_optimization_works", "explain_learning_rate", "what_regularization", "how_batch_norm_works", "explain_softmax" ] class DoubtPredictor: """Simple doubt predictor using the trained Q-network""" def __init__(self, checkpoint_path: str): self.extractor = FeatureExtractor() # Load checkpoint with open(checkpoint_path, 'rb') as f: self.checkpoint = pickle.load(f) print(f"Loaded checkpoint v{self.checkpoint.policy_version}") print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}") def extract_state(self, **kwargs) -> np.ndarray: """Extract state vector from input features""" return self.extractor.extract_state(**kwargs) def predict(self, state: np.ndarray) -> dict: """ Predict doubt actions from state Returns: dict with predicted actions and Q-values """ # Simple linear approximation since we have Q-network weights q_weights = self.checkpoint.q_network_weights # Extract key weights (simplified) if 'layer1.weight' in q_weights: w1 = q_weights['layer1.weight'] b1 = q_weights['layer1.bias'] w2 = q_weights['layer2.weight'] b2 = q_weights['layer2.bias'] w3 = q_weights['output.weight'] b3 = q_weights['output.bias'] # Forward pass h1 = np.maximum(np.dot(state, w1.T) + b1, 0) # ReLU h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) # ReLU q_values = np.dot(h2, w3.T) + b3 else: # Fallback: random predictions q_values = np.random.randn(10) * 0.5 # Get top 3 predictions top_indices = np.argsort(q_values)[::-1][:3] return { 'predicted_doubt': DOUBT_ACTIONS[top_indices[0]], 'confidence': float(q_values[top_indices[0]]), 'top_predictions': [ { 'action': DOUBT_ACTIONS[i], 'q_value': float(q_values[i]) } for i in top_indices ] } def example_inference(): """Run example inferences""" checkpoint_path = 'checkpoint.pkl' if not os.path.exists(checkpoint_path): print(f"Checkpoint not found: {checkpoint_path}") print("Download from: https://huggingface.co/namish10/contextflow-rl") return predictor = DoubtPredictor(checkpoint_path) print("\n" + "="*60) print("EXAMPLE INFERENCES") print("="*60) # Example 1: Beginner ML student print("\n[Scenario 1: Beginner ML student]") state1 = predictor.extract_state( topic="neural networks", progress=0.3, confusion_signals={ 'mouse_hesitation': 3.0, 'scroll_reversals': 6, 'time_on_page': 45, 'back_button': 3, 'copy_attempts': 1 }, gesture_signals={ 'pinch': 2, 'point': 5 }, time_spent=120 ) result1 = predictor.predict(state1) print(f" Predicted doubt: {result1['predicted_doubt']}") print(f" Q-value: {result1['confidence']:.4f}") # Example 2: Advanced learner struggling with regularization print("\n[Scenario 2: Advanced learner, high confusion signals]") state2 = predictor.extract_state( topic="deep learning", progress=0.7, confusion_signals={ 'mouse_hesitation': 4.5, 'scroll_reversals': 8, 'time_on_page': 280, 'back_button': 5, 'copy_attempts': 2, 'search_usage': 3 }, gesture_signals={ 'pinch': 8, 'swipe_left': 4, 'point': 10 }, time_spent=600 ) result2 = predictor.predict(state2) print(f" Predicted doubt: {result2['predicted_doubt']}") print(f" Q-value: {result2['confidence']:.4f}") # Example 3: Quick learner, low confusion print("\n[Scenario 3: Quick learner, low confusion]") state3 = predictor.extract_state( topic="python programming", progress=0.9, confusion_signals={ 'mouse_hesitation': 0.5, 'scroll_reversals': 1, 'time_on_page': 20, 'back_button': 0 }, gesture_signals={ 'swipe_down': 5, 'point': 3 }, time_spent=60 ) result3 = predictor.predict(state3) print(f" Predicted doubt: {result3['predicted_doubt']}") print(f" Q-value: {result3['confidence']:.4f}") print("\n" + "="*60) if __name__ == "__main__": example_inference()