""" LeWorld Memory Architecture — Complete Implementation ===================================================== Component 1: Artificial Memory (CPU-style bit storage) Component 2: SLMs (Small LeWorld Models, ~1.5M params each) Component 3: BLM (Big LeWorld Model, ~12M params) Component 4: Full System with training loop """ import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from typing import Tuple, List, Optional # ============================================================================= # Configuration # ============================================================================= @dataclass class MemoryConfig: """CPU-style artificial memory configuration.""" num_words: int = 65536 # 64K addressable words (like 64K RAM) word_size: int = 32 # 32 bits per word address_bits: int = 16 # 2^16 = 65536 addresses max_read_range: int = 256 # max words per single read operation @dataclass class SLMConfig: """Small LeWorld Model configuration (~1.5M params).""" d_model: int = 128 # internal dimension n_heads: int = 4 # attention heads n_layers: int = 2 # transformer layers state_dim: int = 64 # state vector dimension char_dim: int = 32 # characteristics vector dimension address_space: int = 65536 # must match MemoryConfig.num_words max_read_range: int = 256 # must match MemoryConfig.max_read_range dropout: float = 0.1 @dataclass class BLMConfig: """Big LeWorld Model configuration (~12M params).""" d_model: int = 384 # internal dimension n_heads: int = 6 # attention heads n_layers: int = 6 # transformer layers state_dim: int = 64 # state vector dimension n_slms: int = 3 # number of SLMs to route over memory_read_dim: int = 256 # dimension of encoded memory reads info_query_dim: int = 128 # dimension of "what info do I need" query dropout: float = 0.1 # ============================================================================= # Component 1: Artificial Memory # ============================================================================= class ArtificialMemory(nn.Module): """ CPU-style bit-level memory with address-range access. Stores data as actual bits (0/1 tensors), organized into addressable words. Supports: - READ(start_addr, end_addr) → returns bit block - WRITE(start_addr, data) → writes bits to memory - Bit-to-embedding projection (for neural network consumption) This mimics how a CPU accesses RAM: - Each address points to a word (32 bits) - Contiguous reads fetch a range of words - No inherent "meaning" — bits are just bits until interpreted """ def __init__(self, config: MemoryConfig): super().__init__() self.config = config # The actual memory: (num_words, word_size) binary tensor # Initialized randomly — represents "existing knowledge base" self.register_buffer( 'memory', torch.randint(0, 2, (config.num_words, config.word_size)).float() ) # Bit-to-embedding projection: converts raw bits into dense vectors # This is learnable — the system learns what bit patterns mean self.bit_encoder = nn.Sequential( nn.Linear(config.word_size, 64), nn.GELU(), nn.Linear(64, 128), nn.LayerNorm(128) ) # Write projection: converts dense vectors back to bit probabilities self.bit_decoder = nn.Sequential( nn.Linear(128, 64), nn.GELU(), nn.Linear(64, config.word_size), nn.Sigmoid() # output probabilities for each bit ) def read(self, start_addr: torch.Tensor, end_addr: torch.Tensor) -> torch.Tensor: """ Read a contiguous range of words from memory. Args: start_addr: (batch,) integer tensor of start addresses end_addr: (batch,) integer tensor of end addresses Returns: bit_block: (batch, max_range, word_size) raw bits encoded: (batch, max_range, 128) encoded memory content """ batch_size = start_addr.shape[0] max_range = self.config.max_read_range # Clamp addresses to valid range start_addr = start_addr.clamp(0, self.config.num_words - 1) end_addr = end_addr.clamp(start_addr, torch.minimum(start_addr + max_range, torch.tensor(self.config.num_words))) # Gather memory contents for each batch element # Create index tensor for the address ranges offsets = torch.arange(max_range, device=start_addr.device).unsqueeze(0) # (1, max_range) addresses = start_addr.unsqueeze(1) + offsets # (batch, max_range) addresses = addresses.clamp(0, self.config.num_words - 1) # Create validity mask (addresses within [start, end) are valid) range_lengths = (end_addr - start_addr).unsqueeze(1) # (batch, 1) valid_mask = offsets < range_lengths # (batch, max_range) # Gather bits bit_block = self.memory[addresses] # (batch, max_range, word_size) bit_block = bit_block * valid_mask.unsqueeze(-1).float() # zero out invalid # Encode bits to dense vectors encoded = self.bit_encoder(bit_block) # (batch, max_range, 128) encoded = encoded * valid_mask.unsqueeze(-1).float() return bit_block, encoded, valid_mask def write(self, start_addr: torch.Tensor, data: torch.Tensor): """ Write data to memory (differentiable soft-write). Args: start_addr: (batch,) start addresses data: (batch, n_words, 128) encoded data to write """ n_words = data.shape[1] # Decode to bit probabilities bit_probs = self.bit_decoder(data) # (batch, n_words, word_size) # Hard bits via straight-through hard_bits = (bit_probs > 0.5).float() bits_to_write = hard_bits - bit_probs.detach() + bit_probs # ST trick # Write to memory (last batch element wins for simplicity) for b in range(start_addr.shape[0]): addr = start_addr[b].long() end = min(addr + n_words, self.config.num_words) actual_n = end - addr self.memory[addr:end] = bits_to_write[b, :actual_n].detach() def soft_read(self, attention_weights: torch.Tensor) -> torch.Tensor: """ Content-based soft read using attention weights over entire memory. Used for differentiable end-to-end training. Args: attention_weights: (batch, num_words) soft address distribution Returns: encoded: (batch, 128) weighted memory content """ # Encode all memory (expensive but differentiable) all_encoded = self.bit_encoder(self.memory) # (num_words, 128) # Weighted sum encoded = torch.matmul(attention_weights, all_encoded) # (batch, 128) return encoded # ============================================================================= # Component 2: Small LeWorld Model (SLM) # ============================================================================= class StateEncoder(nn.Module): """Encodes past_state and current_state into a joint representation.""" def __init__(self, state_dim: int, d_model: int): super().__init__() self.past_proj = nn.Linear(state_dim, d_model) self.curr_proj = nn.Linear(state_dim, d_model) self.combiner = nn.Sequential( nn.Linear(d_model * 2, d_model), nn.GELU(), nn.LayerNorm(d_model) ) def forward(self, past_state: torch.Tensor, current_state: torch.Tensor) -> torch.Tensor: """ Args: past_state: (batch, state_dim) current_state: (batch, state_dim) Returns: combined: (batch, d_model) """ past_enc = F.gelu(self.past_proj(past_state)) curr_enc = F.gelu(self.curr_proj(current_state)) combined = self.combiner(torch.cat([past_enc, curr_enc], dim=-1)) return combined class CharacteristicsEncoder(nn.Module): """Encodes static characteristics/context.""" def __init__(self, char_dim: int, d_model: int): super().__init__() self.encoder = nn.Sequential( nn.Linear(char_dim, d_model), nn.GELU(), nn.LayerNorm(d_model) ) def forward(self, characteristics: torch.Tensor) -> torch.Tensor: return self.encoder(characteristics) class TransformerBlock(nn.Module): """Standard transformer block with pre-norm.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.norm2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Self-attention with pre-norm normed = self.norm1(x) attn_out, _ = self.attn(normed, normed, normed) x = x + attn_out # FFN with pre-norm x = x + self.ffn(self.norm2(x)) return x class CrossAttentionBlock(nn.Module): """Cross-attention: state attends to characteristics.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() self.norm_q = nn.LayerNorm(d_model) self.norm_kv = nn.LayerNorm(d_model) self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.norm_ff = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout) ) def forward(self, query: torch.Tensor, context: torch.Tensor) -> torch.Tensor: normed_q = self.norm_q(query) normed_kv = self.norm_kv(context) attn_out, _ = self.cross_attn(normed_q, normed_kv, normed_kv) x = query + attn_out x = x + self.ffn(self.norm_ff(x)) return x class AddressHead(nn.Module): """ Produces memory address range (start_addr, end_addr) from hidden state. Uses two approaches: 1. HARD mode: argmax over address space (for inference) 2. SOFT mode: attention weights over memory (for differentiable training) """ def __init__(self, d_model: int, address_space: int, max_range: int): super().__init__() self.address_space = address_space self.max_range = max_range # Produce start address logits # We don't have a linear over 65K — that's too many params # Instead: predict address as composition of sub-addresses (like product keys) self.addr_bits = int(math.log2(address_space)) # 16 for 65536 assert 2 ** self.addr_bits == address_space, "address_space must be power of 2" # Split address into high byte and low byte (8+8 = 16 bits) self.half_bits = self.addr_bits // 2 # 8 self.half_space = 2 ** self.half_bits # 256 # Predict high and low parts separately (product key approach) self.start_high = nn.Linear(d_model, self.half_space) # 256 outputs self.start_low = nn.Linear(d_model, self.half_space) # 256 outputs # Predict range length (how many words to read) self.range_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, max_range) ) # Confidence head self.confidence_head = nn.Sequential( nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Linear(d_model // 4, 1), nn.Sigmoid() ) def forward(self, hidden: torch.Tensor) -> dict: """ Args: hidden: (batch, d_model) Returns: dict with: start_addr: (batch,) integer addresses end_addr: (batch,) integer addresses range_length: (batch,) how many words to read confidence: (batch,) read confidence score start_logits_high: (batch, 256) for soft addressing start_logits_low: (batch, 256) for soft addressing range_logits: (batch, max_range) for soft range selection """ batch_size = hidden.shape[0] # Product-key address generation high_logits = self.start_high(hidden) # (batch, 256) low_logits = self.start_low(hidden) # (batch, 256) # Hard address via argmax high_idx = high_logits.argmax(dim=-1) # (batch,) low_idx = low_logits.argmax(dim=-1) # (batch,) start_addr = high_idx * self.half_space + low_idx # (batch,) 0..65535 # Range length range_logits = self.range_head(hidden) # (batch, max_range) range_length = range_logits.argmax(dim=-1) + 1 # (batch,) 1..max_range end_addr = (start_addr + range_length).clamp(max=self.address_space - 1) # Confidence confidence = self.confidence_head(hidden).squeeze(-1) # (batch,) return { 'start_addr': start_addr, 'end_addr': end_addr, 'range_length': range_length, 'confidence': confidence, 'start_logits_high': high_logits, 'start_logits_low': low_logits, 'range_logits': range_logits, } class SmallLeWorldModel(nn.Module): """ SLM: Small LeWorld Model (~1.5M params) Takes (past_state, current_state, characteristics) and produces a memory address range pointing to the most useful memory for next-state prediction. Architecture: 1. Encode past + current state → state representation 2. Encode characteristics 3. Cross-attend: state attends to characteristics 4. Self-attention transformer layers 5. Address head: output (start_addr, end_addr, confidence) """ def __init__(self, config: SLMConfig, slm_id: int = 0): super().__init__() self.config = config self.slm_id = slm_id # Encoders self.state_encoder = StateEncoder(config.state_dim, config.d_model) self.char_encoder = CharacteristicsEncoder(config.char_dim, config.d_model) # Cross-attention: state ← characteristics self.cross_attn = CrossAttentionBlock(config.d_model, config.n_heads, config.dropout) # Self-attention transformer self.transformer_layers = nn.ModuleList([ TransformerBlock(config.d_model, config.n_heads, config.dropout) for _ in range(config.n_layers) ]) self.final_norm = nn.LayerNorm(config.d_model) # Address output head self.address_head = AddressHead(config.d_model, config.address_space, config.max_read_range) def forward( self, past_state: torch.Tensor, # (batch, state_dim) current_state: torch.Tensor, # (batch, state_dim) characteristics: torch.Tensor, # (batch, char_dim) ) -> dict: """ Forward pass: state + characteristics → memory address range. Returns dict with address info + internal hidden state. """ # Encode states state_repr = self.state_encoder(past_state, current_state) # (batch, d_model) # Encode characteristics char_repr = self.char_encoder(characteristics) # (batch, d_model) # Cross-attention: state queries characteristics # Unsqueeze to sequence dim for attention state_seq = state_repr.unsqueeze(1) # (batch, 1, d_model) char_seq = char_repr.unsqueeze(1) # (batch, 1, d_model) enriched = self.cross_attn(state_seq, char_seq) # (batch, 1, d_model) # Self-attention layers hidden = enriched for layer in self.transformer_layers: hidden = layer(hidden) hidden = self.final_norm(hidden) hidden = hidden.squeeze(1) # (batch, d_model) # Produce address range addr_output = self.address_head(hidden) addr_output['hidden'] = hidden # keep for BLM to use addr_output['slm_id'] = self.slm_id return addr_output # ============================================================================= # Component 3: Big LeWorld Model (BLM) # ============================================================================= class StraightThroughSigmoid(torch.autograd.Function): """ Binary gate: hard 0/1 in forward, sigmoid gradient in backward. From literature: ST-GS (Jang et al. 2017) + Switch Transformer routing. """ @staticmethod def forward(ctx, logits): probs = torch.sigmoid(logits) ctx.save_for_backward(probs) return (probs > 0.5).float() @staticmethod def backward(ctx, grad_output): probs, = ctx.saved_tensors # Sigmoid derivative: p * (1-p) return grad_output * probs * (1 - probs) class BLMRouter(nn.Module): """ Routes/selects which SLMs to activate. Produces binary mask like [1, 0, 1]. Uses Straight-Through Sigmoid for differentiable binary selection. Includes load-balancing loss to prevent degenerate routing. """ def __init__(self, d_model: int, n_slms: int): super().__init__() self.n_slms = n_slms self.gate = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, n_slms) ) # Temperature for annealing (start warm, cool down) self.register_buffer('temperature', torch.tensor(1.0)) def forward(self, state_repr: torch.Tensor) -> Tuple[torch.Tensor, dict]: """ Args: state_repr: (batch, d_model) encoded current state Returns: binary_mask: (batch, n_slms) hard 0/1 selection routing_info: dict with probs, losses, etc. """ logits = self.gate(state_repr) # (batch, n_slms) # Scale by temperature scaled_logits = logits / self.temperature.clamp(min=0.1) probs = torch.sigmoid(scaled_logits) # (batch, n_slms) # Straight-through binary: hard in forward, soft in backward hard_mask = (probs > 0.5).float() binary_mask = hard_mask - probs.detach() + probs # THE ST TRICK # Ensure at least one SLM is selected (don't want all zeros) # If all zeros, force-select the highest probability SLM all_zero = (binary_mask.sum(dim=-1) == 0) # (batch,) if all_zero.any(): max_idx = probs[all_zero].argmax(dim=-1) forced = torch.zeros_like(probs[all_zero]) forced.scatter_(1, max_idx.unsqueeze(1), 1.0) binary_mask[all_zero] = forced # Load balance loss: encourage roughly equal usage of SLMs usage_per_slm = binary_mask.mean(dim=0) # (n_slms,) target_usage = 1.0 / self.n_slms balance_loss = ((usage_per_slm - target_usage) ** 2).sum() # Entropy loss: encourage decisive routing (not all ~0.5) entropy = -(probs * torch.log(probs + 1e-8) + (1 - probs) * torch.log(1 - probs + 1e-8)) entropy_loss = entropy.mean() routing_info = { 'probs': probs, 'binary_mask': binary_mask, 'balance_loss': balance_loss, 'entropy_loss': entropy_loss, 'logits': logits, } return binary_mask, routing_info def anneal_temperature(self, step: int, anneal_rate: float = 3e-5, min_temp: float = 0.1): """Anneal temperature: start warm (exploratory), cool down (decisive).""" new_temp = max(min_temp, math.exp(-anneal_rate * step)) self.temperature.fill_(new_temp) class InfoRequestHead(nn.Module): """ Produces a query vector representing "what information do I need next?" This is the key innovation: instead of passively receiving all SLM outputs, the BLM actively requests specific information. This query modulates which memory regions the SLMs should focus on in the NEXT timestep. """ def __init__(self, d_model: int, query_dim: int): super().__init__() self.query_generator = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, query_dim), nn.LayerNorm(query_dim) ) def forward(self, hidden: torch.Tensor) -> torch.Tensor: """ Args: hidden: (batch, d_model) BLM's internal state Returns: info_query: (batch, query_dim) "what do I need next?" """ return self.query_generator(hidden) class BigLeWorldModel(nn.Module): """ BLM: Big LeWorld Model (~12M params) Two roles: 1. ROUTER: Select which SLMs to activate (binary mask) 2. PREDICTOR: Given selected memory contents, predict next state Plus: Info-Request Head that asks "what information is needed next?" Architecture: 1. Encode current state → routing decision 2. Receive memory reads from selected SLMs 3. Transformer processes (state + memories) 4. Predict next state 5. Generate info request for next timestep """ def __init__(self, config: BLMConfig): super().__init__() self.config = config # State encoder (maps state_dim → d_model) self.state_encoder = nn.Sequential( nn.Linear(config.state_dim, config.d_model), nn.GELU(), nn.LayerNorm(config.d_model) ) # Memory read encoder (maps encoded memory → d_model) self.memory_encoder = nn.Sequential( nn.Linear(128, config.d_model), # 128 from ArtificialMemory bit_encoder nn.GELU(), nn.LayerNorm(config.d_model) ) # SLM hidden state encoder (maps SLM hidden → d_model) self.slm_hidden_encoder = nn.Sequential( nn.Linear(128, config.d_model), # 128 = SLM d_model nn.GELU(), nn.LayerNorm(config.d_model) ) # Router: selects which SLMs to use self.router = BLMRouter(config.d_model, config.n_slms) # Transformer backbone self.transformer_layers = nn.ModuleList([ TransformerBlock(config.d_model, config.n_heads, config.dropout) for _ in range(config.n_layers) ]) self.final_norm = nn.LayerNorm(config.d_model) # Prediction heads self.next_state_head = nn.Sequential( nn.Linear(config.d_model, config.d_model), nn.GELU(), nn.Linear(config.d_model, config.state_dim) ) # Info request head: "what do I need next?" self.info_request = InfoRequestHead(config.d_model, config.info_query_dim) # Learnable tokens self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) self.state_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) self.memory_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) def forward( self, past_state: torch.Tensor, # (batch, state_dim) current_state: torch.Tensor, # (batch, state_dim) slm_outputs: List[dict], # list of SLM output dicts memory_reads: List[torch.Tensor], # list of (batch, range, 128) encoded memory info_query_prev: Optional[torch.Tensor] = None, # (batch, query_dim) from previous step ) -> dict: """ Full BLM forward pass. Returns: dict with next_state, binary_mask, info_query, losses, etc. """ batch_size = current_state.shape[0] # 1. Encode current state for routing decision state_enc = self.state_encoder(current_state) # (batch, d_model) # 2. Route: select which SLMs to use binary_mask, routing_info = self.router(state_enc) # (batch, n_slms) # 3. Aggregate selected memory reads # For each SLM, apply its binary gate and encode its memory read memory_tokens = [] for i, (slm_out, mem_read) in enumerate(zip(slm_outputs, memory_reads)): gate = binary_mask[:, i:i+1] # (batch, 1) # Gate the SLM's hidden representation slm_hidden = self.slm_hidden_encoder(slm_out['hidden']) # (batch, d_model) slm_hidden = slm_hidden * gate # zero if SLM not selected # Gate and encode the memory read # mem_read: (batch, range_len, 128) mem_enc = self.memory_encoder(mem_read) # (batch, range_len, d_model) mem_enc = mem_enc * gate.unsqueeze(-1) # zero if SLM not selected # Pool memory read to single token (mean pool over range) mem_pooled = mem_enc.mean(dim=1, keepdim=True) # (batch, 1, d_model) memory_tokens.append(slm_hidden.unsqueeze(1)) # SLM hidden as token memory_tokens.append(mem_pooled) # memory content as token # 4. Build input sequence for transformer # [CLS] + [state] + [slm_0_hidden, slm_0_mem, slm_1_hidden, slm_1_mem, ...] cls = self.cls_token.expand(batch_size, -1, -1) state_token = state_enc.unsqueeze(1) + self.state_type_embed # (batch, 1, d_model) # Add memory type embedding to memory tokens mem_sequence = torch.cat(memory_tokens, dim=1) # (batch, 2*n_slms, d_model) mem_sequence = mem_sequence + self.memory_type_embed sequence = torch.cat([cls, state_token, mem_sequence], dim=1) # Shape: (batch, 1 + 1 + 2*n_slms, d_model) # 5. Transformer processing hidden = sequence for layer in self.transformer_layers: hidden = layer(hidden) hidden = self.final_norm(hidden) # 6. Extract predictions from CLS token cls_output = hidden[:, 0, :] # (batch, d_model) # 7. Predict next state next_state_pred = self.next_state_head(cls_output) # (batch, state_dim) # 8. Generate info request for next timestep info_query = self.info_request(cls_output) # (batch, query_dim) return { 'next_state': next_state_pred, 'binary_mask': binary_mask, 'info_query': info_query, 'routing_info': routing_info, 'cls_output': cls_output, } # ============================================================================= # Component 4: Full LeWorld System # ============================================================================= class LeWorldSystem(nn.Module): """ Complete LeWorld Memory Architecture. Orchestrates: - Artificial Memory (bit-level storage) - 3 SLMs (produce memory address ranges) - 1 BLM (selects SLMs, reads memory, predicts next state) Training loop: 1. BLM sees current state → routes to SLMs 2. Selected SLMs produce address ranges 3. Memory is read at those ranges 4. BLM aggregates memory + state → predicts next state 5. BLM generates info-request for next step Losses: - next_state_loss: MSE between predicted and actual next state - routing_balance_loss: encourage balanced SLM usage - address_diversity_loss: encourage SLMs to read different memory regions - info_utility_loss: did the info request lead to useful retrievals? """ def __init__( self, mem_config: MemoryConfig = MemoryConfig(), slm_config: SLMConfig = SLMConfig(), blm_config: BLMConfig = BLMConfig(), ): super().__init__() # Artificial Memory self.memory = ArtificialMemory(mem_config) # 3 SLMs self.slms = nn.ModuleList([ SmallLeWorldModel(slm_config, slm_id=i) for i in range(blm_config.n_slms) ]) # BLM self.blm = BigLeWorldModel(blm_config) # Info-query → SLM modulation: the BLM's info request # influences what SLMs look for in the next timestep self.info_to_slm = nn.Linear(blm_config.info_query_dim, slm_config.state_dim) self.config = { 'mem': mem_config, 'slm': slm_config, 'blm': blm_config, } def forward( self, past_state: torch.Tensor, # (batch, state_dim) current_state: torch.Tensor, # (batch, state_dim) characteristics: torch.Tensor, # (batch, char_dim) next_state_target: Optional[torch.Tensor] = None, # (batch, state_dim) for training info_query_prev: Optional[torch.Tensor] = None, # from previous timestep ) -> dict: """ Full system forward pass. """ batch_size = current_state.shape[0] # If we have a previous info query, modulate the current state # This is how the BLM's "what do I need?" influences retrieval if info_query_prev is not None: info_modulation = self.info_to_slm(info_query_prev) # (batch, state_dim) modulated_state = current_state + 0.1 * info_modulation # gentle modulation else: modulated_state = current_state # 1. Run all 3 SLMs to get address ranges slm_outputs = [] for slm in self.slms: out = slm(past_state, modulated_state, characteristics) slm_outputs.append(out) # 2. Read memory at each SLM's address range memory_reads = [] for slm_out in slm_outputs: _, encoded, valid_mask = self.memory.read( slm_out['start_addr'], slm_out['end_addr'] ) memory_reads.append(encoded) # 3. BLM processes everything blm_output = self.blm( past_state, current_state, slm_outputs, memory_reads, info_query_prev ) # 4. Compute losses if training losses = {} if next_state_target is not None: # Primary loss: next state prediction losses['next_state_loss'] = F.mse_loss( blm_output['next_state'], next_state_target ) # Routing balance loss losses['balance_loss'] = blm_output['routing_info']['balance_loss'] # Address diversity loss: penalize SLMs for reading same regions addresses = torch.stack([ slm_out['start_addr'].float() for slm_out in slm_outputs ], dim=1) # (batch, n_slms) # Pairwise distance between SLM addresses (want to maximize) addr_diff = torch.cdist(addresses.unsqueeze(-1), addresses.unsqueeze(-1)) diversity_loss = -addr_diff.mean() # negative = encourage large distances losses['diversity_loss'] = diversity_loss # Total loss losses['total_loss'] = ( losses['next_state_loss'] + 0.01 * losses['balance_loss'] + 0.001 * losses['diversity_loss'] ) return { 'next_state': blm_output['next_state'], 'binary_mask': blm_output['binary_mask'], 'info_query': blm_output['info_query'], 'slm_outputs': slm_outputs, 'memory_reads': memory_reads, 'losses': losses, 'routing_info': blm_output['routing_info'], } def multi_step_forward( self, states: torch.Tensor, # (batch, T, state_dim) sequence of states characteristics: torch.Tensor, # (batch, char_dim) static n_steps: int = None, ) -> dict: """ Run the system over multiple timesteps autoregressively. For training: teacher forcing with ground-truth states """ batch_size, T, state_dim = states.shape if n_steps is None: n_steps = T - 1 # predict all future states all_predictions = [] all_masks = [] total_loss = None info_query = None for t in range(min(n_steps, T - 1)): past_state = states[:, max(0, t-1), :] current_state = states[:, t, :] next_state_target = states[:, t+1, :] output = self.forward( past_state, current_state, characteristics, next_state_target, info_query ) all_predictions.append(output['next_state']) all_masks.append(output['binary_mask']) info_query = output['info_query'] if output['losses']: if total_loss is None: total_loss = output['losses']['total_loss'] else: total_loss = total_loss + output['losses']['total_loss'] if total_loss is None: total_loss = torch.tensor(0.0, device=states.device) return { 'predictions': torch.stack(all_predictions, dim=1), 'masks': torch.stack(all_masks, dim=1), 'total_loss': total_loss / max(1, min(n_steps, T - 1)), 'final_info_query': info_query, } # ============================================================================= # Parameter Count Verification # ============================================================================= def count_params(model, name="Model"): """Count and display parameter breakdown.""" total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n{'='*60}") print(f"{name}: {total:,} total params ({trainable:,} trainable)") print(f"{'='*60}") for child_name, child in model.named_children(): child_params = sum(p.numel() for p in child.parameters()) if child_params > 0: print(f" {child_name}: {child_params:,}") return total # ============================================================================= # Demo / Test # ============================================================================= if __name__ == "__main__": print("LeWorld Memory Architecture — Component Verification") print("=" * 60) # Configs mem_config = MemoryConfig() slm_config = SLMConfig() blm_config = BLMConfig() # Build system system = LeWorldSystem(mem_config, slm_config, blm_config) # Count parameters print("\n--- Parameter Counts ---") count_params(system.memory, "Artificial Memory") for i, slm in enumerate(system.slms): count_params(slm, f"SLM-{i}") count_params(system.blm, "BLM") count_params(system, "Full System") # Test forward pass print("\n--- Forward Pass Test ---") batch_size = 4 state_dim = slm_config.state_dim char_dim = slm_config.char_dim past_state = torch.randn(batch_size, state_dim) current_state = torch.randn(batch_size, state_dim) characteristics = torch.randn(batch_size, char_dim) next_state = torch.randn(batch_size, state_dim) output = system(past_state, current_state, characteristics, next_state) print(f"Next state prediction shape: {output['next_state'].shape}") print(f"Binary mask (SLM selection): {output['binary_mask']}") print(f"Info query shape: {output['info_query'].shape}") print(f"Losses: {output['losses']}") # Test multi-step print("\n--- Multi-Step Test ---") T = 10 states = torch.randn(batch_size, T, state_dim) ms_output = system.multi_step_forward(states, characteristics) print(f"Predictions shape: {ms_output['predictions'].shape}") print(f"Masks shape: {ms_output['masks'].shape}") print(f"Average loss: {ms_output['total_loss'].item():.4f}") # Show routing patterns over time print("\n--- Routing Patterns Over Time ---") masks = ms_output['masks'][0].detach() # first batch element for t in range(masks.shape[0]): mask = masks[t].int().tolist() print(f" Step {t}: SLMs selected = {mask}") print("\n✅ All components verified successfully!")