| """ |
| LeWorld Training System |
| ======================= |
| 3-Phase training procedure: |
| Phase 1: Pre-train components separately |
| Phase 2: End-to-end joint training |
| Phase 3: Cooperative refinement with info-request loop |
| |
| Plus: Memory population strategies, data generation, evaluation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import math |
| import random |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
|
|
| from leworld_architecture import ( |
| LeWorldSystem, MemoryConfig, SLMConfig, BLMConfig, |
| ArtificialMemory, SmallLeWorldModel, BigLeWorldModel, |
| count_params |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainingConfig: |
| """Full training configuration.""" |
| |
| phase1_lr: float = 1e-3 |
| phase1_epochs: int = 50 |
| phase1_batch_size: int = 32 |
| |
| |
| phase2_lr: float = 3e-4 |
| phase2_epochs: int = 100 |
| phase2_batch_size: int = 16 |
| phase2_warmup_steps: int = 500 |
| |
| |
| phase3_lr: float = 1e-4 |
| phase3_epochs: int = 50 |
| phase3_batch_size: int = 16 |
| |
| |
| weight_decay: float = 0.01 |
| grad_clip: float = 1.0 |
| state_dim: int = 64 |
| char_dim: int = 32 |
| sequence_length: int = 20 |
| |
| |
| lambda_balance: float = 0.01 |
| lambda_diversity: float = 0.001 |
| lambda_entropy: float = 0.01 |
| lambda_info_util: float = 0.1 |
| |
| |
| temp_anneal_rate: float = 3e-5 |
| temp_min: float = 0.1 |
|
|
|
|
| |
| |
| |
|
|
| class StateTransitionDataset(Dataset): |
| """ |
| Generates synthetic state transition sequences for training. |
| |
| Each sequence has: |
| - States that evolve according to learnable dynamics |
| - Characteristics that stay fixed per sequence |
| - Ground-truth "useful memory" labels (for Phase 1 SLM pre-training) |
| |
| The key insight: we embed patterns into memory, and the state transitions |
| DEPEND on what's in specific memory regions. This creates a genuine need |
| for memory retrieval — the model can't predict next state without reading |
| the right memory. |
| """ |
| |
| def __init__( |
| self, |
| num_sequences: int, |
| seq_length: int, |
| state_dim: int, |
| char_dim: int, |
| memory: ArtificialMemory, |
| difficulty: str = "easy", |
| ): |
| self.num_sequences = num_sequences |
| self.seq_length = seq_length |
| self.state_dim = state_dim |
| self.char_dim = char_dim |
| self.memory = memory |
| |
| |
| self.data = self._generate_sequences(difficulty) |
| |
| def _generate_sequences(self, difficulty: str) -> List[Dict]: |
| """Generate synthetic state-transition sequences.""" |
| data = [] |
| mem_size = self.memory.config.num_words |
| |
| for _ in range(self.num_sequences): |
| |
| characteristics = torch.randn(self.char_dim) |
| |
| |
| if difficulty == "easy": |
| n_relevant = 1 |
| elif difficulty == "medium": |
| n_relevant = 2 |
| else: |
| n_relevant = 3 |
| |
| relevant_addrs = [] |
| for _ in range(n_relevant): |
| start = random.randint(0, mem_size - 256) |
| length = random.randint(16, 128) |
| relevant_addrs.append((start, start + length)) |
| |
| |
| states = torch.zeros(self.seq_length, self.state_dim) |
| states[0] = torch.randn(self.state_dim) |
| |
| |
| |
| with torch.no_grad(): |
| for addr_start, addr_end in relevant_addrs: |
| mem_bits = self.memory.memory[addr_start:addr_end].mean(dim=0) |
| |
| |
| transition_seed_raw = mem_bits * 2 - 1 |
| transition_seed = transition_seed_raw.repeat( |
| math.ceil(self.state_dim / len(transition_seed_raw)) |
| )[:self.state_dim] |
| |
| |
| char_padded = characteristics.repeat( |
| math.ceil(self.state_dim / len(characteristics)) |
| )[:self.state_dim] |
| |
| for t in range(1, self.seq_length): |
| noise = torch.randn(self.state_dim) * 0.1 |
| |
| states[t] = ( |
| 0.8 * states[t-1] |
| + 0.15 * transition_seed |
| + 0.05 * char_padded |
| + noise |
| ) |
| |
| data.append({ |
| 'states': states, |
| 'characteristics': characteristics, |
| 'relevant_addrs': relevant_addrs, |
| 'n_relevant': n_relevant, |
| }) |
| |
| return data |
| |
| def __len__(self): |
| return self.num_sequences |
| |
| def __getitem__(self, idx): |
| item = self.data[idx] |
| |
| |
| padded_starts = torch.zeros(3, dtype=torch.long) |
| padded_ends = torch.zeros(3, dtype=torch.long) |
| for i, (s, e) in enumerate(item['relevant_addrs']): |
| padded_starts[i] = s |
| padded_ends[i] = e |
| |
| return { |
| 'states': item['states'], |
| 'characteristics': item['characteristics'], |
| 'relevant_starts': padded_starts, |
| 'relevant_ends': padded_ends, |
| 'n_relevant': item['n_relevant'], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class Phase1Trainer: |
| """ |
| Pre-train SLMs and BLM separately. |
| |
| SLMs: Given (past_state, current_state, characteristics), learn to output |
| address ranges that point to "relevant" memory regions. |
| Loss: distance between predicted address range and ground-truth relevant region. |
| |
| BLM: Given perfect memory reads, learn to predict next state. |
| Loss: MSE between predicted and actual next state. |
| """ |
| |
| def __init__(self, system: LeWorldSystem, config: TrainingConfig): |
| self.system = system |
| self.config = config |
| |
| |
| self.slm_optimizer = optim.AdamW( |
| system.slms.parameters(), |
| lr=config.phase1_lr, |
| weight_decay=config.weight_decay |
| ) |
| self.blm_optimizer = optim.AdamW( |
| list(system.blm.parameters()) + list(system.memory.parameters()), |
| lr=config.phase1_lr, |
| weight_decay=config.weight_decay |
| ) |
| |
| def train_slms_step(self, batch: Dict) -> Dict: |
| """ |
| Train SLMs to find relevant memory regions. |
| |
| Loss: |predicted_addr - target_addr| normalized by address space. |
| """ |
| self.slm_optimizer.zero_grad() |
| |
| states = batch['states'] |
| chars = batch['characteristics'] |
| target_starts = batch['relevant_starts'] |
| target_ends = batch['relevant_ends'] |
| |
| total_loss = None |
| |
| |
| for i, slm in enumerate(self.system.slms): |
| |
| past_state = states[:, 0, :] |
| current_state = states[:, 1, :] |
| |
| output = slm(past_state, current_state, chars) |
| |
| |
| |
| tgt_start = target_starts[:, i].long() |
| |
| half_space = slm.address_head.half_space |
| tgt_high = tgt_start // half_space |
| tgt_low = tgt_start % half_space |
| |
| |
| addr_loss = ( |
| F.cross_entropy(output['start_logits_high'], tgt_high) + |
| F.cross_entropy(output['start_logits_low'], tgt_low) |
| ) |
| |
| |
| tgt_range = (target_ends[:, i] - target_starts[:, i]).clamp(1, self.system.memory.config.max_read_range) - 1 |
| range_loss = F.cross_entropy(output['range_logits'], tgt_range.long()) |
| |
| slm_loss = addr_loss + 0.5 * range_loss |
| |
| if total_loss is None: |
| total_loss = slm_loss |
| else: |
| total_loss = total_loss + slm_loss |
| |
| total_loss = total_loss / len(self.system.slms) |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.system.slms.parameters(), self.config.grad_clip) |
| self.slm_optimizer.step() |
| |
| return {'slm_loss': total_loss.item()} |
| |
| def train_blm_step(self, batch: Dict) -> Dict: |
| """ |
| Train BLM to predict next state given oracle memory reads. |
| |
| Oracle: we read from the KNOWN relevant memory regions (ground truth). |
| """ |
| self.blm_optimizer.zero_grad() |
| |
| states = batch['states'] |
| chars = batch['characteristics'] |
| target_starts = batch['relevant_starts'] |
| target_ends = batch['relevant_ends'] |
| |
| batch_size = states.shape[0] |
| |
| |
| oracle_reads = [] |
| slm_fake_outputs = [] |
| for i in range(3): |
| _, encoded, _ = self.system.memory.read( |
| target_starts[:, i], target_ends[:, i] |
| ) |
| oracle_reads.append(encoded) |
| |
| fake_hidden = torch.zeros(batch_size, 128) |
| slm_fake_outputs.append({ |
| 'hidden': fake_hidden, |
| 'start_addr': target_starts[:, i], |
| 'end_addr': target_ends[:, i], |
| 'confidence': torch.ones(batch_size), |
| }) |
| |
| |
| total_loss = None |
| for t in range(states.shape[1] - 1): |
| past_state = states[:, max(0, t-1), :] |
| current_state = states[:, t, :] |
| next_state = states[:, t+1, :] |
| |
| blm_out = self.system.blm( |
| past_state, current_state, |
| slm_fake_outputs, oracle_reads |
| ) |
| |
| loss = F.mse_loss(blm_out['next_state'], next_state) |
| if total_loss is None: |
| total_loss = loss |
| else: |
| total_loss = total_loss + loss |
| |
| total_loss = total_loss / (states.shape[1] - 1) |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| list(self.system.blm.parameters()) + list(self.system.memory.parameters()), |
| self.config.grad_clip |
| ) |
| self.blm_optimizer.step() |
| |
| return {'blm_loss': total_loss.item()} |
|
|
|
|
| |
| |
| |
|
|
| class Phase2Trainer: |
| """ |
| Joint training of the entire system end-to-end. |
| |
| The full pipeline runs: SLMs → Memory Read → BLM → Next State |
| |
| Key challenge: gradient flow through discrete decisions |
| - SLM address selection: use soft attention + hard address (ST trick) |
| - BLM routing: use straight-through sigmoid |
| |
| Losses: |
| 1. next_state_loss: primary prediction accuracy |
| 2. balance_loss: balanced SLM routing |
| 3. diversity_loss: SLMs read different memory regions |
| 4. info_utility_loss: BLM's info request improves future predictions |
| """ |
| |
| def __init__(self, system: LeWorldSystem, config: TrainingConfig): |
| self.system = system |
| self.config = config |
| |
| |
| self.optimizer = optim.AdamW( |
| system.parameters(), |
| lr=config.phase2_lr, |
| weight_decay=config.weight_decay |
| ) |
| |
| |
| self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| self.optimizer, T_0=config.phase2_epochs // 3, T_mult=2 |
| ) |
| |
| self.global_step = 0 |
| |
| def train_step(self, batch: Dict) -> Dict: |
| """Full end-to-end training step.""" |
| self.optimizer.zero_grad() |
| |
| states = batch['states'] |
| chars = batch['characteristics'] |
| |
| |
| output = self.system.multi_step_forward(states, chars) |
| |
| loss = output['total_loss'] |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_( |
| self.system.parameters(), self.config.grad_clip |
| ) |
| |
| self.optimizer.step() |
| |
| |
| self.global_step += 1 |
| self.system.blm.router.anneal_temperature( |
| self.global_step, |
| self.config.temp_anneal_rate, |
| self.config.temp_min |
| ) |
| |
| return { |
| 'total_loss': loss.item(), |
| 'temperature': self.system.blm.router.temperature.item(), |
| 'step': self.global_step, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class Phase3Trainer: |
| """ |
| Refinement phase: train the info-request mechanism. |
| |
| The BLM learns to generate useful "what info do I need?" queries that |
| improve the SLMs' memory retrieval in the NEXT timestep. |
| |
| Training signal: compare prediction quality WITH vs WITHOUT info-request |
| modulation. If info-request helped → reward; if not → penalize. |
| |
| This is inspired by ProactAgent (arxiv:2604.20572) paired-branch reward. |
| """ |
| |
| def __init__(self, system: LeWorldSystem, config: TrainingConfig): |
| self.system = system |
| self.config = config |
| |
| |
| info_params = set(id(p) for p in system.blm.info_request.parameters()) |
| info_params.update(id(p) for p in system.info_to_slm.parameters()) |
| |
| other_blm_params = [p for p in system.blm.parameters() if id(p) not in info_params] |
| |
| self.optimizer = optim.AdamW([ |
| {'params': list(system.blm.info_request.parameters()) + list(system.info_to_slm.parameters()), 'lr': config.phase3_lr}, |
| {'params': list(system.slms.parameters()), 'lr': config.phase3_lr * 0.1}, |
| {'params': other_blm_params, 'lr': config.phase3_lr * 0.1}, |
| {'params': list(system.memory.parameters()), 'lr': config.phase3_lr * 0.01}, |
| ], weight_decay=config.weight_decay) |
| |
| def train_step(self, batch: Dict) -> Dict: |
| """ |
| Paired-branch training: |
| Branch A: Run with info-request modulation (full system) |
| Branch B: Run WITHOUT info-request (baseline) |
| Reward = improvement of A over B |
| """ |
| self.optimizer.zero_grad() |
| |
| states = batch['states'] |
| chars = batch['characteristics'] |
| |
| |
| output_with = self.system.multi_step_forward(states, chars) |
| loss_with = output_with['total_loss'] |
| |
| |
| |
| batch_size, T, state_dim = states.shape |
| loss_without = None |
| |
| for t in range(T - 1): |
| past_state = states[:, max(0, t-1), :] |
| current_state = states[:, t, :] |
| next_state = states[:, t+1, :] |
| |
| output = self.system( |
| past_state, current_state, chars, |
| next_state, info_query_prev=None |
| ) |
| if output['losses']: |
| if loss_without is None: |
| loss_without = output['losses']['next_state_loss'] |
| else: |
| loss_without = loss_without + output['losses']['next_state_loss'] |
| |
| if loss_without is None: |
| loss_without = torch.tensor(0.0) |
| else: |
| loss_without = loss_without / max(1, T - 1) |
| |
| |
| improvement = (loss_without - loss_with).detach() |
| |
| |
| total_loss = loss_with - self.config.lambda_info_util * improvement |
| |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.system.parameters(), self.config.grad_clip) |
| self.optimizer.step() |
| |
| return { |
| 'loss_with_info': loss_with.item(), |
| 'loss_without_info': loss_without.item(), |
| 'improvement': improvement.item(), |
| 'total_loss': total_loss.item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class MemoryPopulator: |
| """ |
| Strategies for populating the artificial memory with meaningful content. |
| |
| In a real application, memory would be populated by experience / observations. |
| Here we provide several strategies for initial content. |
| """ |
| |
| @staticmethod |
| def random_bits(memory: ArtificialMemory): |
| """Fill with random bits (baseline).""" |
| memory.memory.uniform_(0, 1).round_() |
| |
| @staticmethod |
| def structured_patterns(memory: ArtificialMemory): |
| """ |
| Fill with structured patterns that encode different "knowledge types." |
| |
| Memory layout: |
| - [0x0000 - 0x3FFF]: Dynamics patterns (state transition rules) |
| - [0x4000 - 0x7FFF]: Context patterns (characteristic-dependent info) |
| - [0x8000 - 0xBFFF]: History patterns (temporal sequences) |
| - [0xC000 - 0xFFFF]: Association patterns (cross-references) |
| """ |
| N = memory.config.num_words |
| W = memory.config.word_size |
| quarter = N // 4 |
| |
| with torch.no_grad(): |
| |
| for i in range(quarter): |
| pattern = torch.zeros(W) |
| pattern[i % W] = 1.0 |
| memory.memory[i] = pattern |
| |
| |
| for i in range(quarter, 2 * quarter): |
| seed = i - quarter |
| torch.manual_seed(seed) |
| memory.memory[i] = torch.randint(0, 2, (W,)).float() |
| |
| |
| for i in range(2 * quarter, 3 * quarter): |
| binary = torch.zeros(W) |
| val = i - 2 * quarter |
| for bit in range(min(W, 16)): |
| binary[bit] = float((val >> bit) & 1) |
| memory.memory[i] = binary |
| |
| |
| for i in range(3 * quarter, N): |
| a = memory.memory[i % quarter] |
| b = memory.memory[quarter + (i % quarter)] |
| memory.memory[i] = ((a + b) % 2) |
| |
| @staticmethod |
| def from_experience(memory: ArtificialMemory, experiences: torch.Tensor): |
| """ |
| Populate memory from observed data. |
| |
| Args: |
| experiences: (N, feature_dim) tensor of observed features |
| Each feature vector gets encoded to bits and stored |
| """ |
| with torch.no_grad(): |
| N = min(experiences.shape[0], memory.config.num_words) |
| W = memory.config.word_size |
| |
| |
| for i in range(N): |
| feat = experiences[i] |
| |
| if len(feat) >= W: |
| bits = (feat[:W] > feat[:W].median()).float() |
| else: |
| bits = torch.zeros(W) |
| bits[:len(feat)] = (feat > feat.median()).float() |
| memory.memory[i] = bits |
|
|
|
|
| |
| |
| |
|
|
| class Evaluator: |
| """Evaluation metrics for the LeWorld system.""" |
| |
| @staticmethod |
| def prediction_accuracy( |
| system: LeWorldSystem, |
| dataloader: DataLoader, |
| n_steps: int = 5 |
| ) -> Dict: |
| """ |
| Evaluate next-state prediction accuracy. |
| |
| Metrics: |
| - MSE: mean squared error of state predictions |
| - MAE: mean absolute error |
| - Multi-step MSE: prediction error at different horizons |
| - Routing diversity: how varied the SLM selections are |
| """ |
| system.eval() |
| total_mse = 0.0 |
| total_mae = 0.0 |
| step_mses = [0.0] * n_steps |
| all_masks = [] |
| n_batches = 0 |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| states = batch['states'] |
| chars = batch['characteristics'] |
| |
| output = system.multi_step_forward(states, chars, n_steps) |
| |
| |
| gt_future = states[:, 1:n_steps+1, :] |
| pred_future = output['predictions'][:, :n_steps, :] |
| |
| actual_steps = min(n_steps, pred_future.shape[1]) |
| |
| mse = F.mse_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps]) |
| mae = F.l1_loss(pred_future[:, :actual_steps], gt_future[:, :actual_steps]) |
| |
| total_mse += mse.item() |
| total_mae += mae.item() |
| |
| |
| for t in range(actual_steps): |
| step_mse = F.mse_loss(pred_future[:, t], gt_future[:, t]) |
| step_mses[t] += step_mse.item() |
| |
| |
| all_masks.append(output['masks']) |
| n_batches += 1 |
| |
| |
| all_masks = torch.cat(all_masks, dim=0) |
| usage_per_slm = all_masks.mean(dim=(0, 1)) |
| routing_entropy = -(usage_per_slm * torch.log(usage_per_slm + 1e-8)).sum().item() |
| |
| system.train() |
| |
| return { |
| 'mse': total_mse / max(1, n_batches), |
| 'mae': total_mae / max(1, n_batches), |
| 'step_mses': [m / max(1, n_batches) for m in step_mses], |
| 'routing_entropy': routing_entropy, |
| 'slm_usage': usage_per_slm.tolist(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def run_training( |
| system: LeWorldSystem, |
| train_config: TrainingConfig, |
| num_train_sequences: int = 1000, |
| num_val_sequences: int = 200, |
| ): |
| """Execute the full 3-phase training pipeline.""" |
| |
| print("=" * 70) |
| print("LeWorld Training Pipeline") |
| print("=" * 70) |
| |
| |
| print("\n[Setup] Populating artificial memory...") |
| MemoryPopulator.structured_patterns(system.memory) |
| |
| |
| print("[Setup] Generating training data...") |
| train_dataset = StateTransitionDataset( |
| num_sequences=num_train_sequences, |
| seq_length=train_config.sequence_length, |
| state_dim=train_config.state_dim, |
| char_dim=train_config.char_dim, |
| memory=system.memory, |
| difficulty="medium", |
| ) |
| |
| val_dataset = StateTransitionDataset( |
| num_sequences=num_val_sequences, |
| seq_length=train_config.sequence_length, |
| state_dim=train_config.state_dim, |
| char_dim=train_config.char_dim, |
| memory=system.memory, |
| difficulty="medium", |
| ) |
| |
| train_loader = DataLoader(train_dataset, batch_size=train_config.phase1_batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=train_config.phase1_batch_size) |
| |
| evaluator = Evaluator() |
| |
| |
| print(f"\n{'='*70}") |
| print("Phase 1: Pre-training (SLMs + BLM separately)") |
| print(f"{'='*70}") |
| |
| phase1 = Phase1Trainer(system, train_config) |
| |
| for epoch in range(min(3, train_config.phase1_epochs)): |
| epoch_slm_loss = 0 |
| epoch_blm_loss = 0 |
| n_batches = 0 |
| |
| for batch in train_loader: |
| slm_metrics = phase1.train_slms_step(batch) |
| blm_metrics = phase1.train_blm_step(batch) |
| |
| epoch_slm_loss += slm_metrics['slm_loss'] |
| epoch_blm_loss += blm_metrics['blm_loss'] |
| n_batches += 1 |
| |
| print(f" Epoch {epoch+1}: SLM loss={epoch_slm_loss/n_batches:.4f}, " |
| f"BLM loss={epoch_blm_loss/n_batches:.4f}") |
| |
| |
| val_metrics = evaluator.prediction_accuracy(system, val_loader, n_steps=5) |
| print(f" Phase 1 eval: MSE={val_metrics['mse']:.4f}, " |
| f"Routing entropy={val_metrics['routing_entropy']:.4f}") |
| |
| |
| print(f"\n{'='*70}") |
| print("Phase 2: End-to-End Joint Training") |
| print(f"{'='*70}") |
| |
| phase2 = Phase2Trainer(system, train_config) |
| train_loader2 = DataLoader(train_dataset, batch_size=train_config.phase2_batch_size, shuffle=True) |
| val_loader2 = DataLoader(val_dataset, batch_size=train_config.phase2_batch_size) |
| |
| for epoch in range(min(5, train_config.phase2_epochs)): |
| epoch_loss = 0 |
| n_batches = 0 |
| |
| for batch in train_loader2: |
| metrics = phase2.train_step(batch) |
| epoch_loss += metrics['total_loss'] |
| n_batches += 1 |
| |
| print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, " |
| f"temp={metrics['temperature']:.4f}") |
| |
| val_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5) |
| print(f" Phase 2 eval: MSE={val_metrics['mse']:.4f}, " |
| f"Routing entropy={val_metrics['routing_entropy']:.4f}, " |
| f"SLM usage={[f'{u:.2f}' for u in val_metrics['slm_usage']]}") |
| |
| |
| print(f"\n{'='*70}") |
| print("Phase 3: Info-Request Cooperative Refinement") |
| print(f"{'='*70}") |
| |
| phase3 = Phase3Trainer(system, train_config) |
| |
| for epoch in range(min(3, train_config.phase3_epochs)): |
| epoch_loss = 0 |
| epoch_improvement = 0 |
| n_batches = 0 |
| |
| for batch in train_loader2: |
| metrics = phase3.train_step(batch) |
| epoch_loss += metrics['total_loss'] |
| epoch_improvement += metrics['improvement'] |
| n_batches += 1 |
| |
| print(f" Epoch {epoch+1}: loss={epoch_loss/n_batches:.4f}, " |
| f"info improvement={epoch_improvement/n_batches:.4f}") |
| |
| |
| print(f"\n{'='*70}") |
| print("Final Evaluation") |
| print(f"{'='*70}") |
| |
| final_metrics = evaluator.prediction_accuracy(system, val_loader2, n_steps=5) |
| print(f" Final MSE: {final_metrics['mse']:.4f}") |
| print(f" Final MAE: {final_metrics['mae']:.4f}") |
| print(f" Per-step MSE: {[f'{m:.4f}' for m in final_metrics['step_mses']]}") |
| print(f" Routing entropy: {final_metrics['routing_entropy']:.4f}") |
| print(f" SLM usage: {[f'{u:.2f}' for u in final_metrics['slm_usage']]}") |
| |
| return final_metrics |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| |
| mem_config = MemoryConfig() |
| slm_config = SLMConfig() |
| blm_config = BLMConfig() |
| train_config = TrainingConfig(sequence_length=10) |
| |
| system = LeWorldSystem(mem_config, slm_config, blm_config) |
| count_params(system, "Full LeWorld System") |
| |
| |
| metrics = run_training( |
| system, train_config, |
| num_train_sequences=100, |
| num_val_sequences=30, |
| ) |
| |
| print("\n✅ Training pipeline complete!") |