""" LeWorld Training System ======================= 3-Phase training procedure: Phase 1: Pre-train components separately Phase 2: End-to-end joint training Phase 3: Cooperative refinement with info-request loop Plus: Memory population strategies, data generation, evaluation. """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader import math import random from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from leworld_architecture import ( LeWorldSystem, MemoryConfig, SLMConfig, BLMConfig, ArtificialMemory, SmallLeWorldModel, BigLeWorldModel, count_params ) # ============================================================================= # Training Configuration # ============================================================================= @dataclass class TrainingConfig: """Full training configuration.""" # Phase 1: Pre-training phase1_lr: float = 1e-3 phase1_epochs: int = 50 phase1_batch_size: int = 32 # Phase 2: Joint training phase2_lr: float = 3e-4 phase2_epochs: int = 100 phase2_batch_size: int = 16 phase2_warmup_steps: int = 500 # Phase 3: Refinement phase3_lr: float = 1e-4 phase3_epochs: int = 50 phase3_batch_size: int = 16 # General weight_decay: float = 0.01 grad_clip: float = 1.0 state_dim: int = 64 char_dim: int = 32 sequence_length: int = 20 # timesteps per sequence # Loss weights lambda_balance: float = 0.01 # routing balance lambda_diversity: float = 0.001 # address diversity lambda_entropy: float = 0.01 # routing entropy lambda_info_util: float = 0.1 # info request utility # Temperature annealing temp_anneal_rate: float = 3e-5 temp_min: float = 0.1 # ============================================================================= # Synthetic Data Generation # ============================================================================= class StateTransitionDataset(Dataset): """ Generates synthetic state transition sequences for training. Each sequence has: - States that evolve according to learnable dynamics - Characteristics that stay fixed per sequence - Ground-truth "useful memory" labels (for Phase 1 SLM pre-training) The key insight: we embed patterns into memory, and the state transitions DEPEND on what's in specific memory regions. This creates a genuine need for memory retrieval — the model can't predict next state without reading the right memory. """ def __init__( self, num_sequences: int, seq_length: int, state_dim: int, char_dim: int, memory: ArtificialMemory, difficulty: str = "easy", # easy, medium, hard ): self.num_sequences = num_sequences self.seq_length = seq_length self.state_dim = state_dim self.char_dim = char_dim self.memory = memory # Generate all sequences upfront self.data = self._generate_sequences(difficulty) def _generate_sequences(self, difficulty: str) -> List[Dict]: """Generate synthetic state-transition sequences.""" data = [] mem_size = self.memory.config.num_words for _ in range(self.num_sequences): # Static characteristics for this sequence characteristics = torch.randn(self.char_dim) # Choose "relevant" memory regions (ground truth for SLM training) if difficulty == "easy": n_relevant = 1 # only one memory region matters elif difficulty == "medium": n_relevant = 2 else: n_relevant = 3 relevant_addrs = [] for _ in range(n_relevant): start = random.randint(0, mem_size - 256) length = random.randint(16, 128) relevant_addrs.append((start, start + length)) # Generate state sequence where transitions depend on memory content states = torch.zeros(self.seq_length, self.state_dim) states[0] = torch.randn(self.state_dim) # The transition rule: next_state = f(current_state, memory_content) # We use a simple linear rule seeded by the memory content with torch.no_grad(): for addr_start, addr_end in relevant_addrs: mem_bits = self.memory.memory[addr_start:addr_end].mean(dim=0) # Memory content influences the transition dynamics # Pad/tile mem_bits to state_dim transition_seed_raw = mem_bits * 2 - 1 # map 0,1 → -1,1 transition_seed = transition_seed_raw.repeat( math.ceil(self.state_dim / len(transition_seed_raw)) )[:self.state_dim] # Pad/tile characteristics to state_dim char_padded = characteristics.repeat( math.ceil(self.state_dim / len(characteristics)) )[:self.state_dim] for t in range(1, self.seq_length): noise = torch.randn(self.state_dim) * 0.1 # State evolves based on current state + memory influence states[t] = ( 0.8 * states[t-1] + 0.15 * transition_seed + 0.05 * char_padded + noise ) data.append({ 'states': states, # (seq_length, state_dim) 'characteristics': characteristics, # (char_dim,) 'relevant_addrs': relevant_addrs, # list of (start, end) tuples 'n_relevant': n_relevant, }) return data def __len__(self): return self.num_sequences def __getitem__(self, idx): item = self.data[idx] # Pad relevant addresses to fixed length (3 = max n_slms) padded_starts = torch.zeros(3, dtype=torch.long) padded_ends = torch.zeros(3, dtype=torch.long) for i, (s, e) in enumerate(item['relevant_addrs']): padded_starts[i] = s padded_ends[i] = e return { 'states': item['states'], 'characteristics': item['characteristics'], 'relevant_starts': padded_starts, 'relevant_ends': padded_ends, 'n_relevant': item['n_relevant'], } # ============================================================================= # Phase 1: Pre-training (Components Separately) # ============================================================================= class Phase1Trainer: """ Pre-train SLMs and BLM separately. SLMs: Given (past_state, current_state, characteristics), learn to output address ranges that point to "relevant" memory regions. Loss: distance between predicted address range and ground-truth relevant region. BLM: Given perfect memory reads, learn to predict next state. Loss: MSE between predicted and actual next state. """ def __init__(self, system: LeWorldSystem, config: TrainingConfig): self.system = system self.config = config # Separate optimizers for SLMs and BLM self.slm_optimizer = optim.AdamW( system.slms.parameters(), lr=config.phase1_lr, weight_decay=config.weight_decay ) self.blm_optimizer = optim.AdamW( list(system.blm.parameters()) + list(system.memory.parameters()), lr=config.phase1_lr, weight_decay=config.weight_decay ) def train_slms_step(self, batch: Dict) -> Dict: """ Train SLMs to find relevant memory regions. Loss: |predicted_addr - target_addr| normalized by address space. """ self.slm_optimizer.zero_grad() states = batch['states'] # (B, T, state_dim) chars = batch['characteristics'] # (B, char_dim) target_starts = batch['relevant_starts'] # (B, 3) target_ends = batch['relevant_ends'] # (B, 3) total_loss = None # For each SLM, train to find the corresponding relevant region for i, slm in enumerate(self.system.slms): # Use first two timesteps as past/current past_state = states[:, 0, :] current_state = states[:, 1, :] output = slm(past_state, current_state, chars) # Use logits (differentiable) instead of hard addresses # Target: which high/low byte corresponds to the target address tgt_start = target_starts[:, i].long() half_space = slm.address_head.half_space # 256 tgt_high = tgt_start // half_space # high byte tgt_low = tgt_start % half_space # low byte # Cross-entropy over address components (differentiable!) addr_loss = ( F.cross_entropy(output['start_logits_high'], tgt_high) + F.cross_entropy(output['start_logits_low'], tgt_low) ) # Range length loss tgt_range = (target_ends[:, i] - target_starts[:, i]).clamp(1, self.system.memory.config.max_read_range) - 1 range_loss = F.cross_entropy(output['range_logits'], tgt_range.long()) slm_loss = addr_loss + 0.5 * range_loss if total_loss is None: total_loss = slm_loss else: total_loss = total_loss + slm_loss total_loss = total_loss / len(self.system.slms) total_loss.backward() torch.nn.utils.clip_grad_norm_(self.system.slms.parameters(), self.config.grad_clip) self.slm_optimizer.step() return {'slm_loss': total_loss.item()} def train_blm_step(self, batch: Dict) -> Dict: """ Train BLM to predict next state given oracle memory reads. Oracle: we read from the KNOWN relevant memory regions (ground truth). """ self.blm_optimizer.zero_grad() states = batch['states'] chars = batch['characteristics'] target_starts = batch['relevant_starts'] target_ends = batch['relevant_ends'] batch_size = states.shape[0] # Read oracle memory oracle_reads = [] slm_fake_outputs = [] for i in range(3): _, encoded, _ = self.system.memory.read( target_starts[:, i], target_ends[:, i] ) oracle_reads.append(encoded) # Create fake SLM output (just need hidden state) fake_hidden = torch.zeros(batch_size, 128) # SLM d_model = 128 slm_fake_outputs.append({ 'hidden': fake_hidden, 'start_addr': target_starts[:, i], 'end_addr': target_ends[:, i], 'confidence': torch.ones(batch_size), }) # BLM forward with oracle reads total_loss = None for t in range(states.shape[1] - 1): past_state = states[:, max(0, t-1), :] current_state = states[:, t, :] next_state = states[:, t+1, :] blm_out = self.system.blm( past_state, current_state, slm_fake_outputs, oracle_reads ) loss = F.mse_loss(blm_out['next_state'], next_state) if total_loss is None: total_loss = loss else: total_loss = total_loss + loss total_loss = total_loss / (states.shape[1] - 1) total_loss.backward() torch.nn.utils.clip_grad_norm_( list(self.system.blm.parameters()) + list(self.system.memory.parameters()), self.config.grad_clip ) self.blm_optimizer.step() return {'blm_loss': total_loss.item()} # ============================================================================= # Phase 2: End-to-End Joint Training # ============================================================================= class Phase2Trainer: """ Joint training of the entire system end-to-end. The full pipeline runs: SLMs → Memory Read → BLM → Next State Key challenge: gradient flow through discrete decisions - SLM address selection: use soft attention + hard address (ST trick) - BLM routing: use straight-through sigmoid Losses: 1. next_state_loss: primary prediction accuracy 2. balance_loss: balanced SLM routing 3. diversity_loss: SLMs read different memory regions 4. info_utility_loss: BLM's info request improves future predictions """ def __init__(self, system: LeWorldSystem, config: TrainingConfig): self.system = system self.config = config # Single optimizer for everything self.optimizer = optim.AdamW( system.parameters(), lr=config.phase2_lr, weight_decay=config.weight_decay ) # Learning rate scheduler self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optimizer, T_0=config.phase2_epochs // 3, T_mult=2 ) self.global_step = 0 def train_step(self, batch: Dict) -> Dict: """Full end-to-end training step.""" self.optimizer.zero_grad() states = batch['states'] chars = batch['characteristics'] # Multi-step forward output = self.system.multi_step_forward(states, chars) loss = output['total_loss'] loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_( self.system.parameters(), self.config.grad_clip ) self.optimizer.step() # Temperature annealing for router self.global_step += 1 self.system.blm.router.anneal_temperature( self.global_step, self.config.temp_anneal_rate, self.config.temp_min ) return { 'total_loss': loss.item(), 'temperature': self.system.blm.router.temperature.item(), 'step': self.global_step, } # ============================================================================= # Phase 3: Cooperative Refinement with Info-Request Loop # ============================================================================= class Phase3Trainer: """ Refinement phase: train the info-request mechanism. The BLM learns to generate useful "what info do I need?" queries that improve the SLMs' memory retrieval in the NEXT timestep. Training signal: compare prediction quality WITH vs WITHOUT info-request modulation. If info-request helped → reward; if not → penalize. This is inspired by ProactAgent (arxiv:2604.20572) paired-branch reward. """ def __init__(self, system: LeWorldSystem, config: TrainingConfig): self.system = system self.config = config # Optimizer: higher LR for info-request modules, lower for rest info_params = set(id(p) for p in system.blm.info_request.parameters()) info_params.update(id(p) for p in system.info_to_slm.parameters()) other_blm_params = [p for p in system.blm.parameters() if id(p) not in info_params] self.optimizer = optim.AdamW([ {'params': list(system.blm.info_request.parameters()) + list(system.info_to_slm.parameters()), 'lr': config.phase3_lr}, {'params': list(system.slms.parameters()), 'lr': config.phase3_lr * 0.1}, {'params': other_blm_params, 'lr': config.phase3_lr * 0.1}, {'params': list(system.memory.parameters()), 'lr': config.phase3_lr * 0.01}, ], weight_decay=config.weight_decay) def train_step(self, batch: Dict) -> Dict: """ Paired-branch training: Branch A: Run with info-request modulation (full system) Branch B: Run WITHOUT info-request (baseline) Reward = improvement of A over B """ self.optimizer.zero_grad() states = batch['states'] chars = batch['characteristics'] # Branch A: with info-request loop output_with = self.system.multi_step_forward(states, chars) loss_with = output_with['total_loss'] # Branch B: without info-request (set info_query to None at each step) # We do this by running forward without passing info_query between steps batch_size, T, state_dim = states.shape loss_without = None for t in range(T - 1): past_state = states[:, max(0, t-1), :] current_state = states[:, t, :] next_state = states[:, t+1, :] output = self.system( past_state, current_state, chars, next_state, info_query_prev=None # NO info request ) if output['losses']: if loss_without is None: loss_without = output['losses']['next_state_loss'] else: loss_without = loss_without + output['losses']['next_state_loss'] if loss_without is None: loss_without = torch.tensor(0.0) else: loss_without = loss_without / max(1, T - 1) # Info utility: reward if info-request helps, penalize if not improvement = (loss_without - loss_with).detach() # positive = info helped # Total loss: prediction loss + info utility bonus total_loss = loss_with - self.config.lambda_info_util * improvement total_loss.backward() torch.nn.utils.clip_grad_norm_(self.system.parameters(), self.config.grad_clip) self.optimizer.step() return { 'loss_with_info': loss_with.item(), 'loss_without_info': loss_without.item(), 'improvement': improvement.item(), 'total_loss': total_loss.item(), } # ============================================================================= # Memory Population Strategies # ============================================================================= class MemoryPopulator: """ Strategies for populating the artificial memory with meaningful content. In a real application, memory would be populated by experience / observations. Here we provide several strategies for initial content. """ @staticmethod def random_bits(memory: ArtificialMemory): """Fill with random bits (baseline).""" memory.memory.uniform_(0, 1).round_() @staticmethod def structured_patterns(memory: ArtificialMemory): """ Fill with structured patterns that encode different "knowledge types." Memory layout: - [0x0000 - 0x3FFF]: Dynamics patterns (state transition rules) - [0x4000 - 0x7FFF]: Context patterns (characteristic-dependent info) - [0x8000 - 0xBFFF]: History patterns (temporal sequences) - [0xC000 - 0xFFFF]: Association patterns (cross-references) """ N = memory.config.num_words W = memory.config.word_size quarter = N // 4 with torch.no_grad(): # Region 1: Dynamics — repeating patterns (easy to learn) for i in range(quarter): pattern = torch.zeros(W) pattern[i % W] = 1.0 # cyclic single-bit pattern memory.memory[i] = pattern # Region 2: Context — characteristic-dependent for i in range(quarter, 2 * quarter): seed = i - quarter torch.manual_seed(seed) memory.memory[i] = torch.randint(0, 2, (W,)).float() # Region 3: History — sequential counting in binary for i in range(2 * quarter, 3 * quarter): binary = torch.zeros(W) val = i - 2 * quarter for bit in range(min(W, 16)): binary[bit] = float((val >> bit) & 1) memory.memory[i] = binary # Region 4: Associations — XOR patterns for i in range(3 * quarter, N): a = memory.memory[i % quarter] # reference region 1 b = memory.memory[quarter + (i % quarter)] # reference region 2 memory.memory[i] = ((a + b) % 2) # XOR @staticmethod def from_experience(memory: ArtificialMemory, experiences: torch.Tensor): """ Populate memory from observed data. Args: experiences: (N, feature_dim) tensor of observed features Each feature vector gets encoded to bits and stored """ with torch.no_grad(): N = min(experiences.shape[0], memory.config.num_words) W = memory.config.word_size # Simple quantization: threshold at median for i in range(N): feat = experiences[i] # Truncate/pad to word_size if len(feat) >= W: bits = (feat[:W] > feat[:W].median()).float() else: bits = torch.zeros(W) bits[:len(feat)] = (feat > feat.median()).float() memory.memory[i] = bits # ============================================================================= # Evaluation # ============================================================================= class Evaluator: """Evaluation metrics for the LeWorld system.""" @staticmethod def prediction_accuracy( system: LeWorldSystem, dataloader: DataLoader, n_steps: int = 5 ) -> Dict: """ Evaluate next-state prediction accuracy. Metrics: - MSE: mean squared error of state predictions - MAE: mean absolute error - Multi-step MSE: prediction error at different horizons - Routing diversity: how varied the SLM selections are """ system.eval() total_mse = 0.0 total_mae = 0.0 step_mses = [0.0] * n_steps all_masks = [] n_batches = 0 with torch.no_grad(): for batch in dataloader: states = batch['states'] chars = batch['characteristics'] output = system.multi_step_forward(states, chars, n_steps) # Ground truth future states gt_future = states[:, 1:n_steps+1, :] pred_future = output['predictions'][:, :n_steps, :] actual_steps = min(n_steps, pred_future.shape[1]) mse = F.mse_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps]) mae = F.l1_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps]) total_mse += mse.item() total_mae += mae.item() # Per-step MSE for t in range(actual_steps): step_mse = F.mse_loss(pred_future[:, t], gt_future[:, t]) step_mses[t] += step_mse.item() # Collect routing masks all_masks.append(output['masks']) n_batches += 1 # Routing diversity: entropy of SLM usage all_masks = torch.cat(all_masks, dim=0) # (total, T, n_slms) usage_per_slm = all_masks.mean(dim=(0, 1)) # (n_slms,) routing_entropy = -(usage_per_slm * torch.log(usage_per_slm + 1e-8)).sum().item() system.train() return { 'mse': total_mse / max(1, n_batches), 'mae': total_mae / max(1, n_batches), 'step_mses': [m / max(1, n_batches) for m in step_mses], 'routing_entropy': routing_entropy, 'slm_usage': usage_per_slm.tolist(), } # ============================================================================= # Full Training Pipeline # ============================================================================= def run_training( system: LeWorldSystem, train_config: TrainingConfig, num_train_sequences: int = 1000, num_val_sequences: int = 200, ): """Execute the full 3-phase training pipeline.""" print("=" * 70) print("LeWorld Training Pipeline") print("=" * 70) # Populate memory with structured patterns print("\n[Setup] Populating artificial memory...") MemoryPopulator.structured_patterns(system.memory) # Create datasets print("[Setup] Generating training data...") train_dataset = StateTransitionDataset( num_sequences=num_train_sequences, seq_length=train_config.sequence_length, state_dim=train_config.state_dim, char_dim=train_config.char_dim, memory=system.memory, difficulty="medium", ) val_dataset = StateTransitionDataset( num_sequences=num_val_sequences, seq_length=train_config.sequence_length, state_dim=train_config.state_dim, char_dim=train_config.char_dim, memory=system.memory, difficulty="medium", ) train_loader = DataLoader(train_dataset, batch_size=train_config.phase1_batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=train_config.phase1_batch_size) evaluator = Evaluator() # ===== Phase 1: Pre-training ===== print(f"\n{'='*70}") print("Phase 1: Pre-training (SLMs + BLM separately)") print(f"{'='*70}") phase1 = Phase1Trainer(system, train_config) for epoch in range(min(3, train_config.phase1_epochs)): # shortened for demo epoch_slm_loss = 0 epoch_blm_loss = 0 n_batches = 0 for batch in train_loader: slm_metrics = phase1.train_slms_step(batch) blm_metrics = phase1.train_blm_step(batch) epoch_slm_loss += slm_metrics['slm_loss'] epoch_blm_loss += blm_metrics['blm_loss'] n_batches += 1 print(f" Epoch {epoch+1}: SLM loss={epoch_slm_loss/n_batches:.4f}, " f"BLM loss={epoch_blm_loss/n_batches:.4f}") # Evaluate after Phase 1 val_metrics = evaluator.prediction_accuracy(system, val_loader, n_steps=5) print(f" Phase 1 eval: MSE={val_metrics['mse']:.4f}, " f"Routing entropy={val_metrics['routing_entropy']:.4f}") # ===== Phase 2: Joint Training ===== print(f"\n{'='*70}") print("Phase 2: End-to-End Joint Training") print(f"{'='*70}") phase2 = Phase2Trainer(system, train_config) train_loader2 = DataLoader(train_dataset, batch_size=train_config.phase2_batch_size, shuffle=True) val_loader2 = DataLoader(val_dataset, batch_size=train_config.phase2_batch_size) for epoch in range(min(5, train_config.phase2_epochs)): # shortened for demo epoch_loss = 0 n_batches = 0 for batch in train_loader2: metrics = phase2.train_step(batch) epoch_loss += metrics['total_loss'] n_batches += 1 print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, " f"temp={metrics['temperature']:.4f}") val_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5) print(f" Phase 2 eval: MSE={val_metrics['mse']:.4f}, " f"Routing entropy={val_metrics['routing_entropy']:.4f}, " f"SLM usage={[f'{u:.2f}' for u in val_metrics['slm_usage']]}") # ===== Phase 3: Info-Request Refinement ===== print(f"\n{'='*70}") print("Phase 3: Info-Request Cooperative Refinement") print(f"{'='*70}") phase3 = Phase3Trainer(system, train_config) for epoch in range(min(3, train_config.phase3_epochs)): # shortened for demo epoch_loss = 0 epoch_improvement = 0 n_batches = 0 for batch in train_loader2: metrics = phase3.train_step(batch) epoch_loss += metrics['total_loss'] epoch_improvement += metrics['improvement'] n_batches += 1 print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, " f"info improvement={epoch_improvement/n_batches:.4f}") # Final evaluation print(f"\n{'='*70}") print("Final Evaluation") print(f"{'='*70}") final_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5) print(f" Final MSE: {final_metrics['mse']:.4f}") print(f" Final MAE: {final_metrics['mae']:.4f}") print(f" Per-step MSE: {[f'{m:.4f}' for m in final_metrics['step_mses']]}") print(f" Routing entropy: {final_metrics['routing_entropy']:.4f}") print(f" SLM usage: {[f'{u:.2f}' for u in final_metrics['slm_usage']]}") return final_metrics # ============================================================================= # Entry Point # ============================================================================= if __name__ == "__main__": # Build system mem_config = MemoryConfig() slm_config = SLMConfig() blm_config = BLMConfig() train_config = TrainingConfig(sequence_length=10) # shorter for demo system = LeWorldSystem(mem_config, slm_config, blm_config) count_params(system, "Full LeWorld System") # Run training metrics = run_training( system, train_config, num_train_sequences=100, # small for demo num_val_sequences=30, ) print("\n✅ Training pipeline complete!")