| """ |
| LeWorld Memory Architecture — Complete Implementation |
| ===================================================== |
| Component 1: Artificial Memory (CPU-style bit storage) |
| Component 2: SLMs (Small LeWorld Models, ~1.5M params each) |
| Component 3: BLM (Big LeWorld Model, ~12M params) |
| Component 4: Full System with training loop |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from dataclasses import dataclass |
| from typing import Tuple, List, Optional |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class MemoryConfig: |
| """CPU-style artificial memory configuration.""" |
| num_words: int = 65536 |
| word_size: int = 32 |
| address_bits: int = 16 |
| max_read_range: int = 256 |
|
|
|
|
| @dataclass |
| class SLMConfig: |
| """Small LeWorld Model configuration (~1.5M params).""" |
| d_model: int = 128 |
| n_heads: int = 4 |
| n_layers: int = 2 |
| state_dim: int = 64 |
| char_dim: int = 32 |
| address_space: int = 65536 |
| max_read_range: int = 256 |
| dropout: float = 0.1 |
|
|
|
|
| @dataclass |
| class BLMConfig: |
| """Big LeWorld Model configuration (~12M params).""" |
| d_model: int = 384 |
| n_heads: int = 6 |
| n_layers: int = 6 |
| state_dim: int = 64 |
| n_slms: int = 3 |
| memory_read_dim: int = 256 |
| info_query_dim: int = 128 |
| dropout: float = 0.1 |
|
|
|
|
| |
| |
| |
|
|
| class ArtificialMemory(nn.Module): |
| """ |
| CPU-style bit-level memory with address-range access. |
| |
| Stores data as actual bits (0/1 tensors), organized into addressable words. |
| Supports: |
| - READ(start_addr, end_addr) → returns bit block |
| - WRITE(start_addr, data) → writes bits to memory |
| - Bit-to-embedding projection (for neural network consumption) |
| |
| This mimics how a CPU accesses RAM: |
| - Each address points to a word (32 bits) |
| - Contiguous reads fetch a range of words |
| - No inherent "meaning" — bits are just bits until interpreted |
| """ |
| |
| def __init__(self, config: MemoryConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| |
| self.register_buffer( |
| 'memory', |
| torch.randint(0, 2, (config.num_words, config.word_size)).float() |
| ) |
| |
| |
| |
| self.bit_encoder = nn.Sequential( |
| nn.Linear(config.word_size, 64), |
| nn.GELU(), |
| nn.Linear(64, 128), |
| nn.LayerNorm(128) |
| ) |
| |
| |
| self.bit_decoder = nn.Sequential( |
| nn.Linear(128, 64), |
| nn.GELU(), |
| nn.Linear(64, config.word_size), |
| nn.Sigmoid() |
| ) |
| |
| def read(self, start_addr: torch.Tensor, end_addr: torch.Tensor) -> torch.Tensor: |
| """ |
| Read a contiguous range of words from memory. |
| |
| Args: |
| start_addr: (batch,) integer tensor of start addresses |
| end_addr: (batch,) integer tensor of end addresses |
| |
| Returns: |
| bit_block: (batch, max_range, word_size) raw bits |
| encoded: (batch, max_range, 128) encoded memory content |
| """ |
| batch_size = start_addr.shape[0] |
| max_range = self.config.max_read_range |
| |
| |
| start_addr = start_addr.clamp(0, self.config.num_words - 1) |
| end_addr = end_addr.clamp(start_addr, |
| torch.minimum(start_addr + max_range, |
| torch.tensor(self.config.num_words))) |
| |
| |
| |
| offsets = torch.arange(max_range, device=start_addr.device).unsqueeze(0) |
| addresses = start_addr.unsqueeze(1) + offsets |
| addresses = addresses.clamp(0, self.config.num_words - 1) |
| |
| |
| range_lengths = (end_addr - start_addr).unsqueeze(1) |
| valid_mask = offsets < range_lengths |
| |
| |
| bit_block = self.memory[addresses] |
| bit_block = bit_block * valid_mask.unsqueeze(-1).float() |
| |
| |
| encoded = self.bit_encoder(bit_block) |
| encoded = encoded * valid_mask.unsqueeze(-1).float() |
| |
| return bit_block, encoded, valid_mask |
| |
| def write(self, start_addr: torch.Tensor, data: torch.Tensor): |
| """ |
| Write data to memory (differentiable soft-write). |
| |
| Args: |
| start_addr: (batch,) start addresses |
| data: (batch, n_words, 128) encoded data to write |
| """ |
| n_words = data.shape[1] |
| |
| |
| bit_probs = self.bit_decoder(data) |
| |
| |
| hard_bits = (bit_probs > 0.5).float() |
| bits_to_write = hard_bits - bit_probs.detach() + bit_probs |
| |
| |
| for b in range(start_addr.shape[0]): |
| addr = start_addr[b].long() |
| end = min(addr + n_words, self.config.num_words) |
| actual_n = end - addr |
| self.memory[addr:end] = bits_to_write[b, :actual_n].detach() |
| |
| def soft_read(self, attention_weights: torch.Tensor) -> torch.Tensor: |
| """ |
| Content-based soft read using attention weights over entire memory. |
| Used for differentiable end-to-end training. |
| |
| Args: |
| attention_weights: (batch, num_words) soft address distribution |
| |
| Returns: |
| encoded: (batch, 128) weighted memory content |
| """ |
| |
| all_encoded = self.bit_encoder(self.memory) |
| |
| encoded = torch.matmul(attention_weights, all_encoded) |
| return encoded |
|
|
|
|
| |
| |
| |
|
|
| class StateEncoder(nn.Module): |
| """Encodes past_state and current_state into a joint representation.""" |
| |
| def __init__(self, state_dim: int, d_model: int): |
| super().__init__() |
| self.past_proj = nn.Linear(state_dim, d_model) |
| self.curr_proj = nn.Linear(state_dim, d_model) |
| self.combiner = nn.Sequential( |
| nn.Linear(d_model * 2, d_model), |
| nn.GELU(), |
| nn.LayerNorm(d_model) |
| ) |
| |
| def forward(self, past_state: torch.Tensor, current_state: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| past_state: (batch, state_dim) |
| current_state: (batch, state_dim) |
| Returns: |
| combined: (batch, d_model) |
| """ |
| past_enc = F.gelu(self.past_proj(past_state)) |
| curr_enc = F.gelu(self.curr_proj(current_state)) |
| combined = self.combiner(torch.cat([past_enc, curr_enc], dim=-1)) |
| return combined |
|
|
|
|
| class CharacteristicsEncoder(nn.Module): |
| """Encodes static characteristics/context.""" |
| |
| def __init__(self, char_dim: int, d_model: int): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Linear(char_dim, d_model), |
| nn.GELU(), |
| nn.LayerNorm(d_model) |
| ) |
| |
| def forward(self, characteristics: torch.Tensor) -> torch.Tensor: |
| return self.encoder(characteristics) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Standard transformer block with pre-norm.""" |
| |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(d_model) |
| self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| normed = self.norm1(x) |
| attn_out, _ = self.attn(normed, normed, normed) |
| x = x + attn_out |
| |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class CrossAttentionBlock(nn.Module): |
| """Cross-attention: state attends to characteristics.""" |
| |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): |
| super().__init__() |
| self.norm_q = nn.LayerNorm(d_model) |
| self.norm_kv = nn.LayerNorm(d_model) |
| self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm_ff = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward(self, query: torch.Tensor, context: torch.Tensor) -> torch.Tensor: |
| normed_q = self.norm_q(query) |
| normed_kv = self.norm_kv(context) |
| attn_out, _ = self.cross_attn(normed_q, normed_kv, normed_kv) |
| x = query + attn_out |
| x = x + self.ffn(self.norm_ff(x)) |
| return x |
|
|
|
|
| class AddressHead(nn.Module): |
| """ |
| Produces memory address range (start_addr, end_addr) from hidden state. |
| |
| Uses two approaches: |
| 1. HARD mode: argmax over address space (for inference) |
| 2. SOFT mode: attention weights over memory (for differentiable training) |
| """ |
| |
| def __init__(self, d_model: int, address_space: int, max_range: int): |
| super().__init__() |
| self.address_space = address_space |
| self.max_range = max_range |
| |
| |
| |
| |
| self.addr_bits = int(math.log2(address_space)) |
| assert 2 ** self.addr_bits == address_space, "address_space must be power of 2" |
| |
| |
| self.half_bits = self.addr_bits // 2 |
| self.half_space = 2 ** self.half_bits |
| |
| |
| self.start_high = nn.Linear(d_model, self.half_space) |
| self.start_low = nn.Linear(d_model, self.half_space) |
| |
| |
| self.range_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, max_range) |
| ) |
| |
| |
| self.confidence_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 4), |
| nn.GELU(), |
| nn.Linear(d_model // 4, 1), |
| nn.Sigmoid() |
| ) |
| |
| def forward(self, hidden: torch.Tensor) -> dict: |
| """ |
| Args: |
| hidden: (batch, d_model) |
| |
| Returns: |
| dict with: |
| start_addr: (batch,) integer addresses |
| end_addr: (batch,) integer addresses |
| range_length: (batch,) how many words to read |
| confidence: (batch,) read confidence score |
| start_logits_high: (batch, 256) for soft addressing |
| start_logits_low: (batch, 256) for soft addressing |
| range_logits: (batch, max_range) for soft range selection |
| """ |
| batch_size = hidden.shape[0] |
| |
| |
| high_logits = self.start_high(hidden) |
| low_logits = self.start_low(hidden) |
| |
| |
| high_idx = high_logits.argmax(dim=-1) |
| low_idx = low_logits.argmax(dim=-1) |
| start_addr = high_idx * self.half_space + low_idx |
| |
| |
| range_logits = self.range_head(hidden) |
| range_length = range_logits.argmax(dim=-1) + 1 |
| end_addr = (start_addr + range_length).clamp(max=self.address_space - 1) |
| |
| |
| confidence = self.confidence_head(hidden).squeeze(-1) |
| |
| return { |
| 'start_addr': start_addr, |
| 'end_addr': end_addr, |
| 'range_length': range_length, |
| 'confidence': confidence, |
| 'start_logits_high': high_logits, |
| 'start_logits_low': low_logits, |
| 'range_logits': range_logits, |
| } |
|
|
|
|
| class SmallLeWorldModel(nn.Module): |
| """ |
| SLM: Small LeWorld Model (~1.5M params) |
| |
| Takes (past_state, current_state, characteristics) and produces |
| a memory address range pointing to the most useful memory for |
| next-state prediction. |
| |
| Architecture: |
| 1. Encode past + current state → state representation |
| 2. Encode characteristics |
| 3. Cross-attend: state attends to characteristics |
| 4. Self-attention transformer layers |
| 5. Address head: output (start_addr, end_addr, confidence) |
| """ |
| |
| def __init__(self, config: SLMConfig, slm_id: int = 0): |
| super().__init__() |
| self.config = config |
| self.slm_id = slm_id |
| |
| |
| self.state_encoder = StateEncoder(config.state_dim, config.d_model) |
| self.char_encoder = CharacteristicsEncoder(config.char_dim, config.d_model) |
| |
| |
| self.cross_attn = CrossAttentionBlock(config.d_model, config.n_heads, config.dropout) |
| |
| |
| self.transformer_layers = nn.ModuleList([ |
| TransformerBlock(config.d_model, config.n_heads, config.dropout) |
| for _ in range(config.n_layers) |
| ]) |
| self.final_norm = nn.LayerNorm(config.d_model) |
| |
| |
| self.address_head = AddressHead(config.d_model, config.address_space, config.max_read_range) |
| |
| def forward( |
| self, |
| past_state: torch.Tensor, |
| current_state: torch.Tensor, |
| characteristics: torch.Tensor, |
| ) -> dict: |
| """ |
| Forward pass: state + characteristics → memory address range. |
| |
| Returns dict with address info + internal hidden state. |
| """ |
| |
| state_repr = self.state_encoder(past_state, current_state) |
| |
| |
| char_repr = self.char_encoder(characteristics) |
| |
| |
| |
| state_seq = state_repr.unsqueeze(1) |
| char_seq = char_repr.unsqueeze(1) |
| |
| enriched = self.cross_attn(state_seq, char_seq) |
| |
| |
| hidden = enriched |
| for layer in self.transformer_layers: |
| hidden = layer(hidden) |
| |
| hidden = self.final_norm(hidden) |
| hidden = hidden.squeeze(1) |
| |
| |
| addr_output = self.address_head(hidden) |
| addr_output['hidden'] = hidden |
| addr_output['slm_id'] = self.slm_id |
| |
| return addr_output |
|
|
|
|
| |
| |
| |
|
|
| class StraightThroughSigmoid(torch.autograd.Function): |
| """ |
| Binary gate: hard 0/1 in forward, sigmoid gradient in backward. |
| From literature: ST-GS (Jang et al. 2017) + Switch Transformer routing. |
| """ |
| @staticmethod |
| def forward(ctx, logits): |
| probs = torch.sigmoid(logits) |
| ctx.save_for_backward(probs) |
| return (probs > 0.5).float() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| probs, = ctx.saved_tensors |
| |
| return grad_output * probs * (1 - probs) |
|
|
|
|
| class BLMRouter(nn.Module): |
| """ |
| Routes/selects which SLMs to activate. |
| Produces binary mask like [1, 0, 1]. |
| |
| Uses Straight-Through Sigmoid for differentiable binary selection. |
| Includes load-balancing loss to prevent degenerate routing. |
| """ |
| |
| def __init__(self, d_model: int, n_slms: int): |
| super().__init__() |
| self.n_slms = n_slms |
| |
| self.gate = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, n_slms) |
| ) |
| |
| |
| self.register_buffer('temperature', torch.tensor(1.0)) |
| |
| def forward(self, state_repr: torch.Tensor) -> Tuple[torch.Tensor, dict]: |
| """ |
| Args: |
| state_repr: (batch, d_model) encoded current state |
| |
| Returns: |
| binary_mask: (batch, n_slms) hard 0/1 selection |
| routing_info: dict with probs, losses, etc. |
| """ |
| logits = self.gate(state_repr) |
| |
| |
| scaled_logits = logits / self.temperature.clamp(min=0.1) |
| |
| probs = torch.sigmoid(scaled_logits) |
| |
| |
| hard_mask = (probs > 0.5).float() |
| binary_mask = hard_mask - probs.detach() + probs |
| |
| |
| |
| all_zero = (binary_mask.sum(dim=-1) == 0) |
| if all_zero.any(): |
| max_idx = probs[all_zero].argmax(dim=-1) |
| forced = torch.zeros_like(probs[all_zero]) |
| forced.scatter_(1, max_idx.unsqueeze(1), 1.0) |
| binary_mask[all_zero] = forced |
| |
| |
| usage_per_slm = binary_mask.mean(dim=0) |
| target_usage = 1.0 / self.n_slms |
| balance_loss = ((usage_per_slm - target_usage) ** 2).sum() |
| |
| |
| entropy = -(probs * torch.log(probs + 1e-8) + |
| (1 - probs) * torch.log(1 - probs + 1e-8)) |
| entropy_loss = entropy.mean() |
| |
| routing_info = { |
| 'probs': probs, |
| 'binary_mask': binary_mask, |
| 'balance_loss': balance_loss, |
| 'entropy_loss': entropy_loss, |
| 'logits': logits, |
| } |
| |
| return binary_mask, routing_info |
| |
| def anneal_temperature(self, step: int, anneal_rate: float = 3e-5, min_temp: float = 0.1): |
| """Anneal temperature: start warm (exploratory), cool down (decisive).""" |
| new_temp = max(min_temp, math.exp(-anneal_rate * step)) |
| self.temperature.fill_(new_temp) |
|
|
|
|
| class InfoRequestHead(nn.Module): |
| """ |
| Produces a query vector representing "what information do I need next?" |
| |
| This is the key innovation: instead of passively receiving all SLM outputs, |
| the BLM actively requests specific information. This query modulates which |
| memory regions the SLMs should focus on in the NEXT timestep. |
| """ |
| |
| def __init__(self, d_model: int, query_dim: int): |
| super().__init__() |
| self.query_generator = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, query_dim), |
| nn.LayerNorm(query_dim) |
| ) |
| |
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden: (batch, d_model) BLM's internal state |
| Returns: |
| info_query: (batch, query_dim) "what do I need next?" |
| """ |
| return self.query_generator(hidden) |
|
|
|
|
| class BigLeWorldModel(nn.Module): |
| """ |
| BLM: Big LeWorld Model (~12M params) |
| |
| Two roles: |
| 1. ROUTER: Select which SLMs to activate (binary mask) |
| 2. PREDICTOR: Given selected memory contents, predict next state |
| |
| Plus: Info-Request Head that asks "what information is needed next?" |
| |
| Architecture: |
| 1. Encode current state → routing decision |
| 2. Receive memory reads from selected SLMs |
| 3. Transformer processes (state + memories) |
| 4. Predict next state |
| 5. Generate info request for next timestep |
| """ |
| |
| def __init__(self, config: BLMConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| self.state_encoder = nn.Sequential( |
| nn.Linear(config.state_dim, config.d_model), |
| nn.GELU(), |
| nn.LayerNorm(config.d_model) |
| ) |
| |
| |
| self.memory_encoder = nn.Sequential( |
| nn.Linear(128, config.d_model), |
| nn.GELU(), |
| nn.LayerNorm(config.d_model) |
| ) |
| |
| |
| self.slm_hidden_encoder = nn.Sequential( |
| nn.Linear(128, config.d_model), |
| nn.GELU(), |
| nn.LayerNorm(config.d_model) |
| ) |
| |
| |
| self.router = BLMRouter(config.d_model, config.n_slms) |
| |
| |
| self.transformer_layers = nn.ModuleList([ |
| TransformerBlock(config.d_model, config.n_heads, config.dropout) |
| for _ in range(config.n_layers) |
| ]) |
| self.final_norm = nn.LayerNorm(config.d_model) |
| |
| |
| self.next_state_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model), |
| nn.GELU(), |
| nn.Linear(config.d_model, config.state_dim) |
| ) |
| |
| |
| self.info_request = InfoRequestHead(config.d_model, config.info_query_dim) |
| |
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) |
| self.state_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) |
| self.memory_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02) |
| |
| def forward( |
| self, |
| past_state: torch.Tensor, |
| current_state: torch.Tensor, |
| slm_outputs: List[dict], |
| memory_reads: List[torch.Tensor], |
| info_query_prev: Optional[torch.Tensor] = None, |
| ) -> dict: |
| """ |
| Full BLM forward pass. |
| |
| Returns: |
| dict with next_state, binary_mask, info_query, losses, etc. |
| """ |
| batch_size = current_state.shape[0] |
| |
| |
| state_enc = self.state_encoder(current_state) |
| |
| |
| binary_mask, routing_info = self.router(state_enc) |
| |
| |
| |
| memory_tokens = [] |
| for i, (slm_out, mem_read) in enumerate(zip(slm_outputs, memory_reads)): |
| gate = binary_mask[:, i:i+1] |
| |
| |
| slm_hidden = self.slm_hidden_encoder(slm_out['hidden']) |
| slm_hidden = slm_hidden * gate |
| |
| |
| |
| mem_enc = self.memory_encoder(mem_read) |
| mem_enc = mem_enc * gate.unsqueeze(-1) |
| |
| |
| mem_pooled = mem_enc.mean(dim=1, keepdim=True) |
| |
| memory_tokens.append(slm_hidden.unsqueeze(1)) |
| memory_tokens.append(mem_pooled) |
| |
| |
| |
| cls = self.cls_token.expand(batch_size, -1, -1) |
| state_token = state_enc.unsqueeze(1) + self.state_type_embed |
| |
| |
| mem_sequence = torch.cat(memory_tokens, dim=1) |
| mem_sequence = mem_sequence + self.memory_type_embed |
| |
| sequence = torch.cat([cls, state_token, mem_sequence], dim=1) |
| |
| |
| |
| hidden = sequence |
| for layer in self.transformer_layers: |
| hidden = layer(hidden) |
| hidden = self.final_norm(hidden) |
| |
| |
| cls_output = hidden[:, 0, :] |
| |
| |
| next_state_pred = self.next_state_head(cls_output) |
| |
| |
| info_query = self.info_request(cls_output) |
| |
| return { |
| 'next_state': next_state_pred, |
| 'binary_mask': binary_mask, |
| 'info_query': info_query, |
| 'routing_info': routing_info, |
| 'cls_output': cls_output, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class LeWorldSystem(nn.Module): |
| """ |
| Complete LeWorld Memory Architecture. |
| |
| Orchestrates: |
| - Artificial Memory (bit-level storage) |
| - 3 SLMs (produce memory address ranges) |
| - 1 BLM (selects SLMs, reads memory, predicts next state) |
| |
| Training loop: |
| 1. BLM sees current state → routes to SLMs |
| 2. Selected SLMs produce address ranges |
| 3. Memory is read at those ranges |
| 4. BLM aggregates memory + state → predicts next state |
| 5. BLM generates info-request for next step |
| |
| Losses: |
| - next_state_loss: MSE between predicted and actual next state |
| - routing_balance_loss: encourage balanced SLM usage |
| - address_diversity_loss: encourage SLMs to read different memory regions |
| - info_utility_loss: did the info request lead to useful retrievals? |
| """ |
| |
| def __init__( |
| self, |
| mem_config: MemoryConfig = MemoryConfig(), |
| slm_config: SLMConfig = SLMConfig(), |
| blm_config: BLMConfig = BLMConfig(), |
| ): |
| super().__init__() |
| |
| |
| self.memory = ArtificialMemory(mem_config) |
| |
| |
| self.slms = nn.ModuleList([ |
| SmallLeWorldModel(slm_config, slm_id=i) |
| for i in range(blm_config.n_slms) |
| ]) |
| |
| |
| self.blm = BigLeWorldModel(blm_config) |
| |
| |
| |
| self.info_to_slm = nn.Linear(blm_config.info_query_dim, slm_config.state_dim) |
| |
| self.config = { |
| 'mem': mem_config, |
| 'slm': slm_config, |
| 'blm': blm_config, |
| } |
| |
| def forward( |
| self, |
| past_state: torch.Tensor, |
| current_state: torch.Tensor, |
| characteristics: torch.Tensor, |
| next_state_target: Optional[torch.Tensor] = None, |
| info_query_prev: Optional[torch.Tensor] = None, |
| ) -> dict: |
| """ |
| Full system forward pass. |
| """ |
| batch_size = current_state.shape[0] |
| |
| |
| |
| if info_query_prev is not None: |
| info_modulation = self.info_to_slm(info_query_prev) |
| modulated_state = current_state + 0.1 * info_modulation |
| else: |
| modulated_state = current_state |
| |
| |
| slm_outputs = [] |
| for slm in self.slms: |
| out = slm(past_state, modulated_state, characteristics) |
| slm_outputs.append(out) |
| |
| |
| memory_reads = [] |
| for slm_out in slm_outputs: |
| _, encoded, valid_mask = self.memory.read( |
| slm_out['start_addr'], |
| slm_out['end_addr'] |
| ) |
| memory_reads.append(encoded) |
| |
| |
| blm_output = self.blm( |
| past_state, current_state, |
| slm_outputs, memory_reads, |
| info_query_prev |
| ) |
| |
| |
| losses = {} |
| if next_state_target is not None: |
| |
| losses['next_state_loss'] = F.mse_loss( |
| blm_output['next_state'], next_state_target |
| ) |
| |
| |
| losses['balance_loss'] = blm_output['routing_info']['balance_loss'] |
| |
| |
| addresses = torch.stack([ |
| slm_out['start_addr'].float() for slm_out in slm_outputs |
| ], dim=1) |
| |
| addr_diff = torch.cdist(addresses.unsqueeze(-1), addresses.unsqueeze(-1)) |
| diversity_loss = -addr_diff.mean() |
| losses['diversity_loss'] = diversity_loss |
| |
| |
| losses['total_loss'] = ( |
| losses['next_state_loss'] |
| + 0.01 * losses['balance_loss'] |
| + 0.001 * losses['diversity_loss'] |
| ) |
| |
| return { |
| 'next_state': blm_output['next_state'], |
| 'binary_mask': blm_output['binary_mask'], |
| 'info_query': blm_output['info_query'], |
| 'slm_outputs': slm_outputs, |
| 'memory_reads': memory_reads, |
| 'losses': losses, |
| 'routing_info': blm_output['routing_info'], |
| } |
| |
| def multi_step_forward( |
| self, |
| states: torch.Tensor, |
| characteristics: torch.Tensor, |
| n_steps: int = None, |
| ) -> dict: |
| """ |
| Run the system over multiple timesteps autoregressively. |
| |
| For training: teacher forcing with ground-truth states |
| """ |
| batch_size, T, state_dim = states.shape |
| if n_steps is None: |
| n_steps = T - 1 |
| |
| all_predictions = [] |
| all_masks = [] |
| total_loss = None |
| info_query = None |
| |
| for t in range(min(n_steps, T - 1)): |
| past_state = states[:, max(0, t-1), :] |
| current_state = states[:, t, :] |
| next_state_target = states[:, t+1, :] |
| |
| output = self.forward( |
| past_state, current_state, characteristics, |
| next_state_target, info_query |
| ) |
| |
| all_predictions.append(output['next_state']) |
| all_masks.append(output['binary_mask']) |
| info_query = output['info_query'] |
| |
| if output['losses']: |
| if total_loss is None: |
| total_loss = output['losses']['total_loss'] |
| else: |
| total_loss = total_loss + output['losses']['total_loss'] |
| |
| if total_loss is None: |
| total_loss = torch.tensor(0.0, device=states.device) |
| return { |
| 'predictions': torch.stack(all_predictions, dim=1), |
| 'masks': torch.stack(all_masks, dim=1), |
| 'total_loss': total_loss / max(1, min(n_steps, T - 1)), |
| 'final_info_query': info_query, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def count_params(model, name="Model"): |
| """Count and display parameter breakdown.""" |
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\n{'='*60}") |
| print(f"{name}: {total:,} total params ({trainable:,} trainable)") |
| print(f"{'='*60}") |
| |
| for child_name, child in model.named_children(): |
| child_params = sum(p.numel() for p in child.parameters()) |
| if child_params > 0: |
| print(f" {child_name}: {child_params:,}") |
| |
| return total |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("LeWorld Memory Architecture — Component Verification") |
| print("=" * 60) |
| |
| |
| mem_config = MemoryConfig() |
| slm_config = SLMConfig() |
| blm_config = BLMConfig() |
| |
| |
| system = LeWorldSystem(mem_config, slm_config, blm_config) |
| |
| |
| print("\n--- Parameter Counts ---") |
| count_params(system.memory, "Artificial Memory") |
| for i, slm in enumerate(system.slms): |
| count_params(slm, f"SLM-{i}") |
| count_params(system.blm, "BLM") |
| count_params(system, "Full System") |
| |
| |
| print("\n--- Forward Pass Test ---") |
| batch_size = 4 |
| state_dim = slm_config.state_dim |
| char_dim = slm_config.char_dim |
| |
| past_state = torch.randn(batch_size, state_dim) |
| current_state = torch.randn(batch_size, state_dim) |
| characteristics = torch.randn(batch_size, char_dim) |
| next_state = torch.randn(batch_size, state_dim) |
| |
| output = system(past_state, current_state, characteristics, next_state) |
| |
| print(f"Next state prediction shape: {output['next_state'].shape}") |
| print(f"Binary mask (SLM selection): {output['binary_mask']}") |
| print(f"Info query shape: {output['info_query'].shape}") |
| print(f"Losses: {output['losses']}") |
| |
| |
| print("\n--- Multi-Step Test ---") |
| T = 10 |
| states = torch.randn(batch_size, T, state_dim) |
| |
| ms_output = system.multi_step_forward(states, characteristics) |
| print(f"Predictions shape: {ms_output['predictions'].shape}") |
| print(f"Masks shape: {ms_output['masks'].shape}") |
| print(f"Average loss: {ms_output['total_loss'].item():.4f}") |
| |
| |
| print("\n--- Routing Patterns Over Time ---") |
| masks = ms_output['masks'][0].detach() |
| for t in range(masks.shape[0]): |
| mask = masks[t].int().tolist() |
| print(f" Step {t}: SLMs selected = {mask}") |
| |
| print("\n✅ All components verified successfully!") |