| """ |
| Online Learning Module for ContextFlow |
| |
| Implements continuous model improvement from real user interactions. |
| Addresses: Online Learning requirement |
| """ |
|
|
| import numpy as np |
| import pickle |
| from typing import Dict, List, Optional, Any, Tuple |
| from dataclasses import dataclass, field |
| from collections import deque |
| import threading |
| import time |
| import json |
|
|
|
|
| @dataclass |
| class InteractionSample: |
| """A single interaction sample for online learning""" |
| state: np.ndarray |
| action: int |
| reward: float |
| next_state: np.ndarray |
| done: bool |
| timestamp: float |
| user_id: str |
| confidence: float = 0.0 |
| |
| def to_dict(self) -> Dict: |
| return { |
| 'state': self.state.tolist(), |
| 'action': self.action, |
| 'reward': self.reward, |
| 'next_state': self.next_state.tolist(), |
| 'done': self.done, |
| 'timestamp': self.timestamp, |
| 'user_id': self.user_id, |
| 'confidence': self.confidence |
| } |
|
|
|
|
| @dataclass |
| class OnlineQNetwork: |
| """Q-Network for online learning""" |
| weights: Dict[str, np.ndarray] |
| biases: Dict[str, np.ndarray] |
| version: int = 1 |
| |
| def forward(self, state: np.ndarray) -> np.ndarray: |
| """Forward pass through network""" |
| |
| h1 = np.maximum(np.dot(state, self.weights['l1']) + self.biases['b1'], 0) |
| |
| h2 = np.maximum(np.dot(h1, self.weights['l2']) + self.biases['b2'], 0) |
| |
| q_values = np.dot(h2, self.weights['l3']) + self.biases['b3'] |
| return q_values |
| |
| def clone_from(self, source: 'OnlineQNetwork'): |
| """Clone weights from another network""" |
| self.weights = {k: v.copy() for k, v in source.weights.items()} |
| self.biases = {k: v.copy() for k, v in source.biases.items()} |
| self.version = source.version + 1 |
|
|
|
|
| class OnlineLearningEngine: |
| """ |
| Online learning engine for continuous model improvement. |
| |
| Features: |
| - Incremental updates from user feedback |
| - Experience replay buffer |
| - Target network for stability |
| - Periodic checkpointing |
| """ |
| |
| def __init__( |
| self, |
| state_dim: int = 64, |
| action_dim: int = 10, |
| hidden_dim: int = 128, |
| learning_rate: float = 0.001, |
| gamma: float = 0.95, |
| batch_size: int = 32, |
| buffer_size: int = 10000, |
| target_update_freq: int = 100 |
| ): |
| self.state_dim = state_dim |
| self.action_dim = action_dim |
| self.learning_rate = learning_rate |
| self.gamma = gamma |
| self.batch_size = batch_size |
| self.target_update_freq = target_update_freq |
| |
| |
| self.q_network = self._init_network() |
| self.target_network = self._init_network() |
| self._sync_target() |
| |
| |
| self.replay_buffer = deque(maxlen=buffer_size) |
| |
| |
| self.total_updates = 0 |
| self.update_count = 0 |
| |
| |
| self.lock = threading.Lock() |
| |
| |
| self.on_checkpoint = None |
| self.on_update = None |
| |
| def _init_network(self) -> OnlineQNetwork: |
| """Initialize network weights""" |
| np.random.seed(42) |
| return OnlineQNetwork( |
| weights={ |
| 'l1': np.random.randn(self.state_dim, self.hidden_dim) * 0.1, |
| 'l2': np.random.randn(self.hidden_dim, self.hidden_dim) * 0.1, |
| 'l3': np.random.randn(self.hidden_dim, self.action_dim) * 0.1 |
| }, |
| biases={ |
| 'b1': np.zeros(self.hidden_dim), |
| 'b2': np.zeros(self.hidden_dim), |
| 'b3': np.zeros(self.action_dim) |
| }, |
| version=1 |
| ) |
| |
| def _sync_target(self): |
| """Copy Q-network to target network""" |
| self.target_network.clone_from(self.q_network) |
| |
| def add_interaction( |
| self, |
| state: np.ndarray, |
| action: int, |
| reward: float, |
| next_state: np.ndarray, |
| done: bool, |
| user_id: str = 'anonymous', |
| confidence: float = 0.0 |
| ): |
| """Add a new interaction to the replay buffer""" |
| sample = InteractionSample( |
| state=state, |
| action=action, |
| reward=reward, |
| next_state=next_state, |
| done=done, |
| timestamp=time.time(), |
| user_id=user_id, |
| confidence=confidence |
| ) |
| |
| with self.lock: |
| self.replay_buffer.append(sample) |
| |
| |
| if len(self.replay_buffer) >= self.batch_size: |
| self.update() |
| |
| def update(self) -> Optional[Dict]: |
| """Perform a single online update""" |
| with self.lock: |
| if len(self.replay_buffer) < self.batch_size: |
| return None |
| |
| |
| indices = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False) |
| batch = [self.replay_buffer[i] for i in indices] |
| |
| |
| states = np.array([s.state for s in batch]) |
| actions = np.array([s.action for s in batch]) |
| rewards = np.array([s.reward for s in batch]) |
| next_states = np.array([s.next_state for s in batch]) |
| dones = np.array([s.done for s in batch]) |
| |
| |
| current_q = self.q_network.forward(states) |
| next_q = self.target_network.forward(next_states) |
| |
| targets = current_q.copy() |
| max_next_q = np.max(next_q, axis=1) |
| |
| for i in range(self.batch_size): |
| if dones[i]: |
| targets[i, actions[i]] = rewards[i] |
| else: |
| targets[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i] |
| |
| |
| |
| errors = targets - current_q |
| |
| |
| h2 = np.maximum(np.dot(states, self.q_network.weights['l1']) + self.q_network.biases['b1'], 0) |
| h3 = np.maximum(np.dot(h2, self.q_network.weights['l2']) + self.q_network.biases['b2'], 0) |
| |
| for i in range(self.batch_size): |
| grad_l3 = np.outer(h3[i], errors[i]) |
| grad_b3 = errors[i] |
| |
| self.q_network.weights['l3'] += self.learning_rate * grad_l3 |
| self.q_network.biases['b3'] += self.learning_rate * grad_b3 |
| |
| |
| self.update_count += 1 |
| if self.update_count % self.target_update_freq == 0: |
| self._sync_target() |
| |
| self.total_updates += 1 |
| |
| loss = np.mean(errors ** 2) |
| |
| result = { |
| 'loss': float(loss), |
| 'updates': self.total_updates, |
| 'buffer_size': len(self.replay_buffer) |
| } |
| |
| if self.on_update: |
| self.on_update(result) |
| |
| return result |
| |
| def predict(self, state: np.ndarray) -> Tuple[int, float]: |
| """Predict best action for a state""" |
| q_values = self.q_network.forward(state) |
| action = int(np.argmax(q_values)) |
| confidence = float(np.max(q_values)) |
| return action, confidence |
| |
| def get_q_values(self, state: np.ndarray) -> np.ndarray: |
| """Get Q-values for all actions""" |
| return self.q_network.forward(state) |
| |
| def save_checkpoint(self, path: str): |
| """Save model checkpoint""" |
| checkpoint = { |
| 'q_network': { |
| 'weights': {k: v.tolist() for k, v in self.q_network.weights.items()}, |
| 'biases': {k: v.tolist() for k, v in self.q_network.biases.items()}, |
| 'version': self.q_network.version |
| }, |
| 'total_updates': self.total_updates, |
| 'buffer_size': len(self.replay_buffer) |
| } |
| |
| with open(path, 'w') as f: |
| json.dump(checkpoint, f) |
| |
| if self.on_checkpoint: |
| self.on_checkpoint(path) |
| |
| return path |
| |
| def load_checkpoint(self, path: str): |
| """Load model checkpoint""" |
| with open(path, 'r') as f: |
| checkpoint = json.load(f) |
| |
| self.q_network.weights = {k: np.array(v) for k, v in checkpoint['q_network']['weights'].items()} |
| self.q_network.biases = {k: np.array(v) for k, v in checkpoint['q_network']['biases'].items()} |
| self.q_network.version = checkpoint['q_network']['version'] |
| self.total_updates = checkpoint['total_updates'] |
| |
| self._sync_target() |
| |
| return checkpoint |
| |
| def get_stats(self) -> Dict: |
| """Get learning statistics""" |
| with self.lock: |
| return { |
| 'total_updates': self.total_updates, |
| 'buffer_size': len(self.replay_buffer), |
| 'buffer_capacity': self.replay_buffer.maxlen, |
| 'network_version': self.q_network.version |
| } |
|
|
|
|
| class AdaptiveLearningScheduler: |
| """ |
| Adaptive learning rate scheduler based on performance. |
| |
| Reduces learning rate when performance plateaus. |
| Increases when making good progress. |
| """ |
| |
| def __init__( |
| self, |
| initial_lr: float = 0.001, |
| min_lr: float = 0.00001, |
| patience: int = 10, |
| factor: float = 0.5 |
| ): |
| self.current_lr = initial_lr |
| self.min_lr = min_lr |
| self.patience = patience |
| self.factor = factor |
| |
| self.best_loss = float('inf') |
| self.wait_count = 0 |
| self.history = [] |
| |
| def step(self, loss: float) -> float: |
| """Update learning rate based on loss""" |
| self.history.append(loss) |
| |
| if len(self.history) < 2: |
| return self.current_lr |
| |
| if loss < self.best_loss: |
| self.best_loss = loss |
| self.wait_count = 0 |
| else: |
| self.wait_count += 1 |
| |
| if self.wait_count >= self.patience and self.current_lr > self.min_lr: |
| self.current_lr *= self.factor |
| self.wait_count = 0 |
| |
| return self.current_lr |
|
|
|
|
| |
| class OnlineLearningAPI: |
| """REST API wrapper for online learning""" |
| |
| def __init__(self, engine: OnlineLearningEngine): |
| self.engine = engine |
| |
| def record_feedback( |
| self, |
| user_id: str, |
| state: List[float], |
| action: int, |
| quality: int, |
| comment: Optional[str] = None |
| ) -> Dict: |
| """ |
| Record user feedback and trigger online update. |
| |
| Quality mapping: |
| - 1: Very unhelpful (-1.0) |
| - 2: Unhelpful (-0.5) |
| - 3: Neutral (0.0) |
| - 4: Helpful (0.5) |
| - 5: Very helpful (1.0) |
| """ |
| reward_map = {1: -1.0, 2: -0.5, 3: 0.0, 4: 0.5, 5: 1.0} |
| reward = reward_map.get(quality, 0.0) |
| |
| state_arr = np.array(state) |
| |
| |
| next_state = state_arr + np.random.randn(len(state_arr)) * 0.1 |
| |
| self.engine.add_interaction( |
| state=state_arr, |
| action=action, |
| reward=reward, |
| next_state=next_state, |
| done=False, |
| user_id=user_id, |
| confidence=reward |
| ) |
| |
| return { |
| 'status': 'recorded', |
| 'reward': reward, |
| 'total_updates': self.engine.total_updates |
| } |
| |
| def get_prediction(self, state: List[float]) -> Dict: |
| """Get prediction for a state""" |
| state_arr = np.array(state) |
| action, confidence = self.engine.predict(state_arr) |
| q_values = self.engine.get_q_values(state_arr) |
| |
| return { |
| 'action': action, |
| 'confidence': confidence, |
| 'q_values': q_values.tolist() |
| } |
| |
| def get_stats(self) -> Dict: |
| """Get learning stats""" |
| return self.engine.get_stats() |
|
|
|
|
| |
| if __name__ == "__main__": |
| engine = OnlineLearningEngine() |
| api = OnlineLearningAPI(engine) |
| |
| print("Online Learning Engine initialized") |
| print(f"State dim: {engine.state_dim}, Action dim: {engine.action_dim}") |
| |
| |
| for i in range(100): |
| state = np.random.randn(64) |
| action = np.random.randint(0, 10) |
| quality = np.random.randint(1, 6) |
| |
| result = api.record_feedback( |
| user_id='test_user', |
| state=state.tolist(), |
| action=action, |
| quality=quality |
| ) |
| |
| print(f"\\nAfter 100 interactions:") |
| print(f" Updates: {result['total_updates']}") |
| print(f" Stats: {api.get_stats()}") |
|
|