inv0krr commited on
Commit
9c21ddc
·
verified ·
1 Parent(s): 52bd035

Add 3-phase training pipeline, data gen, evaluation

Browse files
Files changed (1) hide show
  1. leworld_training.py +820 -0
leworld_training.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LeWorld Training System
3
+ =======================
4
+ 3-Phase training procedure:
5
+ Phase 1: Pre-train components separately
6
+ Phase 2: End-to-end joint training
7
+ Phase 3: Cooperative refinement with info-request loop
8
+
9
+ Plus: Memory population strategies, data generation, evaluation.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.optim as optim
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import math
18
+ import random
19
+ from typing import Dict, List, Optional, Tuple
20
+ from dataclasses import dataclass
21
+
22
+ from leworld_architecture import (
23
+ LeWorldSystem, MemoryConfig, SLMConfig, BLMConfig,
24
+ ArtificialMemory, SmallLeWorldModel, BigLeWorldModel,
25
+ count_params
26
+ )
27
+
28
+
29
+ # =============================================================================
30
+ # Training Configuration
31
+ # =============================================================================
32
+
33
+ @dataclass
34
+ class TrainingConfig:
35
+ """Full training configuration."""
36
+ # Phase 1: Pre-training
37
+ phase1_lr: float = 1e-3
38
+ phase1_epochs: int = 50
39
+ phase1_batch_size: int = 32
40
+
41
+ # Phase 2: Joint training
42
+ phase2_lr: float = 3e-4
43
+ phase2_epochs: int = 100
44
+ phase2_batch_size: int = 16
45
+ phase2_warmup_steps: int = 500
46
+
47
+ # Phase 3: Refinement
48
+ phase3_lr: float = 1e-4
49
+ phase3_epochs: int = 50
50
+ phase3_batch_size: int = 16
51
+
52
+ # General
53
+ weight_decay: float = 0.01
54
+ grad_clip: float = 1.0
55
+ state_dim: int = 64
56
+ char_dim: int = 32
57
+ sequence_length: int = 20 # timesteps per sequence
58
+
59
+ # Loss weights
60
+ lambda_balance: float = 0.01 # routing balance
61
+ lambda_diversity: float = 0.001 # address diversity
62
+ lambda_entropy: float = 0.01 # routing entropy
63
+ lambda_info_util: float = 0.1 # info request utility
64
+
65
+ # Temperature annealing
66
+ temp_anneal_rate: float = 3e-5
67
+ temp_min: float = 0.1
68
+
69
+
70
+ # =============================================================================
71
+ # Synthetic Data Generation
72
+ # =============================================================================
73
+
74
+ class StateTransitionDataset(Dataset):
75
+ """
76
+ Generates synthetic state transition sequences for training.
77
+
78
+ Each sequence has:
79
+ - States that evolve according to learnable dynamics
80
+ - Characteristics that stay fixed per sequence
81
+ - Ground-truth "useful memory" labels (for Phase 1 SLM pre-training)
82
+
83
+ The key insight: we embed patterns into memory, and the state transitions
84
+ DEPEND on what's in specific memory regions. This creates a genuine need
85
+ for memory retrieval — the model can't predict next state without reading
86
+ the right memory.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ num_sequences: int,
92
+ seq_length: int,
93
+ state_dim: int,
94
+ char_dim: int,
95
+ memory: ArtificialMemory,
96
+ difficulty: str = "easy", # easy, medium, hard
97
+ ):
98
+ self.num_sequences = num_sequences
99
+ self.seq_length = seq_length
100
+ self.state_dim = state_dim
101
+ self.char_dim = char_dim
102
+ self.memory = memory
103
+
104
+ # Generate all sequences upfront
105
+ self.data = self._generate_sequences(difficulty)
106
+
107
+ def _generate_sequences(self, difficulty: str) -> List[Dict]:
108
+ """Generate synthetic state-transition sequences."""
109
+ data = []
110
+ mem_size = self.memory.config.num_words
111
+
112
+ for _ in range(self.num_sequences):
113
+ # Static characteristics for this sequence
114
+ characteristics = torch.randn(self.char_dim)
115
+
116
+ # Choose "relevant" memory regions (ground truth for SLM training)
117
+ if difficulty == "easy":
118
+ n_relevant = 1 # only one memory region matters
119
+ elif difficulty == "medium":
120
+ n_relevant = 2
121
+ else:
122
+ n_relevant = 3
123
+
124
+ relevant_addrs = []
125
+ for _ in range(n_relevant):
126
+ start = random.randint(0, mem_size - 256)
127
+ length = random.randint(16, 128)
128
+ relevant_addrs.append((start, start + length))
129
+
130
+ # Generate state sequence where transitions depend on memory content
131
+ states = torch.zeros(self.seq_length, self.state_dim)
132
+ states[0] = torch.randn(self.state_dim)
133
+
134
+ # The transition rule: next_state = f(current_state, memory_content)
135
+ # We use a simple linear rule seeded by the memory content
136
+ with torch.no_grad():
137
+ for addr_start, addr_end in relevant_addrs:
138
+ mem_bits = self.memory.memory[addr_start:addr_end].mean(dim=0)
139
+ # Memory content influences the transition dynamics
140
+ # Pad/tile mem_bits to state_dim
141
+ transition_seed_raw = mem_bits * 2 - 1 # map 0,1 → -1,1
142
+ transition_seed = transition_seed_raw.repeat(
143
+ math.ceil(self.state_dim / len(transition_seed_raw))
144
+ )[:self.state_dim]
145
+
146
+ # Pad/tile characteristics to state_dim
147
+ char_padded = characteristics.repeat(
148
+ math.ceil(self.state_dim / len(characteristics))
149
+ )[:self.state_dim]
150
+
151
+ for t in range(1, self.seq_length):
152
+ noise = torch.randn(self.state_dim) * 0.1
153
+ # State evolves based on current state + memory influence
154
+ states[t] = (
155
+ 0.8 * states[t-1]
156
+ + 0.15 * transition_seed
157
+ + 0.05 * char_padded
158
+ + noise
159
+ )
160
+
161
+ data.append({
162
+ 'states': states, # (seq_length, state_dim)
163
+ 'characteristics': characteristics, # (char_dim,)
164
+ 'relevant_addrs': relevant_addrs, # list of (start, end) tuples
165
+ 'n_relevant': n_relevant,
166
+ })
167
+
168
+ return data
169
+
170
+ def __len__(self):
171
+ return self.num_sequences
172
+
173
+ def __getitem__(self, idx):
174
+ item = self.data[idx]
175
+
176
+ # Pad relevant addresses to fixed length (3 = max n_slms)
177
+ padded_starts = torch.zeros(3, dtype=torch.long)
178
+ padded_ends = torch.zeros(3, dtype=torch.long)
179
+ for i, (s, e) in enumerate(item['relevant_addrs']):
180
+ padded_starts[i] = s
181
+ padded_ends[i] = e
182
+
183
+ return {
184
+ 'states': item['states'],
185
+ 'characteristics': item['characteristics'],
186
+ 'relevant_starts': padded_starts,
187
+ 'relevant_ends': padded_ends,
188
+ 'n_relevant': item['n_relevant'],
189
+ }
190
+
191
+
192
+ # =============================================================================
193
+ # Phase 1: Pre-training (Components Separately)
194
+ # =============================================================================
195
+
196
+ class Phase1Trainer:
197
+ """
198
+ Pre-train SLMs and BLM separately.
199
+
200
+ SLMs: Given (past_state, current_state, characteristics), learn to output
201
+ address ranges that point to "relevant" memory regions.
202
+ Loss: distance between predicted address range and ground-truth relevant region.
203
+
204
+ BLM: Given perfect memory reads, learn to predict next state.
205
+ Loss: MSE between predicted and actual next state.
206
+ """
207
+
208
+ def __init__(self, system: LeWorldSystem, config: TrainingConfig):
209
+ self.system = system
210
+ self.config = config
211
+
212
+ # Separate optimizers for SLMs and BLM
213
+ self.slm_optimizer = optim.AdamW(
214
+ system.slms.parameters(),
215
+ lr=config.phase1_lr,
216
+ weight_decay=config.weight_decay
217
+ )
218
+ self.blm_optimizer = optim.AdamW(
219
+ list(system.blm.parameters()) + list(system.memory.parameters()),
220
+ lr=config.phase1_lr,
221
+ weight_decay=config.weight_decay
222
+ )
223
+
224
+ def train_slms_step(self, batch: Dict) -> Dict:
225
+ """
226
+ Train SLMs to find relevant memory regions.
227
+
228
+ Loss: |predicted_addr - target_addr| normalized by address space.
229
+ """
230
+ self.slm_optimizer.zero_grad()
231
+
232
+ states = batch['states'] # (B, T, state_dim)
233
+ chars = batch['characteristics'] # (B, char_dim)
234
+ target_starts = batch['relevant_starts'] # (B, 3)
235
+ target_ends = batch['relevant_ends'] # (B, 3)
236
+
237
+ total_loss = None
238
+
239
+ # For each SLM, train to find the corresponding relevant region
240
+ for i, slm in enumerate(self.system.slms):
241
+ # Use first two timesteps as past/current
242
+ past_state = states[:, 0, :]
243
+ current_state = states[:, 1, :]
244
+
245
+ output = slm(past_state, current_state, chars)
246
+
247
+ # Use logits (differentiable) instead of hard addresses
248
+ # Target: which high/low byte corresponds to the target address
249
+ tgt_start = target_starts[:, i].long()
250
+
251
+ half_space = slm.address_head.half_space # 256
252
+ tgt_high = tgt_start // half_space # high byte
253
+ tgt_low = tgt_start % half_space # low byte
254
+
255
+ # Cross-entropy over address components (differentiable!)
256
+ addr_loss = (
257
+ F.cross_entropy(output['start_logits_high'], tgt_high) +
258
+ F.cross_entropy(output['start_logits_low'], tgt_low)
259
+ )
260
+
261
+ # Range length loss
262
+ tgt_range = (target_ends[:, i] - target_starts[:, i]).clamp(1, self.system.memory.config.max_read_range) - 1
263
+ range_loss = F.cross_entropy(output['range_logits'], tgt_range.long())
264
+
265
+ slm_loss = addr_loss + 0.5 * range_loss
266
+
267
+ if total_loss is None:
268
+ total_loss = slm_loss
269
+ else:
270
+ total_loss = total_loss + slm_loss
271
+
272
+ total_loss = total_loss / len(self.system.slms)
273
+ total_loss.backward()
274
+ torch.nn.utils.clip_grad_norm_(self.system.slms.parameters(), self.config.grad_clip)
275
+ self.slm_optimizer.step()
276
+
277
+ return {'slm_loss': total_loss.item()}
278
+
279
+ def train_blm_step(self, batch: Dict) -> Dict:
280
+ """
281
+ Train BLM to predict next state given oracle memory reads.
282
+
283
+ Oracle: we read from the KNOWN relevant memory regions (ground truth).
284
+ """
285
+ self.blm_optimizer.zero_grad()
286
+
287
+ states = batch['states']
288
+ chars = batch['characteristics']
289
+ target_starts = batch['relevant_starts']
290
+ target_ends = batch['relevant_ends']
291
+
292
+ batch_size = states.shape[0]
293
+
294
+ # Read oracle memory
295
+ oracle_reads = []
296
+ slm_fake_outputs = []
297
+ for i in range(3):
298
+ _, encoded, _ = self.system.memory.read(
299
+ target_starts[:, i], target_ends[:, i]
300
+ )
301
+ oracle_reads.append(encoded)
302
+ # Create fake SLM output (just need hidden state)
303
+ fake_hidden = torch.zeros(batch_size, 128) # SLM d_model = 128
304
+ slm_fake_outputs.append({
305
+ 'hidden': fake_hidden,
306
+ 'start_addr': target_starts[:, i],
307
+ 'end_addr': target_ends[:, i],
308
+ 'confidence': torch.ones(batch_size),
309
+ })
310
+
311
+ # BLM forward with oracle reads
312
+ total_loss = None
313
+ for t in range(states.shape[1] - 1):
314
+ past_state = states[:, max(0, t-1), :]
315
+ current_state = states[:, t, :]
316
+ next_state = states[:, t+1, :]
317
+
318
+ blm_out = self.system.blm(
319
+ past_state, current_state,
320
+ slm_fake_outputs, oracle_reads
321
+ )
322
+
323
+ loss = F.mse_loss(blm_out['next_state'], next_state)
324
+ if total_loss is None:
325
+ total_loss = loss
326
+ else:
327
+ total_loss = total_loss + loss
328
+
329
+ total_loss = total_loss / (states.shape[1] - 1)
330
+ total_loss.backward()
331
+ torch.nn.utils.clip_grad_norm_(
332
+ list(self.system.blm.parameters()) + list(self.system.memory.parameters()),
333
+ self.config.grad_clip
334
+ )
335
+ self.blm_optimizer.step()
336
+
337
+ return {'blm_loss': total_loss.item()}
338
+
339
+
340
+ # =============================================================================
341
+ # Phase 2: End-to-End Joint Training
342
+ # =============================================================================
343
+
344
+ class Phase2Trainer:
345
+ """
346
+ Joint training of the entire system end-to-end.
347
+
348
+ The full pipeline runs: SLMs → Memory Read → BLM → Next State
349
+
350
+ Key challenge: gradient flow through discrete decisions
351
+ - SLM address selection: use soft attention + hard address (ST trick)
352
+ - BLM routing: use straight-through sigmoid
353
+
354
+ Losses:
355
+ 1. next_state_loss: primary prediction accuracy
356
+ 2. balance_loss: balanced SLM routing
357
+ 3. diversity_loss: SLMs read different memory regions
358
+ 4. info_utility_loss: BLM's info request improves future predictions
359
+ """
360
+
361
+ def __init__(self, system: LeWorldSystem, config: TrainingConfig):
362
+ self.system = system
363
+ self.config = config
364
+
365
+ # Single optimizer for everything
366
+ self.optimizer = optim.AdamW(
367
+ system.parameters(),
368
+ lr=config.phase2_lr,
369
+ weight_decay=config.weight_decay
370
+ )
371
+
372
+ # Learning rate scheduler
373
+ self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
374
+ self.optimizer, T_0=config.phase2_epochs // 3, T_mult=2
375
+ )
376
+
377
+ self.global_step = 0
378
+
379
+ def train_step(self, batch: Dict) -> Dict:
380
+ """Full end-to-end training step."""
381
+ self.optimizer.zero_grad()
382
+
383
+ states = batch['states']
384
+ chars = batch['characteristics']
385
+
386
+ # Multi-step forward
387
+ output = self.system.multi_step_forward(states, chars)
388
+
389
+ loss = output['total_loss']
390
+ loss.backward()
391
+
392
+ # Gradient clipping
393
+ torch.nn.utils.clip_grad_norm_(
394
+ self.system.parameters(), self.config.grad_clip
395
+ )
396
+
397
+ self.optimizer.step()
398
+
399
+ # Temperature annealing for router
400
+ self.global_step += 1
401
+ self.system.blm.router.anneal_temperature(
402
+ self.global_step,
403
+ self.config.temp_anneal_rate,
404
+ self.config.temp_min
405
+ )
406
+
407
+ return {
408
+ 'total_loss': loss.item(),
409
+ 'temperature': self.system.blm.router.temperature.item(),
410
+ 'step': self.global_step,
411
+ }
412
+
413
+
414
+ # =============================================================================
415
+ # Phase 3: Cooperative Refinement with Info-Request Loop
416
+ # =============================================================================
417
+
418
+ class Phase3Trainer:
419
+ """
420
+ Refinement phase: train the info-request mechanism.
421
+
422
+ The BLM learns to generate useful "what info do I need?" queries that
423
+ improve the SLMs' memory retrieval in the NEXT timestep.
424
+
425
+ Training signal: compare prediction quality WITH vs WITHOUT info-request
426
+ modulation. If info-request helped → reward; if not → penalize.
427
+
428
+ This is inspired by ProactAgent (arxiv:2604.20572) paired-branch reward.
429
+ """
430
+
431
+ def __init__(self, system: LeWorldSystem, config: TrainingConfig):
432
+ self.system = system
433
+ self.config = config
434
+
435
+ # Optimizer: higher LR for info-request modules, lower for rest
436
+ info_params = set(id(p) for p in system.blm.info_request.parameters())
437
+ info_params.update(id(p) for p in system.info_to_slm.parameters())
438
+
439
+ other_blm_params = [p for p in system.blm.parameters() if id(p) not in info_params]
440
+
441
+ self.optimizer = optim.AdamW([
442
+ {'params': list(system.blm.info_request.parameters()) + list(system.info_to_slm.parameters()), 'lr': config.phase3_lr},
443
+ {'params': list(system.slms.parameters()), 'lr': config.phase3_lr * 0.1},
444
+ {'params': other_blm_params, 'lr': config.phase3_lr * 0.1},
445
+ {'params': list(system.memory.parameters()), 'lr': config.phase3_lr * 0.01},
446
+ ], weight_decay=config.weight_decay)
447
+
448
+ def train_step(self, batch: Dict) -> Dict:
449
+ """
450
+ Paired-branch training:
451
+ Branch A: Run with info-request modulation (full system)
452
+ Branch B: Run WITHOUT info-request (baseline)
453
+ Reward = improvement of A over B
454
+ """
455
+ self.optimizer.zero_grad()
456
+
457
+ states = batch['states']
458
+ chars = batch['characteristics']
459
+
460
+ # Branch A: with info-request loop
461
+ output_with = self.system.multi_step_forward(states, chars)
462
+ loss_with = output_with['total_loss']
463
+
464
+ # Branch B: without info-request (set info_query to None at each step)
465
+ # We do this by running forward without passing info_query between steps
466
+ batch_size, T, state_dim = states.shape
467
+ loss_without = None
468
+
469
+ for t in range(T - 1):
470
+ past_state = states[:, max(0, t-1), :]
471
+ current_state = states[:, t, :]
472
+ next_state = states[:, t+1, :]
473
+
474
+ output = self.system(
475
+ past_state, current_state, chars,
476
+ next_state, info_query_prev=None # NO info request
477
+ )
478
+ if output['losses']:
479
+ if loss_without is None:
480
+ loss_without = output['losses']['next_state_loss']
481
+ else:
482
+ loss_without = loss_without + output['losses']['next_state_loss']
483
+
484
+ if loss_without is None:
485
+ loss_without = torch.tensor(0.0)
486
+ else:
487
+ loss_without = loss_without / max(1, T - 1)
488
+
489
+ # Info utility: reward if info-request helps, penalize if not
490
+ improvement = (loss_without - loss_with).detach() # positive = info helped
491
+
492
+ # Total loss: prediction loss + info utility bonus
493
+ total_loss = loss_with - self.config.lambda_info_util * improvement
494
+
495
+ total_loss.backward()
496
+ torch.nn.utils.clip_grad_norm_(self.system.parameters(), self.config.grad_clip)
497
+ self.optimizer.step()
498
+
499
+ return {
500
+ 'loss_with_info': loss_with.item(),
501
+ 'loss_without_info': loss_without.item(),
502
+ 'improvement': improvement.item(),
503
+ 'total_loss': total_loss.item(),
504
+ }
505
+
506
+
507
+ # =============================================================================
508
+ # Memory Population Strategies
509
+ # =============================================================================
510
+
511
+ class MemoryPopulator:
512
+ """
513
+ Strategies for populating the artificial memory with meaningful content.
514
+
515
+ In a real application, memory would be populated by experience / observations.
516
+ Here we provide several strategies for initial content.
517
+ """
518
+
519
+ @staticmethod
520
+ def random_bits(memory: ArtificialMemory):
521
+ """Fill with random bits (baseline)."""
522
+ memory.memory.uniform_(0, 1).round_()
523
+
524
+ @staticmethod
525
+ def structured_patterns(memory: ArtificialMemory):
526
+ """
527
+ Fill with structured patterns that encode different "knowledge types."
528
+
529
+ Memory layout:
530
+ - [0x0000 - 0x3FFF]: Dynamics patterns (state transition rules)
531
+ - [0x4000 - 0x7FFF]: Context patterns (characteristic-dependent info)
532
+ - [0x8000 - 0xBFFF]: History patterns (temporal sequences)
533
+ - [0xC000 - 0xFFFF]: Association patterns (cross-references)
534
+ """
535
+ N = memory.config.num_words
536
+ W = memory.config.word_size
537
+ quarter = N // 4
538
+
539
+ with torch.no_grad():
540
+ # Region 1: Dynamics — repeating patterns (easy to learn)
541
+ for i in range(quarter):
542
+ pattern = torch.zeros(W)
543
+ pattern[i % W] = 1.0 # cyclic single-bit pattern
544
+ memory.memory[i] = pattern
545
+
546
+ # Region 2: Context — characteristic-dependent
547
+ for i in range(quarter, 2 * quarter):
548
+ seed = i - quarter
549
+ torch.manual_seed(seed)
550
+ memory.memory[i] = torch.randint(0, 2, (W,)).float()
551
+
552
+ # Region 3: History — sequential counting in binary
553
+ for i in range(2 * quarter, 3 * quarter):
554
+ binary = torch.zeros(W)
555
+ val = i - 2 * quarter
556
+ for bit in range(min(W, 16)):
557
+ binary[bit] = float((val >> bit) & 1)
558
+ memory.memory[i] = binary
559
+
560
+ # Region 4: Associations — XOR patterns
561
+ for i in range(3 * quarter, N):
562
+ a = memory.memory[i % quarter] # reference region 1
563
+ b = memory.memory[quarter + (i % quarter)] # reference region 2
564
+ memory.memory[i] = ((a + b) % 2) # XOR
565
+
566
+ @staticmethod
567
+ def from_experience(memory: ArtificialMemory, experiences: torch.Tensor):
568
+ """
569
+ Populate memory from observed data.
570
+
571
+ Args:
572
+ experiences: (N, feature_dim) tensor of observed features
573
+ Each feature vector gets encoded to bits and stored
574
+ """
575
+ with torch.no_grad():
576
+ N = min(experiences.shape[0], memory.config.num_words)
577
+ W = memory.config.word_size
578
+
579
+ # Simple quantization: threshold at median
580
+ for i in range(N):
581
+ feat = experiences[i]
582
+ # Truncate/pad to word_size
583
+ if len(feat) >= W:
584
+ bits = (feat[:W] > feat[:W].median()).float()
585
+ else:
586
+ bits = torch.zeros(W)
587
+ bits[:len(feat)] = (feat > feat.median()).float()
588
+ memory.memory[i] = bits
589
+
590
+
591
+ # =============================================================================
592
+ # Evaluation
593
+ # =============================================================================
594
+
595
+ class Evaluator:
596
+ """Evaluation metrics for the LeWorld system."""
597
+
598
+ @staticmethod
599
+ def prediction_accuracy(
600
+ system: LeWorldSystem,
601
+ dataloader: DataLoader,
602
+ n_steps: int = 5
603
+ ) -> Dict:
604
+ """
605
+ Evaluate next-state prediction accuracy.
606
+
607
+ Metrics:
608
+ - MSE: mean squared error of state predictions
609
+ - MAE: mean absolute error
610
+ - Multi-step MSE: prediction error at different horizons
611
+ - Routing diversity: how varied the SLM selections are
612
+ """
613
+ system.eval()
614
+ total_mse = 0.0
615
+ total_mae = 0.0
616
+ step_mses = [0.0] * n_steps
617
+ all_masks = []
618
+ n_batches = 0
619
+
620
+ with torch.no_grad():
621
+ for batch in dataloader:
622
+ states = batch['states']
623
+ chars = batch['characteristics']
624
+
625
+ output = system.multi_step_forward(states, chars, n_steps)
626
+
627
+ # Ground truth future states
628
+ gt_future = states[:, 1:n_steps+1, :]
629
+ pred_future = output['predictions'][:, :n_steps, :]
630
+
631
+ actual_steps = min(n_steps, pred_future.shape[1])
632
+
633
+ mse = F.mse_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps])
634
+ mae = F.l1_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps])
635
+
636
+ total_mse += mse.item()
637
+ total_mae += mae.item()
638
+
639
+ # Per-step MSE
640
+ for t in range(actual_steps):
641
+ step_mse = F.mse_loss(pred_future[:, t], gt_future[:, t])
642
+ step_mses[t] += step_mse.item()
643
+
644
+ # Collect routing masks
645
+ all_masks.append(output['masks'])
646
+ n_batches += 1
647
+
648
+ # Routing diversity: entropy of SLM usage
649
+ all_masks = torch.cat(all_masks, dim=0) # (total, T, n_slms)
650
+ usage_per_slm = all_masks.mean(dim=(0, 1)) # (n_slms,)
651
+ routing_entropy = -(usage_per_slm * torch.log(usage_per_slm + 1e-8)).sum().item()
652
+
653
+ system.train()
654
+
655
+ return {
656
+ 'mse': total_mse / max(1, n_batches),
657
+ 'mae': total_mae / max(1, n_batches),
658
+ 'step_mses': [m / max(1, n_batches) for m in step_mses],
659
+ 'routing_entropy': routing_entropy,
660
+ 'slm_usage': usage_per_slm.tolist(),
661
+ }
662
+
663
+
664
+ # =============================================================================
665
+ # Full Training Pipeline
666
+ # =============================================================================
667
+
668
+ def run_training(
669
+ system: LeWorldSystem,
670
+ train_config: TrainingConfig,
671
+ num_train_sequences: int = 1000,
672
+ num_val_sequences: int = 200,
673
+ ):
674
+ """Execute the full 3-phase training pipeline."""
675
+
676
+ print("=" * 70)
677
+ print("LeWorld Training Pipeline")
678
+ print("=" * 70)
679
+
680
+ # Populate memory with structured patterns
681
+ print("\n[Setup] Populating artificial memory...")
682
+ MemoryPopulator.structured_patterns(system.memory)
683
+
684
+ # Create datasets
685
+ print("[Setup] Generating training data...")
686
+ train_dataset = StateTransitionDataset(
687
+ num_sequences=num_train_sequences,
688
+ seq_length=train_config.sequence_length,
689
+ state_dim=train_config.state_dim,
690
+ char_dim=train_config.char_dim,
691
+ memory=system.memory,
692
+ difficulty="medium",
693
+ )
694
+
695
+ val_dataset = StateTransitionDataset(
696
+ num_sequences=num_val_sequences,
697
+ seq_length=train_config.sequence_length,
698
+ state_dim=train_config.state_dim,
699
+ char_dim=train_config.char_dim,
700
+ memory=system.memory,
701
+ difficulty="medium",
702
+ )
703
+
704
+ train_loader = DataLoader(train_dataset, batch_size=train_config.phase1_batch_size, shuffle=True)
705
+ val_loader = DataLoader(val_dataset, batch_size=train_config.phase1_batch_size)
706
+
707
+ evaluator = Evaluator()
708
+
709
+ # ===== Phase 1: Pre-training =====
710
+ print(f"\n{'='*70}")
711
+ print("Phase 1: Pre-training (SLMs + BLM separately)")
712
+ print(f"{'='*70}")
713
+
714
+ phase1 = Phase1Trainer(system, train_config)
715
+
716
+ for epoch in range(min(3, train_config.phase1_epochs)): # shortened for demo
717
+ epoch_slm_loss = 0
718
+ epoch_blm_loss = 0
719
+ n_batches = 0
720
+
721
+ for batch in train_loader:
722
+ slm_metrics = phase1.train_slms_step(batch)
723
+ blm_metrics = phase1.train_blm_step(batch)
724
+
725
+ epoch_slm_loss += slm_metrics['slm_loss']
726
+ epoch_blm_loss += blm_metrics['blm_loss']
727
+ n_batches += 1
728
+
729
+ print(f" Epoch {epoch+1}: SLM loss={epoch_slm_loss/n_batches:.4f}, "
730
+ f"BLM loss={epoch_blm_loss/n_batches:.4f}")
731
+
732
+ # Evaluate after Phase 1
733
+ val_metrics = evaluator.prediction_accuracy(system, val_loader, n_steps=5)
734
+ print(f" Phase 1 eval: MSE={val_metrics['mse']:.4f}, "
735
+ f"Routing entropy={val_metrics['routing_entropy']:.4f}")
736
+
737
+ # ===== Phase 2: Joint Training =====
738
+ print(f"\n{'='*70}")
739
+ print("Phase 2: End-to-End Joint Training")
740
+ print(f"{'='*70}")
741
+
742
+ phase2 = Phase2Trainer(system, train_config)
743
+ train_loader2 = DataLoader(train_dataset, batch_size=train_config.phase2_batch_size, shuffle=True)
744
+ val_loader2 = DataLoader(val_dataset, batch_size=train_config.phase2_batch_size)
745
+
746
+ for epoch in range(min(5, train_config.phase2_epochs)): # shortened for demo
747
+ epoch_loss = 0
748
+ n_batches = 0
749
+
750
+ for batch in train_loader2:
751
+ metrics = phase2.train_step(batch)
752
+ epoch_loss += metrics['total_loss']
753
+ n_batches += 1
754
+
755
+ print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, "
756
+ f"temp={metrics['temperature']:.4f}")
757
+
758
+ val_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5)
759
+ print(f" Phase 2 eval: MSE={val_metrics['mse']:.4f}, "
760
+ f"Routing entropy={val_metrics['routing_entropy']:.4f}, "
761
+ f"SLM usage={[f'{u:.2f}' for u in val_metrics['slm_usage']]}")
762
+
763
+ # ===== Phase 3: Info-Request Refinement =====
764
+ print(f"\n{'='*70}")
765
+ print("Phase 3: Info-Request Cooperative Refinement")
766
+ print(f"{'='*70}")
767
+
768
+ phase3 = Phase3Trainer(system, train_config)
769
+
770
+ for epoch in range(min(3, train_config.phase3_epochs)): # shortened for demo
771
+ epoch_loss = 0
772
+ epoch_improvement = 0
773
+ n_batches = 0
774
+
775
+ for batch in train_loader2:
776
+ metrics = phase3.train_step(batch)
777
+ epoch_loss += metrics['total_loss']
778
+ epoch_improvement += metrics['improvement']
779
+ n_batches += 1
780
+
781
+ print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, "
782
+ f"info improvement={epoch_improvement/n_batches:.4f}")
783
+
784
+ # Final evaluation
785
+ print(f"\n{'='*70}")
786
+ print("Final Evaluation")
787
+ print(f"{'='*70}")
788
+
789
+ final_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5)
790
+ print(f" Final MSE: {final_metrics['mse']:.4f}")
791
+ print(f" Final MAE: {final_metrics['mae']:.4f}")
792
+ print(f" Per-step MSE: {[f'{m:.4f}' for m in final_metrics['step_mses']]}")
793
+ print(f" Routing entropy: {final_metrics['routing_entropy']:.4f}")
794
+ print(f" SLM usage: {[f'{u:.2f}' for u in final_metrics['slm_usage']]}")
795
+
796
+ return final_metrics
797
+
798
+
799
+ # =============================================================================
800
+ # Entry Point
801
+ # =============================================================================
802
+
803
+ if __name__ == "__main__":
804
+ # Build system
805
+ mem_config = MemoryConfig()
806
+ slm_config = SLMConfig()
807
+ blm_config = BLMConfig()
808
+ train_config = TrainingConfig(sequence_length=10) # shorter for demo
809
+
810
+ system = LeWorldSystem(mem_config, slm_config, blm_config)
811
+ count_params(system, "Full LeWorld System")
812
+
813
+ # Run training
814
+ metrics = run_training(
815
+ system, train_config,
816
+ num_train_sequences=100, # small for demo
817
+ num_val_sequences=30,
818
+ )
819
+
820
+ print("\n✅ Training pipeline complete!")