""" Online Learning Module for ContextFlow Implements continuous model improvement from real user interactions. Addresses: Online Learning requirement """ import numpy as np import pickle from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass, field from collections import deque import threading import time import json @dataclass class InteractionSample: """A single interaction sample for online learning""" state: np.ndarray action: int reward: float next_state: np.ndarray done: bool timestamp: float user_id: str confidence: float = 0.0 def to_dict(self) -> Dict: return { 'state': self.state.tolist(), 'action': self.action, 'reward': self.reward, 'next_state': self.next_state.tolist(), 'done': self.done, 'timestamp': self.timestamp, 'user_id': self.user_id, 'confidence': self.confidence } @dataclass class OnlineQNetwork: """Q-Network for online learning""" weights: Dict[str, np.ndarray] biases: Dict[str, np.ndarray] version: int = 1 def forward(self, state: np.ndarray) -> np.ndarray: """Forward pass through network""" # Layer 1 h1 = np.maximum(np.dot(state, self.weights['l1']) + self.biases['b1'], 0) # Layer 2 h2 = np.maximum(np.dot(h1, self.weights['l2']) + self.biases['b2'], 0) # Output q_values = np.dot(h2, self.weights['l3']) + self.biases['b3'] return q_values def clone_from(self, source: 'OnlineQNetwork'): """Clone weights from another network""" self.weights = {k: v.copy() for k, v in source.weights.items()} self.biases = {k: v.copy() for k, v in source.biases.items()} self.version = source.version + 1 class OnlineLearningEngine: """ Online learning engine for continuous model improvement. Features: - Incremental updates from user feedback - Experience replay buffer - Target network for stability - Periodic checkpointing """ def __init__( self, state_dim: int = 64, action_dim: int = 10, hidden_dim: int = 128, learning_rate: float = 0.001, gamma: float = 0.95, batch_size: int = 32, buffer_size: int = 10000, target_update_freq: int = 100 ): self.state_dim = state_dim self.action_dim = action_dim self.learning_rate = learning_rate self.gamma = gamma self.batch_size = batch_size self.target_update_freq = target_update_freq # Initialize networks self.q_network = self._init_network() self.target_network = self._init_network() self._sync_target() # Experience replay buffer self.replay_buffer = deque(maxlen=buffer_size) # Training stats self.total_updates = 0 self.update_count = 0 # Lock for thread safety self.lock = threading.Lock() # Callbacks for events self.on_checkpoint = None self.on_update = None def _init_network(self) -> OnlineQNetwork: """Initialize network weights""" np.random.seed(42) return OnlineQNetwork( weights={ 'l1': np.random.randn(self.state_dim, self.hidden_dim) * 0.1, 'l2': np.random.randn(self.hidden_dim, self.hidden_dim) * 0.1, 'l3': np.random.randn(self.hidden_dim, self.action_dim) * 0.1 }, biases={ 'b1': np.zeros(self.hidden_dim), 'b2': np.zeros(self.hidden_dim), 'b3': np.zeros(self.action_dim) }, version=1 ) def _sync_target(self): """Copy Q-network to target network""" self.target_network.clone_from(self.q_network) def add_interaction( self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, user_id: str = 'anonymous', confidence: float = 0.0 ): """Add a new interaction to the replay buffer""" sample = InteractionSample( state=state, action=action, reward=reward, next_state=next_state, done=done, timestamp=time.time(), user_id=user_id, confidence=confidence ) with self.lock: self.replay_buffer.append(sample) # Trigger online update if len(self.replay_buffer) >= self.batch_size: self.update() def update(self) -> Optional[Dict]: """Perform a single online update""" with self.lock: if len(self.replay_buffer) < self.batch_size: return None # Sample batch indices = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False) batch = [self.replay_buffer[i] for i in indices] # Extract batch arrays states = np.array([s.state for s in batch]) actions = np.array([s.action for s in batch]) rewards = np.array([s.reward for s in batch]) next_states = np.array([s.next_state for s in batch]) dones = np.array([s.done for s in batch]) # Compute targets current_q = self.q_network.forward(states) next_q = self.target_network.forward(next_states) targets = current_q.copy() max_next_q = np.max(next_q, axis=1) for i in range(self.batch_size): if dones[i]: targets[i, actions[i]] = rewards[i] else: targets[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i] # Compute gradients and update (simplified SGD) # In production, would use PyTorch autograd errors = targets - current_q # Gradient descent on layer 3 h2 = np.maximum(np.dot(states, self.q_network.weights['l1']) + self.q_network.biases['b1'], 0) h3 = np.maximum(np.dot(h2, self.q_network.weights['l2']) + self.q_network.biases['b2'], 0) for i in range(self.batch_size): grad_l3 = np.outer(h3[i], errors[i]) grad_b3 = errors[i] self.q_network.weights['l3'] += self.learning_rate * grad_l3 self.q_network.biases['b3'] += self.learning_rate * grad_b3 # Update target network periodically self.update_count += 1 if self.update_count % self.target_update_freq == 0: self._sync_target() self.total_updates += 1 loss = np.mean(errors ** 2) result = { 'loss': float(loss), 'updates': self.total_updates, 'buffer_size': len(self.replay_buffer) } if self.on_update: self.on_update(result) return result def predict(self, state: np.ndarray) -> Tuple[int, float]: """Predict best action for a state""" q_values = self.q_network.forward(state) action = int(np.argmax(q_values)) confidence = float(np.max(q_values)) return action, confidence def get_q_values(self, state: np.ndarray) -> np.ndarray: """Get Q-values for all actions""" return self.q_network.forward(state) def save_checkpoint(self, path: str): """Save model checkpoint""" checkpoint = { 'q_network': { 'weights': {k: v.tolist() for k, v in self.q_network.weights.items()}, 'biases': {k: v.tolist() for k, v in self.q_network.biases.items()}, 'version': self.q_network.version }, 'total_updates': self.total_updates, 'buffer_size': len(self.replay_buffer) } with open(path, 'w') as f: json.dump(checkpoint, f) if self.on_checkpoint: self.on_checkpoint(path) return path def load_checkpoint(self, path: str): """Load model checkpoint""" with open(path, 'r') as f: checkpoint = json.load(f) self.q_network.weights = {k: np.array(v) for k, v in checkpoint['q_network']['weights'].items()} self.q_network.biases = {k: np.array(v) for k, v in checkpoint['q_network']['biases'].items()} self.q_network.version = checkpoint['q_network']['version'] self.total_updates = checkpoint['total_updates'] self._sync_target() return checkpoint def get_stats(self) -> Dict: """Get learning statistics""" with self.lock: return { 'total_updates': self.total_updates, 'buffer_size': len(self.replay_buffer), 'buffer_capacity': self.replay_buffer.maxlen, 'network_version': self.q_network.version } class AdaptiveLearningScheduler: """ Adaptive learning rate scheduler based on performance. Reduces learning rate when performance plateaus. Increases when making good progress. """ def __init__( self, initial_lr: float = 0.001, min_lr: float = 0.00001, patience: int = 10, factor: float = 0.5 ): self.current_lr = initial_lr self.min_lr = min_lr self.patience = patience self.factor = factor self.best_loss = float('inf') self.wait_count = 0 self.history = [] def step(self, loss: float) -> float: """Update learning rate based on loss""" self.history.append(loss) if len(self.history) < 2: return self.current_lr if loss < self.best_loss: self.best_loss = loss self.wait_count = 0 else: self.wait_count += 1 if self.wait_count >= self.patience and self.current_lr > self.min_lr: self.current_lr *= self.factor self.wait_count = 0 return self.current_lr # API Integration class OnlineLearningAPI: """REST API wrapper for online learning""" def __init__(self, engine: OnlineLearningEngine): self.engine = engine def record_feedback( self, user_id: str, state: List[float], action: int, quality: int, # 1-5 quality rating comment: Optional[str] = None ) -> Dict: """ Record user feedback and trigger online update. Quality mapping: - 1: Very unhelpful (-1.0) - 2: Unhelpful (-0.5) - 3: Neutral (0.0) - 4: Helpful (0.5) - 5: Very helpful (1.0) """ reward_map = {1: -1.0, 2: -0.5, 3: 0.0, 4: 0.5, 5: 1.0} reward = reward_map.get(quality, 0.0) state_arr = np.array(state) # Simulate next state (in real impl, would come from actual interaction) next_state = state_arr + np.random.randn(len(state_arr)) * 0.1 self.engine.add_interaction( state=state_arr, action=action, reward=reward, next_state=next_state, done=False, user_id=user_id, confidence=reward ) return { 'status': 'recorded', 'reward': reward, 'total_updates': self.engine.total_updates } def get_prediction(self, state: List[float]) -> Dict: """Get prediction for a state""" state_arr = np.array(state) action, confidence = self.engine.predict(state_arr) q_values = self.engine.get_q_values(state_arr) return { 'action': action, 'confidence': confidence, 'q_values': q_values.tolist() } def get_stats(self) -> Dict: """Get learning stats""" return self.engine.get_stats() # Example usage if __name__ == "__main__": engine = OnlineLearningEngine() api = OnlineLearningAPI(engine) print("Online Learning Engine initialized") print(f"State dim: {engine.state_dim}, Action dim: {engine.action_dim}") # Simulate some feedback for i in range(100): state = np.random.randn(64) action = np.random.randint(0, 10) quality = np.random.randint(1, 6) result = api.record_feedback( user_id='test_user', state=state.tolist(), action=action, quality=quality ) print(f"\\nAfter 100 interactions:") print(f" Updates: {result['total_updates']}") print(f" Stats: {api.get_stats()}")