contextflow-rl / online_learning.py
namish10's picture
Upload online_learning.py with huggingface_hub
72558bb verified
"""
Online Learning Module for ContextFlow
Implements continuous model improvement from real user interactions.
Addresses: Online Learning requirement
"""
import numpy as np
import pickle
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field
from collections import deque
import threading
import time
import json
@dataclass
class InteractionSample:
"""A single interaction sample for online learning"""
state: np.ndarray
action: int
reward: float
next_state: np.ndarray
done: bool
timestamp: float
user_id: str
confidence: float = 0.0
def to_dict(self) -> Dict:
return {
'state': self.state.tolist(),
'action': self.action,
'reward': self.reward,
'next_state': self.next_state.tolist(),
'done': self.done,
'timestamp': self.timestamp,
'user_id': self.user_id,
'confidence': self.confidence
}
@dataclass
class OnlineQNetwork:
"""Q-Network for online learning"""
weights: Dict[str, np.ndarray]
biases: Dict[str, np.ndarray]
version: int = 1
def forward(self, state: np.ndarray) -> np.ndarray:
"""Forward pass through network"""
# Layer 1
h1 = np.maximum(np.dot(state, self.weights['l1']) + self.biases['b1'], 0)
# Layer 2
h2 = np.maximum(np.dot(h1, self.weights['l2']) + self.biases['b2'], 0)
# Output
q_values = np.dot(h2, self.weights['l3']) + self.biases['b3']
return q_values
def clone_from(self, source: 'OnlineQNetwork'):
"""Clone weights from another network"""
self.weights = {k: v.copy() for k, v in source.weights.items()}
self.biases = {k: v.copy() for k, v in source.biases.items()}
self.version = source.version + 1
class OnlineLearningEngine:
"""
Online learning engine for continuous model improvement.
Features:
- Incremental updates from user feedback
- Experience replay buffer
- Target network for stability
- Periodic checkpointing
"""
def __init__(
self,
state_dim: int = 64,
action_dim: int = 10,
hidden_dim: int = 128,
learning_rate: float = 0.001,
gamma: float = 0.95,
batch_size: int = 32,
buffer_size: int = 10000,
target_update_freq: int = 100
):
self.state_dim = state_dim
self.action_dim = action_dim
self.learning_rate = learning_rate
self.gamma = gamma
self.batch_size = batch_size
self.target_update_freq = target_update_freq
# Initialize networks
self.q_network = self._init_network()
self.target_network = self._init_network()
self._sync_target()
# Experience replay buffer
self.replay_buffer = deque(maxlen=buffer_size)
# Training stats
self.total_updates = 0
self.update_count = 0
# Lock for thread safety
self.lock = threading.Lock()
# Callbacks for events
self.on_checkpoint = None
self.on_update = None
def _init_network(self) -> OnlineQNetwork:
"""Initialize network weights"""
np.random.seed(42)
return OnlineQNetwork(
weights={
'l1': np.random.randn(self.state_dim, self.hidden_dim) * 0.1,
'l2': np.random.randn(self.hidden_dim, self.hidden_dim) * 0.1,
'l3': np.random.randn(self.hidden_dim, self.action_dim) * 0.1
},
biases={
'b1': np.zeros(self.hidden_dim),
'b2': np.zeros(self.hidden_dim),
'b3': np.zeros(self.action_dim)
},
version=1
)
def _sync_target(self):
"""Copy Q-network to target network"""
self.target_network.clone_from(self.q_network)
def add_interaction(
self,
state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool,
user_id: str = 'anonymous',
confidence: float = 0.0
):
"""Add a new interaction to the replay buffer"""
sample = InteractionSample(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
timestamp=time.time(),
user_id=user_id,
confidence=confidence
)
with self.lock:
self.replay_buffer.append(sample)
# Trigger online update
if len(self.replay_buffer) >= self.batch_size:
self.update()
def update(self) -> Optional[Dict]:
"""Perform a single online update"""
with self.lock:
if len(self.replay_buffer) < self.batch_size:
return None
# Sample batch
indices = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
batch = [self.replay_buffer[i] for i in indices]
# Extract batch arrays
states = np.array([s.state for s in batch])
actions = np.array([s.action for s in batch])
rewards = np.array([s.reward for s in batch])
next_states = np.array([s.next_state for s in batch])
dones = np.array([s.done for s in batch])
# Compute targets
current_q = self.q_network.forward(states)
next_q = self.target_network.forward(next_states)
targets = current_q.copy()
max_next_q = np.max(next_q, axis=1)
for i in range(self.batch_size):
if dones[i]:
targets[i, actions[i]] = rewards[i]
else:
targets[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i]
# Compute gradients and update (simplified SGD)
# In production, would use PyTorch autograd
errors = targets - current_q
# Gradient descent on layer 3
h2 = np.maximum(np.dot(states, self.q_network.weights['l1']) + self.q_network.biases['b1'], 0)
h3 = np.maximum(np.dot(h2, self.q_network.weights['l2']) + self.q_network.biases['b2'], 0)
for i in range(self.batch_size):
grad_l3 = np.outer(h3[i], errors[i])
grad_b3 = errors[i]
self.q_network.weights['l3'] += self.learning_rate * grad_l3
self.q_network.biases['b3'] += self.learning_rate * grad_b3
# Update target network periodically
self.update_count += 1
if self.update_count % self.target_update_freq == 0:
self._sync_target()
self.total_updates += 1
loss = np.mean(errors ** 2)
result = {
'loss': float(loss),
'updates': self.total_updates,
'buffer_size': len(self.replay_buffer)
}
if self.on_update:
self.on_update(result)
return result
def predict(self, state: np.ndarray) -> Tuple[int, float]:
"""Predict best action for a state"""
q_values = self.q_network.forward(state)
action = int(np.argmax(q_values))
confidence = float(np.max(q_values))
return action, confidence
def get_q_values(self, state: np.ndarray) -> np.ndarray:
"""Get Q-values for all actions"""
return self.q_network.forward(state)
def save_checkpoint(self, path: str):
"""Save model checkpoint"""
checkpoint = {
'q_network': {
'weights': {k: v.tolist() for k, v in self.q_network.weights.items()},
'biases': {k: v.tolist() for k, v in self.q_network.biases.items()},
'version': self.q_network.version
},
'total_updates': self.total_updates,
'buffer_size': len(self.replay_buffer)
}
with open(path, 'w') as f:
json.dump(checkpoint, f)
if self.on_checkpoint:
self.on_checkpoint(path)
return path
def load_checkpoint(self, path: str):
"""Load model checkpoint"""
with open(path, 'r') as f:
checkpoint = json.load(f)
self.q_network.weights = {k: np.array(v) for k, v in checkpoint['q_network']['weights'].items()}
self.q_network.biases = {k: np.array(v) for k, v in checkpoint['q_network']['biases'].items()}
self.q_network.version = checkpoint['q_network']['version']
self.total_updates = checkpoint['total_updates']
self._sync_target()
return checkpoint
def get_stats(self) -> Dict:
"""Get learning statistics"""
with self.lock:
return {
'total_updates': self.total_updates,
'buffer_size': len(self.replay_buffer),
'buffer_capacity': self.replay_buffer.maxlen,
'network_version': self.q_network.version
}
class AdaptiveLearningScheduler:
"""
Adaptive learning rate scheduler based on performance.
Reduces learning rate when performance plateaus.
Increases when making good progress.
"""
def __init__(
self,
initial_lr: float = 0.001,
min_lr: float = 0.00001,
patience: int = 10,
factor: float = 0.5
):
self.current_lr = initial_lr
self.min_lr = min_lr
self.patience = patience
self.factor = factor
self.best_loss = float('inf')
self.wait_count = 0
self.history = []
def step(self, loss: float) -> float:
"""Update learning rate based on loss"""
self.history.append(loss)
if len(self.history) < 2:
return self.current_lr
if loss < self.best_loss:
self.best_loss = loss
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience and self.current_lr > self.min_lr:
self.current_lr *= self.factor
self.wait_count = 0
return self.current_lr
# API Integration
class OnlineLearningAPI:
"""REST API wrapper for online learning"""
def __init__(self, engine: OnlineLearningEngine):
self.engine = engine
def record_feedback(
self,
user_id: str,
state: List[float],
action: int,
quality: int, # 1-5 quality rating
comment: Optional[str] = None
) -> Dict:
"""
Record user feedback and trigger online update.
Quality mapping:
- 1: Very unhelpful (-1.0)
- 2: Unhelpful (-0.5)
- 3: Neutral (0.0)
- 4: Helpful (0.5)
- 5: Very helpful (1.0)
"""
reward_map = {1: -1.0, 2: -0.5, 3: 0.0, 4: 0.5, 5: 1.0}
reward = reward_map.get(quality, 0.0)
state_arr = np.array(state)
# Simulate next state (in real impl, would come from actual interaction)
next_state = state_arr + np.random.randn(len(state_arr)) * 0.1
self.engine.add_interaction(
state=state_arr,
action=action,
reward=reward,
next_state=next_state,
done=False,
user_id=user_id,
confidence=reward
)
return {
'status': 'recorded',
'reward': reward,
'total_updates': self.engine.total_updates
}
def get_prediction(self, state: List[float]) -> Dict:
"""Get prediction for a state"""
state_arr = np.array(state)
action, confidence = self.engine.predict(state_arr)
q_values = self.engine.get_q_values(state_arr)
return {
'action': action,
'confidence': confidence,
'q_values': q_values.tolist()
}
def get_stats(self) -> Dict:
"""Get learning stats"""
return self.engine.get_stats()
# Example usage
if __name__ == "__main__":
engine = OnlineLearningEngine()
api = OnlineLearningAPI(engine)
print("Online Learning Engine initialized")
print(f"State dim: {engine.state_dim}, Action dim: {engine.action_dim}")
# Simulate some feedback
for i in range(100):
state = np.random.randn(64)
action = np.random.randint(0, 10)
quality = np.random.randint(1, 6)
result = api.record_feedback(
user_id='test_user',
state=state.tolist(),
action=action,
quality=quality
)
print(f"\\nAfter 100 interactions:")
print(f" Updates: {result['total_updates']}")
print(f" Stats: {api.get_stats()}")