namish10 commited on
Commit
72558bb
·
verified ·
1 Parent(s): 18a270a

Upload online_learning.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. online_learning.py +419 -0
online_learning.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Online Learning Module for ContextFlow
3
+
4
+ Implements continuous model improvement from real user interactions.
5
+ Addresses: Online Learning requirement
6
+ """
7
+
8
+ import numpy as np
9
+ import pickle
10
+ from typing import Dict, List, Optional, Any, Tuple
11
+ from dataclasses import dataclass, field
12
+ from collections import deque
13
+ import threading
14
+ import time
15
+ import json
16
+
17
+
18
+ @dataclass
19
+ class InteractionSample:
20
+ """A single interaction sample for online learning"""
21
+ state: np.ndarray
22
+ action: int
23
+ reward: float
24
+ next_state: np.ndarray
25
+ done: bool
26
+ timestamp: float
27
+ user_id: str
28
+ confidence: float = 0.0
29
+
30
+ def to_dict(self) -> Dict:
31
+ return {
32
+ 'state': self.state.tolist(),
33
+ 'action': self.action,
34
+ 'reward': self.reward,
35
+ 'next_state': self.next_state.tolist(),
36
+ 'done': self.done,
37
+ 'timestamp': self.timestamp,
38
+ 'user_id': self.user_id,
39
+ 'confidence': self.confidence
40
+ }
41
+
42
+
43
+ @dataclass
44
+ class OnlineQNetwork:
45
+ """Q-Network for online learning"""
46
+ weights: Dict[str, np.ndarray]
47
+ biases: Dict[str, np.ndarray]
48
+ version: int = 1
49
+
50
+ def forward(self, state: np.ndarray) -> np.ndarray:
51
+ """Forward pass through network"""
52
+ # Layer 1
53
+ h1 = np.maximum(np.dot(state, self.weights['l1']) + self.biases['b1'], 0)
54
+ # Layer 2
55
+ h2 = np.maximum(np.dot(h1, self.weights['l2']) + self.biases['b2'], 0)
56
+ # Output
57
+ q_values = np.dot(h2, self.weights['l3']) + self.biases['b3']
58
+ return q_values
59
+
60
+ def clone_from(self, source: 'OnlineQNetwork'):
61
+ """Clone weights from another network"""
62
+ self.weights = {k: v.copy() for k, v in source.weights.items()}
63
+ self.biases = {k: v.copy() for k, v in source.biases.items()}
64
+ self.version = source.version + 1
65
+
66
+
67
+ class OnlineLearningEngine:
68
+ """
69
+ Online learning engine for continuous model improvement.
70
+
71
+ Features:
72
+ - Incremental updates from user feedback
73
+ - Experience replay buffer
74
+ - Target network for stability
75
+ - Periodic checkpointing
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ state_dim: int = 64,
81
+ action_dim: int = 10,
82
+ hidden_dim: int = 128,
83
+ learning_rate: float = 0.001,
84
+ gamma: float = 0.95,
85
+ batch_size: int = 32,
86
+ buffer_size: int = 10000,
87
+ target_update_freq: int = 100
88
+ ):
89
+ self.state_dim = state_dim
90
+ self.action_dim = action_dim
91
+ self.learning_rate = learning_rate
92
+ self.gamma = gamma
93
+ self.batch_size = batch_size
94
+ self.target_update_freq = target_update_freq
95
+
96
+ # Initialize networks
97
+ self.q_network = self._init_network()
98
+ self.target_network = self._init_network()
99
+ self._sync_target()
100
+
101
+ # Experience replay buffer
102
+ self.replay_buffer = deque(maxlen=buffer_size)
103
+
104
+ # Training stats
105
+ self.total_updates = 0
106
+ self.update_count = 0
107
+
108
+ # Lock for thread safety
109
+ self.lock = threading.Lock()
110
+
111
+ # Callbacks for events
112
+ self.on_checkpoint = None
113
+ self.on_update = None
114
+
115
+ def _init_network(self) -> OnlineQNetwork:
116
+ """Initialize network weights"""
117
+ np.random.seed(42)
118
+ return OnlineQNetwork(
119
+ weights={
120
+ 'l1': np.random.randn(self.state_dim, self.hidden_dim) * 0.1,
121
+ 'l2': np.random.randn(self.hidden_dim, self.hidden_dim) * 0.1,
122
+ 'l3': np.random.randn(self.hidden_dim, self.action_dim) * 0.1
123
+ },
124
+ biases={
125
+ 'b1': np.zeros(self.hidden_dim),
126
+ 'b2': np.zeros(self.hidden_dim),
127
+ 'b3': np.zeros(self.action_dim)
128
+ },
129
+ version=1
130
+ )
131
+
132
+ def _sync_target(self):
133
+ """Copy Q-network to target network"""
134
+ self.target_network.clone_from(self.q_network)
135
+
136
+ def add_interaction(
137
+ self,
138
+ state: np.ndarray,
139
+ action: int,
140
+ reward: float,
141
+ next_state: np.ndarray,
142
+ done: bool,
143
+ user_id: str = 'anonymous',
144
+ confidence: float = 0.0
145
+ ):
146
+ """Add a new interaction to the replay buffer"""
147
+ sample = InteractionSample(
148
+ state=state,
149
+ action=action,
150
+ reward=reward,
151
+ next_state=next_state,
152
+ done=done,
153
+ timestamp=time.time(),
154
+ user_id=user_id,
155
+ confidence=confidence
156
+ )
157
+
158
+ with self.lock:
159
+ self.replay_buffer.append(sample)
160
+
161
+ # Trigger online update
162
+ if len(self.replay_buffer) >= self.batch_size:
163
+ self.update()
164
+
165
+ def update(self) -> Optional[Dict]:
166
+ """Perform a single online update"""
167
+ with self.lock:
168
+ if len(self.replay_buffer) < self.batch_size:
169
+ return None
170
+
171
+ # Sample batch
172
+ indices = np.random.choice(len(self.replay_buffer), self.batch_size, replace=False)
173
+ batch = [self.replay_buffer[i] for i in indices]
174
+
175
+ # Extract batch arrays
176
+ states = np.array([s.state for s in batch])
177
+ actions = np.array([s.action for s in batch])
178
+ rewards = np.array([s.reward for s in batch])
179
+ next_states = np.array([s.next_state for s in batch])
180
+ dones = np.array([s.done for s in batch])
181
+
182
+ # Compute targets
183
+ current_q = self.q_network.forward(states)
184
+ next_q = self.target_network.forward(next_states)
185
+
186
+ targets = current_q.copy()
187
+ max_next_q = np.max(next_q, axis=1)
188
+
189
+ for i in range(self.batch_size):
190
+ if dones[i]:
191
+ targets[i, actions[i]] = rewards[i]
192
+ else:
193
+ targets[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i]
194
+
195
+ # Compute gradients and update (simplified SGD)
196
+ # In production, would use PyTorch autograd
197
+ errors = targets - current_q
198
+
199
+ # Gradient descent on layer 3
200
+ h2 = np.maximum(np.dot(states, self.q_network.weights['l1']) + self.q_network.biases['b1'], 0)
201
+ h3 = np.maximum(np.dot(h2, self.q_network.weights['l2']) + self.q_network.biases['b2'], 0)
202
+
203
+ for i in range(self.batch_size):
204
+ grad_l3 = np.outer(h3[i], errors[i])
205
+ grad_b3 = errors[i]
206
+
207
+ self.q_network.weights['l3'] += self.learning_rate * grad_l3
208
+ self.q_network.biases['b3'] += self.learning_rate * grad_b3
209
+
210
+ # Update target network periodically
211
+ self.update_count += 1
212
+ if self.update_count % self.target_update_freq == 0:
213
+ self._sync_target()
214
+
215
+ self.total_updates += 1
216
+
217
+ loss = np.mean(errors ** 2)
218
+
219
+ result = {
220
+ 'loss': float(loss),
221
+ 'updates': self.total_updates,
222
+ 'buffer_size': len(self.replay_buffer)
223
+ }
224
+
225
+ if self.on_update:
226
+ self.on_update(result)
227
+
228
+ return result
229
+
230
+ def predict(self, state: np.ndarray) -> Tuple[int, float]:
231
+ """Predict best action for a state"""
232
+ q_values = self.q_network.forward(state)
233
+ action = int(np.argmax(q_values))
234
+ confidence = float(np.max(q_values))
235
+ return action, confidence
236
+
237
+ def get_q_values(self, state: np.ndarray) -> np.ndarray:
238
+ """Get Q-values for all actions"""
239
+ return self.q_network.forward(state)
240
+
241
+ def save_checkpoint(self, path: str):
242
+ """Save model checkpoint"""
243
+ checkpoint = {
244
+ 'q_network': {
245
+ 'weights': {k: v.tolist() for k, v in self.q_network.weights.items()},
246
+ 'biases': {k: v.tolist() for k, v in self.q_network.biases.items()},
247
+ 'version': self.q_network.version
248
+ },
249
+ 'total_updates': self.total_updates,
250
+ 'buffer_size': len(self.replay_buffer)
251
+ }
252
+
253
+ with open(path, 'w') as f:
254
+ json.dump(checkpoint, f)
255
+
256
+ if self.on_checkpoint:
257
+ self.on_checkpoint(path)
258
+
259
+ return path
260
+
261
+ def load_checkpoint(self, path: str):
262
+ """Load model checkpoint"""
263
+ with open(path, 'r') as f:
264
+ checkpoint = json.load(f)
265
+
266
+ self.q_network.weights = {k: np.array(v) for k, v in checkpoint['q_network']['weights'].items()}
267
+ self.q_network.biases = {k: np.array(v) for k, v in checkpoint['q_network']['biases'].items()}
268
+ self.q_network.version = checkpoint['q_network']['version']
269
+ self.total_updates = checkpoint['total_updates']
270
+
271
+ self._sync_target()
272
+
273
+ return checkpoint
274
+
275
+ def get_stats(self) -> Dict:
276
+ """Get learning statistics"""
277
+ with self.lock:
278
+ return {
279
+ 'total_updates': self.total_updates,
280
+ 'buffer_size': len(self.replay_buffer),
281
+ 'buffer_capacity': self.replay_buffer.maxlen,
282
+ 'network_version': self.q_network.version
283
+ }
284
+
285
+
286
+ class AdaptiveLearningScheduler:
287
+ """
288
+ Adaptive learning rate scheduler based on performance.
289
+
290
+ Reduces learning rate when performance plateaus.
291
+ Increases when making good progress.
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ initial_lr: float = 0.001,
297
+ min_lr: float = 0.00001,
298
+ patience: int = 10,
299
+ factor: float = 0.5
300
+ ):
301
+ self.current_lr = initial_lr
302
+ self.min_lr = min_lr
303
+ self.patience = patience
304
+ self.factor = factor
305
+
306
+ self.best_loss = float('inf')
307
+ self.wait_count = 0
308
+ self.history = []
309
+
310
+ def step(self, loss: float) -> float:
311
+ """Update learning rate based on loss"""
312
+ self.history.append(loss)
313
+
314
+ if len(self.history) < 2:
315
+ return self.current_lr
316
+
317
+ if loss < self.best_loss:
318
+ self.best_loss = loss
319
+ self.wait_count = 0
320
+ else:
321
+ self.wait_count += 1
322
+
323
+ if self.wait_count >= self.patience and self.current_lr > self.min_lr:
324
+ self.current_lr *= self.factor
325
+ self.wait_count = 0
326
+
327
+ return self.current_lr
328
+
329
+
330
+ # API Integration
331
+ class OnlineLearningAPI:
332
+ """REST API wrapper for online learning"""
333
+
334
+ def __init__(self, engine: OnlineLearningEngine):
335
+ self.engine = engine
336
+
337
+ def record_feedback(
338
+ self,
339
+ user_id: str,
340
+ state: List[float],
341
+ action: int,
342
+ quality: int, # 1-5 quality rating
343
+ comment: Optional[str] = None
344
+ ) -> Dict:
345
+ """
346
+ Record user feedback and trigger online update.
347
+
348
+ Quality mapping:
349
+ - 1: Very unhelpful (-1.0)
350
+ - 2: Unhelpful (-0.5)
351
+ - 3: Neutral (0.0)
352
+ - 4: Helpful (0.5)
353
+ - 5: Very helpful (1.0)
354
+ """
355
+ reward_map = {1: -1.0, 2: -0.5, 3: 0.0, 4: 0.5, 5: 1.0}
356
+ reward = reward_map.get(quality, 0.0)
357
+
358
+ state_arr = np.array(state)
359
+
360
+ # Simulate next state (in real impl, would come from actual interaction)
361
+ next_state = state_arr + np.random.randn(len(state_arr)) * 0.1
362
+
363
+ self.engine.add_interaction(
364
+ state=state_arr,
365
+ action=action,
366
+ reward=reward,
367
+ next_state=next_state,
368
+ done=False,
369
+ user_id=user_id,
370
+ confidence=reward
371
+ )
372
+
373
+ return {
374
+ 'status': 'recorded',
375
+ 'reward': reward,
376
+ 'total_updates': self.engine.total_updates
377
+ }
378
+
379
+ def get_prediction(self, state: List[float]) -> Dict:
380
+ """Get prediction for a state"""
381
+ state_arr = np.array(state)
382
+ action, confidence = self.engine.predict(state_arr)
383
+ q_values = self.engine.get_q_values(state_arr)
384
+
385
+ return {
386
+ 'action': action,
387
+ 'confidence': confidence,
388
+ 'q_values': q_values.tolist()
389
+ }
390
+
391
+ def get_stats(self) -> Dict:
392
+ """Get learning stats"""
393
+ return self.engine.get_stats()
394
+
395
+
396
+ # Example usage
397
+ if __name__ == "__main__":
398
+ engine = OnlineLearningEngine()
399
+ api = OnlineLearningAPI(engine)
400
+
401
+ print("Online Learning Engine initialized")
402
+ print(f"State dim: {engine.state_dim}, Action dim: {engine.action_dim}")
403
+
404
+ # Simulate some feedback
405
+ for i in range(100):
406
+ state = np.random.randn(64)
407
+ action = np.random.randint(0, 10)
408
+ quality = np.random.randint(1, 6)
409
+
410
+ result = api.record_feedback(
411
+ user_id='test_user',
412
+ state=state.tolist(),
413
+ action=action,
414
+ quality=quality
415
+ )
416
+
417
+ print(f"\\nAfter 100 interactions:")
418
+ print(f" Updates: {result['total_updates']}")
419
+ print(f" Stats: {api.get_stats()}")