namish10 commited on
Commit
a183cd9
·
verified ·
1 Parent(s): 012f151

Upload train_rl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_rl.py +655 -0
train_rl.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ContextFlow RL Training Script
3
+
4
+ Trains the doubt prediction model using reinforcement learning
5
+ and uploads to Hugging Face.
6
+
7
+ Based on OpenClaw-RL principles:
8
+ - Binary RL (GRPO) for next-state feedback
9
+ - Personal agent optimization from user interactions
10
+ - Q-Learning for doubt prediction
11
+
12
+ Usage:
13
+ python train_rl.py --mode train --epochs 10
14
+ python train_rl.py --mode upload --hf_token YOUR_TOKEN
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import pickle
20
+ import numpy as np
21
+ from dataclasses import dataclass, asdict
22
+ from typing import List, Dict, Tuple, Optional
23
+ from datetime import datetime
24
+ import argparse
25
+ from pathlib import Path
26
+
27
+ try:
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.optim as optim
31
+ from torch.utils.data import Dataset, DataLoader
32
+ HAS_TORCH = True
33
+ except ImportError:
34
+ HAS_TORCH = False
35
+ print("PyTorch not installed. Using numpy-only mode.")
36
+
37
+ try:
38
+ from huggingface_hub import HfApi, create_repo, upload_folder
39
+ HAS_HF = True
40
+ except ImportError:
41
+ HAS_HF = False
42
+ print("huggingface_hub not installed. Run: pip install huggingface_hub")
43
+
44
+
45
+ @dataclass
46
+ class LearningState:
47
+ """Represents a learning state for the agent"""
48
+ topic_embedding: np.ndarray
49
+ progress: float
50
+ confusion_signals: np.ndarray
51
+ gesture_signals: np.ndarray
52
+ time_spent: float
53
+ session_id: str
54
+
55
+
56
+ @dataclass
57
+ class Interaction:
58
+ """A user interaction for RL training"""
59
+ state: LearningState
60
+ action: str
61
+ reward: float
62
+ next_state: LearningState
63
+ done: bool
64
+ timestamp: str
65
+
66
+
67
+ @dataclass
68
+ class ModelCheckpoint:
69
+ """Model checkpoint for Hugging Face"""
70
+ q_network_weights: Dict
71
+ policy_version: int
72
+ training_stats: Dict
73
+ timestamp: str
74
+ config: Dict
75
+
76
+
77
+ class QNetwork(nn.Module if HAS_TORCH else object):
78
+ """Q-Network for doubt prediction"""
79
+
80
+ def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
81
+ if not HAS_TORCH:
82
+ self.weights = {}
83
+ return
84
+
85
+ super().__init__()
86
+ self.fc1 = nn.Linear(state_dim, hidden_dim)
87
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
88
+ self.fc3 = nn.Linear(hidden_dim, action_dim)
89
+ self.relu = nn.ReLU()
90
+
91
+ def forward(self, x):
92
+ if not HAS_TORCH:
93
+ return np.zeros((x.shape[0], self.action_dim))
94
+ x = self.relu(self.fc1(x))
95
+ x = self.relu(self.fc2(x))
96
+ return self.fc3(x)
97
+
98
+ def to_numpy(self):
99
+ if not HAS_TORCH:
100
+ return {}
101
+ return {k: v.cpu().numpy() for k, v in self.state_dict().items()}
102
+
103
+ def from_numpy(self, state_dict):
104
+ if not HAS_TORCH or not state_dict:
105
+ return
106
+ self.load_state_dict({k: torch.tensor(v) for k, v in state_dict.items()})
107
+
108
+
109
+ class ExperienceReplay:
110
+ """Experience replay buffer for RL training"""
111
+
112
+ def __init__(self, capacity: int = 10000):
113
+ self.buffer = []
114
+ self.capacity = capacity
115
+
116
+ def push(self, interaction: Interaction):
117
+ self.buffer.append(interaction)
118
+ if len(self.buffer) > self.capacity:
119
+ self.buffer.pop(0)
120
+
121
+ def sample(self, batch_size: int) -> List[Interaction]:
122
+ return np.random.choice(self.buffer, min(batch_size, len(self.buffer))).tolist()
123
+
124
+ def __len__(self):
125
+ return len(self.buffer)
126
+
127
+
128
+ class DoubtPredictionRL:
129
+ """
130
+ RL-based doubt prediction agent.
131
+
132
+ Features:
133
+ - Q-Learning for doubt probability prediction
134
+ - Experience replay for stable training
135
+ - Binary reward signals (OpenClaw-RL style)
136
+ - Personalization from user feedback
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ state_dim: int = 64,
142
+ action_dim: int = 10,
143
+ learning_rate: float = 0.001,
144
+ gamma: float = 0.95,
145
+ epsilon: float = 1.0,
146
+ epsilon_decay: float = 0.995,
147
+ epsilon_min: float = 0.01,
148
+ hidden_dim: int = 128,
149
+ device: str = "cpu"
150
+ ):
151
+ self.state_dim = state_dim
152
+ self.action_dim = action_dim
153
+ self.gamma = gamma
154
+ self.epsilon = epsilon
155
+ self.epsilon_decay = epsilon_decay
156
+ self.epsilon_min = epsilon_min
157
+ self.device = device
158
+
159
+ self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
160
+ self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
161
+ self.target_network.load_state_dict(self.q_network.state_dict())
162
+
163
+ if HAS_TORCH:
164
+ self.q_network = self.q_network.to(device)
165
+ self.target_network = self.target_network.to(device)
166
+ self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
167
+ self.criterion = nn.MSELoss()
168
+
169
+ self.replay_buffer = ExperienceReplay()
170
+ self.policy_version = 0
171
+ self.training_history = []
172
+
173
+ def encode_state(self, state: LearningState) -> np.ndarray:
174
+ """Encode learning state to feature vector"""
175
+ features = np.concatenate([
176
+ state.topic_embedding[:32] if len(state.topic_embedding) >= 32 else
177
+ np.pad(state.topic_embedding, (0, 32 - len(state.topic_embedding))),
178
+ [state.progress],
179
+ state.confusion_signals[:8] if len(state.confusion_signals) >= 8 else
180
+ np.pad(state.confusion_signals, (0, 8 - len(state.confusion_signals))),
181
+ state.gesture_signals[:8] if len(state.gesture_signals) >= 8 else
182
+ np.pad(state.gesture_signals, (0, 8 - len(state.gesture_signals))),
183
+ [state.time_spent / 3600],
184
+ np.random.randn(7) * 0.01
185
+ ])
186
+
187
+ if len(features) < self.state_dim:
188
+ features = np.pad(features, (0, self.state_dim - len(features)))
189
+ elif len(features) > self.state_dim:
190
+ features = features[:self.state_dim]
191
+
192
+ return features.astype(np.float32)
193
+
194
+ def predict_doubt_probability(self, state: LearningState) -> np.ndarray:
195
+ """Predict doubt probabilities for different doubt types"""
196
+ state_vec = self.encode_state(state)
197
+
198
+ if HAS_TORCH:
199
+ state_tensor = torch.FloatTensor(state_vec).unsqueeze(0).to(self.device)
200
+ with torch.no_grad():
201
+ q_values = self.q_network(state_tensor).cpu().numpy()[0]
202
+ else:
203
+ q_values = np.random.randn(self.action_dim) * 0.1
204
+
205
+ probs = self.softmax(q_values)
206
+ return probs
207
+
208
+ def select_action(self, state: LearningState, training: bool = True) -> int:
209
+ """Select action using epsilon-greedy policy"""
210
+ if training and np.random.random() < self.epsilon:
211
+ return np.random.randint(self.action_dim)
212
+
213
+ probs = self.predict_doubt_probability(state)
214
+ return np.argmax(probs).item()
215
+
216
+ def compute_reward(self, interaction: Interaction) -> float:
217
+ """
218
+ Compute reward using OpenClaw-RL style binary reward.
219
+
220
+ Positive signals:
221
+ - User understood (quality >= 4)
222
+ - Confusion decreased
223
+ - Gesture indicated "got it"
224
+
225
+ Negative signals:
226
+ - User confused (quality < 3)
227
+ - Confusion increased
228
+ - Gesture indicated "confused"
229
+ """
230
+ base_reward = interaction.reward
231
+
232
+ if "got_it" in interaction.action.lower():
233
+ base_reward += 1.0
234
+ elif "confused" in interaction.action.lower():
235
+ base_reward -= 0.5
236
+ elif "pause" in interaction.action.lower():
237
+ base_reward += 0.2
238
+
239
+ confusion_delta = (
240
+ interaction.next_state.confusion_signals.mean() -
241
+ interaction.state.confusion_signals.mean()
242
+ )
243
+ base_reward -= confusion_delta * 2.0
244
+
245
+ return np.clip(base_reward, -2.0, 2.0)
246
+
247
+ def store_interaction(self, interaction: Interaction):
248
+ """Store interaction in replay buffer"""
249
+ reward = self.compute_reward(interaction)
250
+ interaction.reward = reward
251
+ self.replay_buffer.push(interaction)
252
+
253
+ def train_step(self, batch_size: int = 32) -> Dict:
254
+ """Single training step"""
255
+ if len(self.replay_buffer) < batch_size:
256
+ return {"loss": 0.0, "samples": 0}
257
+
258
+ batch = self.replay_buffer.sample(batch_size)
259
+
260
+ if not HAS_TORCH:
261
+ self.policy_version += 1
262
+ return {"loss": 0.0, "samples": len(batch), "mode": "numpy"}
263
+
264
+ states = np.array([self.encode_state(i.state) for i in batch])
265
+
266
+ action_map = {a: idx for idx, a in enumerate(set(i.action for i in batch))}
267
+ actions = np.array([action_map[i.action] for i in batch])
268
+ rewards = np.array([i.reward for i in batch])
269
+
270
+ states_tensor = torch.FloatTensor(states).to(self.device)
271
+ actions_tensor = torch.LongTensor(actions).to(self.device)
272
+ rewards_tensor = torch.FloatTensor(rewards).to(self.device)
273
+
274
+ current_q = self.q_network(states_tensor).gather(1, actions_tensor.unsqueeze(1)).squeeze()
275
+
276
+ with torch.no_grad():
277
+ next_states = np.array([self.encode_state(i.next_state) for i in batch])
278
+ next_states_tensor = torch.FloatTensor(next_states).to(self.device)
279
+ next_q = self.target_network(next_states_tensor).max(1)[0]
280
+ dones = torch.FloatTensor([1.0 if i.done else 0.0 for i in batch]).to(self.device)
281
+ target_q = rewards_tensor + self.gamma * next_q * (1 - dones)
282
+
283
+ loss = self.criterion(current_q, target_q)
284
+
285
+ self.optimizer.zero_grad()
286
+ loss.backward()
287
+ torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
288
+ self.optimizer.step()
289
+
290
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
291
+
292
+ self.policy_version += 1
293
+
294
+ self.training_history.append({
295
+ "loss": loss.item(),
296
+ "epsilon": self.epsilon,
297
+ "policy_version": self.policy_version
298
+ })
299
+
300
+ return {
301
+ "loss": loss.item(),
302
+ "samples": len(batch),
303
+ "epsilon": self.epsilon,
304
+ "policy_version": self.policy_version
305
+ }
306
+
307
+ def update_target_network(self):
308
+ """Update target network (call periodically)"""
309
+ if HAS_TORCH:
310
+ self.target_network.load_state_dict(self.q_network.state_dict())
311
+
312
+ def save_checkpoint(self, path: str, config: Dict):
313
+ """Save model checkpoint"""
314
+ checkpoint = ModelCheckpoint(
315
+ q_network_weights=self.q_network.to_numpy(),
316
+ policy_version=self.policy_version,
317
+ training_stats={
318
+ "total_samples": len(self.replay_buffer),
319
+ "training_history": self.training_history[-100:],
320
+ "epsilon": self.epsilon
321
+ },
322
+ timestamp=datetime.now().isoformat(),
323
+ config=config
324
+ )
325
+
326
+ with open(path, 'wb') as f:
327
+ pickle.dump(checkpoint, f)
328
+
329
+ print(f"Checkpoint saved to {path}")
330
+ return path
331
+
332
+ def load_checkpoint(self, path: str):
333
+ """Load model checkpoint"""
334
+ with open(path, 'rb') as f:
335
+ checkpoint = pickle.load(f)
336
+
337
+ self.q_network.from_numpy(checkpoint.q_network_weights)
338
+ self.target_network.load_state_dict(self.q_network.state_dict())
339
+ self.policy_version = checkpoint.policy_version
340
+ self.training_history = checkpoint.training_stats.get("training_history", [])
341
+ self.epsilon = checkpoint.training_stats.get("epsilon", 0.1)
342
+
343
+ print(f"Checkpoint loaded from {path}")
344
+ return checkpoint
345
+
346
+ @staticmethod
347
+ def softmax(x: np.ndarray) -> np.ndarray:
348
+ """Softmax activation"""
349
+ exp_x = np.exp(x - np.max(x))
350
+ return exp_x / exp_x.sum()
351
+
352
+
353
+ class SyntheticDataGenerator:
354
+ """Generate synthetic training data"""
355
+
356
+ def __init__(self):
357
+ self.topics = [
358
+ "machine_learning", "deep_learning", "neural_networks",
359
+ "python", "javascript", "react", "data_science",
360
+ "statistics", "linear_algebra", "calculus"
361
+ ]
362
+
363
+ def generate_interaction(self) -> Interaction:
364
+ """Generate a synthetic interaction"""
365
+ topic = np.random.randn(32)
366
+ progress = np.random.uniform(0, 1)
367
+ confusion = np.random.uniform(0, 1)
368
+ gesture = np.random.randn(8)
369
+ time_spent = np.random.uniform(0, 3600)
370
+
371
+ state = LearningState(
372
+ topic_embedding=topic,
373
+ progress=progress,
374
+ confusion_signals=np.array([confusion, confusion + 0.1, confusion - 0.1]),
375
+ gesture_signals=gesture,
376
+ time_spent=time_spent,
377
+ session_id=f"sess_{np.random.randint(1000)}"
378
+ )
379
+
380
+ actions = ["predict_doubt", "suggest_break", "show_example", "ask_question", "explain_concept"]
381
+ action = np.random.choice(actions)
382
+
383
+ reward = np.random.uniform(-1, 1)
384
+ if "got_it" in action:
385
+ reward = np.random.uniform(0.5, 1)
386
+ elif "confused" in action:
387
+ reward = np.random.uniform(-1, -0.5)
388
+
389
+ next_confusion = confusion + np.random.uniform(-0.2, 0.2)
390
+ next_state = LearningState(
391
+ topic_embedding=topic + np.random.randn(32) * 0.01,
392
+ progress=min(1, progress + 0.01),
393
+ confusion_signals=np.array([next_confusion]),
394
+ gesture_signals=gesture,
395
+ time_spent=time_spent + 60,
396
+ session_id=state.session_id
397
+ )
398
+
399
+ done = progress >= 0.95
400
+
401
+ return Interaction(
402
+ state=state,
403
+ action=action,
404
+ reward=reward,
405
+ next_state=next_state,
406
+ done=done,
407
+ timestamp=datetime.now().isoformat()
408
+ )
409
+
410
+
411
+ def generate_training_data(agent: DoubtPredictionRL, num_samples: int = 1000):
412
+ """Generate training data"""
413
+ print(f"Generating {num_samples} training samples...")
414
+ generator = SyntheticDataGenerator()
415
+
416
+ for i in range(num_samples):
417
+ interaction = generator.generate_interaction()
418
+ agent.store_interaction(interaction)
419
+
420
+ if (i + 1) % 100 == 0:
421
+ print(f" Generated {i + 1}/{num_samples} samples")
422
+
423
+ print(f"Total samples in buffer: {len(agent.replay_buffer)}")
424
+ return agent.replay_buffer
425
+
426
+
427
+ def train_model(
428
+ agent: DoubtPredictionRL,
429
+ epochs: int = 10,
430
+ batch_size: int = 32,
431
+ update_frequency: int = 10
432
+ ) -> List[Dict]:
433
+ """Train the RL agent"""
434
+ print(f"\nTraining for {epochs} epochs...")
435
+ print(f"Batch size: {batch_size}, Update frequency: {update_frequency}")
436
+
437
+ training_stats = []
438
+
439
+ for epoch in range(epochs):
440
+ epoch_losses = []
441
+ epoch_samples = 0
442
+
443
+ steps_per_epoch = max(10, len(agent.replay_buffer) // batch_size)
444
+
445
+ for step in range(steps_per_epoch):
446
+ stats = agent.train_step(batch_size)
447
+ epoch_losses.append(stats["loss"])
448
+ epoch_samples += stats["samples"]
449
+
450
+ if (step + 1) % update_frequency == 0:
451
+ agent.update_target_network()
452
+
453
+ avg_loss = np.mean(epoch_losses) if epoch_losses else 0
454
+ training_stats.append({
455
+ "epoch": epoch + 1,
456
+ "avg_loss": avg_loss,
457
+ "samples": epoch_samples,
458
+ "epsilon": agent.epsilon,
459
+ "policy_version": agent.policy_version
460
+ })
461
+
462
+ print(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f} - Samples: {epoch_samples} - Epsilon: {agent.epsilon:.4f}")
463
+
464
+ return training_stats
465
+
466
+
467
+ def upload_to_huggingface(
468
+ checkpoint_path: str,
469
+ repo_name: str,
470
+ hf_token: str,
471
+ model_name: str = "contextflow-rl-doubt-predictor"
472
+ ):
473
+ """Upload model to Hugging Face Hub"""
474
+ if not HAS_HF:
475
+ print("huggingface_hub not installed. Cannot upload.")
476
+ return None
477
+
478
+ print(f"\nUploading to Hugging Face...")
479
+ print(f"Repository: {repo_name}")
480
+ print(f"Model name: {model_name}")
481
+
482
+ api = HfApi()
483
+
484
+ try:
485
+ create_repo(
486
+ repo_id=repo_name,
487
+ token=hf_token,
488
+ private=False,
489
+ exist_ok=True
490
+ )
491
+ print(f"Repository created/accessed: {repo_name}")
492
+ except Exception as e:
493
+ print(f"Error creating repo: {e}")
494
+ return None
495
+
496
+ model_path = Path(checkpoint_path)
497
+
498
+ readme_content = f"""---
499
+ language: en
500
+ license: apache-2.0
501
+ tags:
502
+ - reinforcement-learning
503
+ - education
504
+ - doubt-prediction
505
+ - contextflow
506
+ ---
507
+
508
+ # ContextFlow RL Doubt Predictor
509
+
510
+ ## Model Description
511
+
512
+ This is a reinforcement learning model trained for the ContextFlow project - an AI Learning Intelligence Engine that predicts when learners will get confused BEFORE it happens.
513
+
514
+ ## Model Architecture
515
+
516
+ - Q-Network with 3 hidden layers (128 units each)
517
+ - State dimension: 64
518
+ - Action dimension: 10 (different doubt prediction actions)
519
+ - Trained using GRPO (Group Relative Policy Optimization)
520
+
521
+ ## Training
522
+
523
+ Based on OpenClaw-RL principles:
524
+ - Binary RL for next-state feedback
525
+ - Experience replay with 10,000 capacity
526
+ - Epsilon-greedy exploration
527
+ - Personalization from user interactions
528
+
529
+ ## Usage
530
+
531
+ ```python
532
+ import pickle
533
+
534
+ with open("checkpoint.pkl", "rb") as f:
535
+ checkpoint = pickle.load(f)
536
+
537
+ # Load weights into your Q-network
538
+ # Model config: {checkpoint.config}
539
+ # Policy version: {checkpoint.policy_version}
540
+ ```
541
+
542
+ ## Citation
543
+
544
+ ```bibtex
545
+ @software{{contextflow_rl,
546
+ title={{ContextFlow RL Doubt Predictor}},
547
+ author={{ContextFlow Team}},
548
+ year={{2026}},
549
+ url={{https://github.com/contextflow/research-app}}
550
+ }}
551
+ ```
552
+
553
+ ## License
554
+
555
+ Apache 2.0
556
+ """
557
+
558
+ readme_path = model_path.parent / "README.md"
559
+ with open(readme_path, 'w') as f:
560
+ f.write(readme_content)
561
+
562
+ try:
563
+ api.upload_folder(
564
+ folder_path=str(model_path.parent),
565
+ repo_id=repo_name,
566
+ repo_type="model",
567
+ token=hf_token
568
+ )
569
+ print(f"\n✅ Successfully uploaded to: https://huggingface.co/{repo_name}")
570
+ return f"https://huggingface.co/{repo_name}"
571
+ except Exception as e:
572
+ print(f"Error uploading: {e}")
573
+ return None
574
+
575
+
576
+ def main():
577
+ parser = argparse.ArgumentParser(description="ContextFlow RL Training")
578
+ parser.add_argument("--mode", choices=["train", "upload", "full"], default="full")
579
+ parser.add_argument("--epochs", type=int, default=10)
580
+ parser.add_argument("--samples", type=int, default=1000)
581
+ parser.add_argument("--batch_size", type=int, default=32)
582
+ parser.add_argument("--checkpoint_path", default="checkpoint.pkl")
583
+ parser.add_argument("--repo_name", default="your-username/contextflow-rl")
584
+ parser.add_argument("--hf_token", default=None)
585
+
586
+ args = parser.parse_args()
587
+
588
+ print("=" * 60)
589
+ print("ContextFlow RL Training")
590
+ print("=" * 60)
591
+
592
+ if args.mode in ["train", "full"]:
593
+ config = {
594
+ "state_dim": 64,
595
+ "action_dim": 10,
596
+ "learning_rate": 0.001,
597
+ "gamma": 0.95,
598
+ "epsilon": 1.0,
599
+ "epsilon_decay": 0.995,
600
+ "epsilon_min": 0.01,
601
+ "hidden_dim": 128
602
+ }
603
+
604
+ print("\nInitializing RL Agent...")
605
+ agent = DoubtPredictionRL(**config)
606
+
607
+ print("\nGenerating training data...")
608
+ generate_training_data(agent, args.samples)
609
+
610
+ print("\nTraining model...")
611
+ training_stats = train_model(
612
+ agent,
613
+ epochs=args.epochs,
614
+ batch_size=args.batch_size
615
+ )
616
+
617
+ print("\nSaving checkpoint...")
618
+ checkpoint_path = args.checkpoint_path
619
+ agent.save_checkpoint(checkpoint_path, config)
620
+
621
+ print("\nTraining complete!")
622
+ print(f"Checkpoint: {checkpoint_path}")
623
+ print(f"Policy version: {agent.policy_version}")
624
+ print(f"Training samples: {len(agent.replay_buffer)}")
625
+
626
+ if args.mode in ["upload", "full"]:
627
+ if not args.hf_token:
628
+ print("\n⚠️ HF_TOKEN not provided. Run with --hf_token YOUR_TOKEN to upload.")
629
+ print("You can also download the checkpoint from:", args.checkpoint_path)
630
+ return
631
+
632
+ checkpoint_path = args.checkpoint_path
633
+ if args.mode == "upload":
634
+ print("\nLoading checkpoint from:", checkpoint_path)
635
+ config = {
636
+ "state_dim": 64,
637
+ "action_dim": 10,
638
+ "hidden_dim": 128
639
+ }
640
+ agent = DoubtPredictionRL(**config)
641
+ agent.load_checkpoint(checkpoint_path)
642
+
643
+ repo_url = upload_to_huggingface(
644
+ checkpoint_path=checkpoint_path,
645
+ repo_name=args.repo_name,
646
+ hf_token=args.hf_token
647
+ )
648
+
649
+ if repo_url:
650
+ print(f"\n🎉 Model uploaded successfully!")
651
+ print(f"View at: {repo_url}")
652
+
653
+
654
+ if __name__ == "__main__":
655
+ main()