leworld-memory-architecture / leworld_architecture.py
inv0krr's picture
Add core architecture: Memory, SLM, BLM, full system
52bd035 verified
"""
LeWorld Memory Architecture — Complete Implementation
=====================================================
Component 1: Artificial Memory (CPU-style bit storage)
Component 2: SLMs (Small LeWorld Models, ~1.5M params each)
Component 3: BLM (Big LeWorld Model, ~12M params)
Component 4: Full System with training loop
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Tuple, List, Optional
# =============================================================================
# Configuration
# =============================================================================
@dataclass
class MemoryConfig:
"""CPU-style artificial memory configuration."""
num_words: int = 65536 # 64K addressable words (like 64K RAM)
word_size: int = 32 # 32 bits per word
address_bits: int = 16 # 2^16 = 65536 addresses
max_read_range: int = 256 # max words per single read operation
@dataclass
class SLMConfig:
"""Small LeWorld Model configuration (~1.5M params)."""
d_model: int = 128 # internal dimension
n_heads: int = 4 # attention heads
n_layers: int = 2 # transformer layers
state_dim: int = 64 # state vector dimension
char_dim: int = 32 # characteristics vector dimension
address_space: int = 65536 # must match MemoryConfig.num_words
max_read_range: int = 256 # must match MemoryConfig.max_read_range
dropout: float = 0.1
@dataclass
class BLMConfig:
"""Big LeWorld Model configuration (~12M params)."""
d_model: int = 384 # internal dimension
n_heads: int = 6 # attention heads
n_layers: int = 6 # transformer layers
state_dim: int = 64 # state vector dimension
n_slms: int = 3 # number of SLMs to route over
memory_read_dim: int = 256 # dimension of encoded memory reads
info_query_dim: int = 128 # dimension of "what info do I need" query
dropout: float = 0.1
# =============================================================================
# Component 1: Artificial Memory
# =============================================================================
class ArtificialMemory(nn.Module):
"""
CPU-style bit-level memory with address-range access.
Stores data as actual bits (0/1 tensors), organized into addressable words.
Supports:
- READ(start_addr, end_addr) → returns bit block
- WRITE(start_addr, data) → writes bits to memory
- Bit-to-embedding projection (for neural network consumption)
This mimics how a CPU accesses RAM:
- Each address points to a word (32 bits)
- Contiguous reads fetch a range of words
- No inherent "meaning" — bits are just bits until interpreted
"""
def __init__(self, config: MemoryConfig):
super().__init__()
self.config = config
# The actual memory: (num_words, word_size) binary tensor
# Initialized randomly — represents "existing knowledge base"
self.register_buffer(
'memory',
torch.randint(0, 2, (config.num_words, config.word_size)).float()
)
# Bit-to-embedding projection: converts raw bits into dense vectors
# This is learnable — the system learns what bit patterns mean
self.bit_encoder = nn.Sequential(
nn.Linear(config.word_size, 64),
nn.GELU(),
nn.Linear(64, 128),
nn.LayerNorm(128)
)
# Write projection: converts dense vectors back to bit probabilities
self.bit_decoder = nn.Sequential(
nn.Linear(128, 64),
nn.GELU(),
nn.Linear(64, config.word_size),
nn.Sigmoid() # output probabilities for each bit
)
def read(self, start_addr: torch.Tensor, end_addr: torch.Tensor) -> torch.Tensor:
"""
Read a contiguous range of words from memory.
Args:
start_addr: (batch,) integer tensor of start addresses
end_addr: (batch,) integer tensor of end addresses
Returns:
bit_block: (batch, max_range, word_size) raw bits
encoded: (batch, max_range, 128) encoded memory content
"""
batch_size = start_addr.shape[0]
max_range = self.config.max_read_range
# Clamp addresses to valid range
start_addr = start_addr.clamp(0, self.config.num_words - 1)
end_addr = end_addr.clamp(start_addr,
torch.minimum(start_addr + max_range,
torch.tensor(self.config.num_words)))
# Gather memory contents for each batch element
# Create index tensor for the address ranges
offsets = torch.arange(max_range, device=start_addr.device).unsqueeze(0) # (1, max_range)
addresses = start_addr.unsqueeze(1) + offsets # (batch, max_range)
addresses = addresses.clamp(0, self.config.num_words - 1)
# Create validity mask (addresses within [start, end) are valid)
range_lengths = (end_addr - start_addr).unsqueeze(1) # (batch, 1)
valid_mask = offsets < range_lengths # (batch, max_range)
# Gather bits
bit_block = self.memory[addresses] # (batch, max_range, word_size)
bit_block = bit_block * valid_mask.unsqueeze(-1).float() # zero out invalid
# Encode bits to dense vectors
encoded = self.bit_encoder(bit_block) # (batch, max_range, 128)
encoded = encoded * valid_mask.unsqueeze(-1).float()
return bit_block, encoded, valid_mask
def write(self, start_addr: torch.Tensor, data: torch.Tensor):
"""
Write data to memory (differentiable soft-write).
Args:
start_addr: (batch,) start addresses
data: (batch, n_words, 128) encoded data to write
"""
n_words = data.shape[1]
# Decode to bit probabilities
bit_probs = self.bit_decoder(data) # (batch, n_words, word_size)
# Hard bits via straight-through
hard_bits = (bit_probs > 0.5).float()
bits_to_write = hard_bits - bit_probs.detach() + bit_probs # ST trick
# Write to memory (last batch element wins for simplicity)
for b in range(start_addr.shape[0]):
addr = start_addr[b].long()
end = min(addr + n_words, self.config.num_words)
actual_n = end - addr
self.memory[addr:end] = bits_to_write[b, :actual_n].detach()
def soft_read(self, attention_weights: torch.Tensor) -> torch.Tensor:
"""
Content-based soft read using attention weights over entire memory.
Used for differentiable end-to-end training.
Args:
attention_weights: (batch, num_words) soft address distribution
Returns:
encoded: (batch, 128) weighted memory content
"""
# Encode all memory (expensive but differentiable)
all_encoded = self.bit_encoder(self.memory) # (num_words, 128)
# Weighted sum
encoded = torch.matmul(attention_weights, all_encoded) # (batch, 128)
return encoded
# =============================================================================
# Component 2: Small LeWorld Model (SLM)
# =============================================================================
class StateEncoder(nn.Module):
"""Encodes past_state and current_state into a joint representation."""
def __init__(self, state_dim: int, d_model: int):
super().__init__()
self.past_proj = nn.Linear(state_dim, d_model)
self.curr_proj = nn.Linear(state_dim, d_model)
self.combiner = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.GELU(),
nn.LayerNorm(d_model)
)
def forward(self, past_state: torch.Tensor, current_state: torch.Tensor) -> torch.Tensor:
"""
Args:
past_state: (batch, state_dim)
current_state: (batch, state_dim)
Returns:
combined: (batch, d_model)
"""
past_enc = F.gelu(self.past_proj(past_state))
curr_enc = F.gelu(self.curr_proj(current_state))
combined = self.combiner(torch.cat([past_enc, curr_enc], dim=-1))
return combined
class CharacteristicsEncoder(nn.Module):
"""Encodes static characteristics/context."""
def __init__(self, char_dim: int, d_model: int):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(char_dim, d_model),
nn.GELU(),
nn.LayerNorm(d_model)
)
def forward(self, characteristics: torch.Tensor) -> torch.Tensor:
return self.encoder(characteristics)
class TransformerBlock(nn.Module):
"""Standard transformer block with pre-norm."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Self-attention with pre-norm
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + attn_out
# FFN with pre-norm
x = x + self.ffn(self.norm2(x))
return x
class CrossAttentionBlock(nn.Module):
"""Cross-attention: state attends to characteristics."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.norm_q = nn.LayerNorm(d_model)
self.norm_kv = nn.LayerNorm(d_model)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm_ff = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
def forward(self, query: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
normed_q = self.norm_q(query)
normed_kv = self.norm_kv(context)
attn_out, _ = self.cross_attn(normed_q, normed_kv, normed_kv)
x = query + attn_out
x = x + self.ffn(self.norm_ff(x))
return x
class AddressHead(nn.Module):
"""
Produces memory address range (start_addr, end_addr) from hidden state.
Uses two approaches:
1. HARD mode: argmax over address space (for inference)
2. SOFT mode: attention weights over memory (for differentiable training)
"""
def __init__(self, d_model: int, address_space: int, max_range: int):
super().__init__()
self.address_space = address_space
self.max_range = max_range
# Produce start address logits
# We don't have a linear over 65K — that's too many params
# Instead: predict address as composition of sub-addresses (like product keys)
self.addr_bits = int(math.log2(address_space)) # 16 for 65536
assert 2 ** self.addr_bits == address_space, "address_space must be power of 2"
# Split address into high byte and low byte (8+8 = 16 bits)
self.half_bits = self.addr_bits // 2 # 8
self.half_space = 2 ** self.half_bits # 256
# Predict high and low parts separately (product key approach)
self.start_high = nn.Linear(d_model, self.half_space) # 256 outputs
self.start_low = nn.Linear(d_model, self.half_space) # 256 outputs
# Predict range length (how many words to read)
self.range_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, max_range)
)
# Confidence head
self.confidence_head = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid()
)
def forward(self, hidden: torch.Tensor) -> dict:
"""
Args:
hidden: (batch, d_model)
Returns:
dict with:
start_addr: (batch,) integer addresses
end_addr: (batch,) integer addresses
range_length: (batch,) how many words to read
confidence: (batch,) read confidence score
start_logits_high: (batch, 256) for soft addressing
start_logits_low: (batch, 256) for soft addressing
range_logits: (batch, max_range) for soft range selection
"""
batch_size = hidden.shape[0]
# Product-key address generation
high_logits = self.start_high(hidden) # (batch, 256)
low_logits = self.start_low(hidden) # (batch, 256)
# Hard address via argmax
high_idx = high_logits.argmax(dim=-1) # (batch,)
low_idx = low_logits.argmax(dim=-1) # (batch,)
start_addr = high_idx * self.half_space + low_idx # (batch,) 0..65535
# Range length
range_logits = self.range_head(hidden) # (batch, max_range)
range_length = range_logits.argmax(dim=-1) + 1 # (batch,) 1..max_range
end_addr = (start_addr + range_length).clamp(max=self.address_space - 1)
# Confidence
confidence = self.confidence_head(hidden).squeeze(-1) # (batch,)
return {
'start_addr': start_addr,
'end_addr': end_addr,
'range_length': range_length,
'confidence': confidence,
'start_logits_high': high_logits,
'start_logits_low': low_logits,
'range_logits': range_logits,
}
class SmallLeWorldModel(nn.Module):
"""
SLM: Small LeWorld Model (~1.5M params)
Takes (past_state, current_state, characteristics) and produces
a memory address range pointing to the most useful memory for
next-state prediction.
Architecture:
1. Encode past + current state → state representation
2. Encode characteristics
3. Cross-attend: state attends to characteristics
4. Self-attention transformer layers
5. Address head: output (start_addr, end_addr, confidence)
"""
def __init__(self, config: SLMConfig, slm_id: int = 0):
super().__init__()
self.config = config
self.slm_id = slm_id
# Encoders
self.state_encoder = StateEncoder(config.state_dim, config.d_model)
self.char_encoder = CharacteristicsEncoder(config.char_dim, config.d_model)
# Cross-attention: state ← characteristics
self.cross_attn = CrossAttentionBlock(config.d_model, config.n_heads, config.dropout)
# Self-attention transformer
self.transformer_layers = nn.ModuleList([
TransformerBlock(config.d_model, config.n_heads, config.dropout)
for _ in range(config.n_layers)
])
self.final_norm = nn.LayerNorm(config.d_model)
# Address output head
self.address_head = AddressHead(config.d_model, config.address_space, config.max_read_range)
def forward(
self,
past_state: torch.Tensor, # (batch, state_dim)
current_state: torch.Tensor, # (batch, state_dim)
characteristics: torch.Tensor, # (batch, char_dim)
) -> dict:
"""
Forward pass: state + characteristics → memory address range.
Returns dict with address info + internal hidden state.
"""
# Encode states
state_repr = self.state_encoder(past_state, current_state) # (batch, d_model)
# Encode characteristics
char_repr = self.char_encoder(characteristics) # (batch, d_model)
# Cross-attention: state queries characteristics
# Unsqueeze to sequence dim for attention
state_seq = state_repr.unsqueeze(1) # (batch, 1, d_model)
char_seq = char_repr.unsqueeze(1) # (batch, 1, d_model)
enriched = self.cross_attn(state_seq, char_seq) # (batch, 1, d_model)
# Self-attention layers
hidden = enriched
for layer in self.transformer_layers:
hidden = layer(hidden)
hidden = self.final_norm(hidden)
hidden = hidden.squeeze(1) # (batch, d_model)
# Produce address range
addr_output = self.address_head(hidden)
addr_output['hidden'] = hidden # keep for BLM to use
addr_output['slm_id'] = self.slm_id
return addr_output
# =============================================================================
# Component 3: Big LeWorld Model (BLM)
# =============================================================================
class StraightThroughSigmoid(torch.autograd.Function):
"""
Binary gate: hard 0/1 in forward, sigmoid gradient in backward.
From literature: ST-GS (Jang et al. 2017) + Switch Transformer routing.
"""
@staticmethod
def forward(ctx, logits):
probs = torch.sigmoid(logits)
ctx.save_for_backward(probs)
return (probs > 0.5).float()
@staticmethod
def backward(ctx, grad_output):
probs, = ctx.saved_tensors
# Sigmoid derivative: p * (1-p)
return grad_output * probs * (1 - probs)
class BLMRouter(nn.Module):
"""
Routes/selects which SLMs to activate.
Produces binary mask like [1, 0, 1].
Uses Straight-Through Sigmoid for differentiable binary selection.
Includes load-balancing loss to prevent degenerate routing.
"""
def __init__(self, d_model: int, n_slms: int):
super().__init__()
self.n_slms = n_slms
self.gate = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, n_slms)
)
# Temperature for annealing (start warm, cool down)
self.register_buffer('temperature', torch.tensor(1.0))
def forward(self, state_repr: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
Args:
state_repr: (batch, d_model) encoded current state
Returns:
binary_mask: (batch, n_slms) hard 0/1 selection
routing_info: dict with probs, losses, etc.
"""
logits = self.gate(state_repr) # (batch, n_slms)
# Scale by temperature
scaled_logits = logits / self.temperature.clamp(min=0.1)
probs = torch.sigmoid(scaled_logits) # (batch, n_slms)
# Straight-through binary: hard in forward, soft in backward
hard_mask = (probs > 0.5).float()
binary_mask = hard_mask - probs.detach() + probs # THE ST TRICK
# Ensure at least one SLM is selected (don't want all zeros)
# If all zeros, force-select the highest probability SLM
all_zero = (binary_mask.sum(dim=-1) == 0) # (batch,)
if all_zero.any():
max_idx = probs[all_zero].argmax(dim=-1)
forced = torch.zeros_like(probs[all_zero])
forced.scatter_(1, max_idx.unsqueeze(1), 1.0)
binary_mask[all_zero] = forced
# Load balance loss: encourage roughly equal usage of SLMs
usage_per_slm = binary_mask.mean(dim=0) # (n_slms,)
target_usage = 1.0 / self.n_slms
balance_loss = ((usage_per_slm - target_usage) ** 2).sum()
# Entropy loss: encourage decisive routing (not all ~0.5)
entropy = -(probs * torch.log(probs + 1e-8) +
(1 - probs) * torch.log(1 - probs + 1e-8))
entropy_loss = entropy.mean()
routing_info = {
'probs': probs,
'binary_mask': binary_mask,
'balance_loss': balance_loss,
'entropy_loss': entropy_loss,
'logits': logits,
}
return binary_mask, routing_info
def anneal_temperature(self, step: int, anneal_rate: float = 3e-5, min_temp: float = 0.1):
"""Anneal temperature: start warm (exploratory), cool down (decisive)."""
new_temp = max(min_temp, math.exp(-anneal_rate * step))
self.temperature.fill_(new_temp)
class InfoRequestHead(nn.Module):
"""
Produces a query vector representing "what information do I need next?"
This is the key innovation: instead of passively receiving all SLM outputs,
the BLM actively requests specific information. This query modulates which
memory regions the SLMs should focus on in the NEXT timestep.
"""
def __init__(self, d_model: int, query_dim: int):
super().__init__()
self.query_generator = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, query_dim),
nn.LayerNorm(query_dim)
)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden: (batch, d_model) BLM's internal state
Returns:
info_query: (batch, query_dim) "what do I need next?"
"""
return self.query_generator(hidden)
class BigLeWorldModel(nn.Module):
"""
BLM: Big LeWorld Model (~12M params)
Two roles:
1. ROUTER: Select which SLMs to activate (binary mask)
2. PREDICTOR: Given selected memory contents, predict next state
Plus: Info-Request Head that asks "what information is needed next?"
Architecture:
1. Encode current state → routing decision
2. Receive memory reads from selected SLMs
3. Transformer processes (state + memories)
4. Predict next state
5. Generate info request for next timestep
"""
def __init__(self, config: BLMConfig):
super().__init__()
self.config = config
# State encoder (maps state_dim → d_model)
self.state_encoder = nn.Sequential(
nn.Linear(config.state_dim, config.d_model),
nn.GELU(),
nn.LayerNorm(config.d_model)
)
# Memory read encoder (maps encoded memory → d_model)
self.memory_encoder = nn.Sequential(
nn.Linear(128, config.d_model), # 128 from ArtificialMemory bit_encoder
nn.GELU(),
nn.LayerNorm(config.d_model)
)
# SLM hidden state encoder (maps SLM hidden → d_model)
self.slm_hidden_encoder = nn.Sequential(
nn.Linear(128, config.d_model), # 128 = SLM d_model
nn.GELU(),
nn.LayerNorm(config.d_model)
)
# Router: selects which SLMs to use
self.router = BLMRouter(config.d_model, config.n_slms)
# Transformer backbone
self.transformer_layers = nn.ModuleList([
TransformerBlock(config.d_model, config.n_heads, config.dropout)
for _ in range(config.n_layers)
])
self.final_norm = nn.LayerNorm(config.d_model)
# Prediction heads
self.next_state_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model),
nn.GELU(),
nn.Linear(config.d_model, config.state_dim)
)
# Info request head: "what do I need next?"
self.info_request = InfoRequestHead(config.d_model, config.info_query_dim)
# Learnable tokens
self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
self.state_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
self.memory_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
def forward(
self,
past_state: torch.Tensor, # (batch, state_dim)
current_state: torch.Tensor, # (batch, state_dim)
slm_outputs: List[dict], # list of SLM output dicts
memory_reads: List[torch.Tensor], # list of (batch, range, 128) encoded memory
info_query_prev: Optional[torch.Tensor] = None, # (batch, query_dim) from previous step
) -> dict:
"""
Full BLM forward pass.
Returns:
dict with next_state, binary_mask, info_query, losses, etc.
"""
batch_size = current_state.shape[0]
# 1. Encode current state for routing decision
state_enc = self.state_encoder(current_state) # (batch, d_model)
# 2. Route: select which SLMs to use
binary_mask, routing_info = self.router(state_enc) # (batch, n_slms)
# 3. Aggregate selected memory reads
# For each SLM, apply its binary gate and encode its memory read
memory_tokens = []
for i, (slm_out, mem_read) in enumerate(zip(slm_outputs, memory_reads)):
gate = binary_mask[:, i:i+1] # (batch, 1)
# Gate the SLM's hidden representation
slm_hidden = self.slm_hidden_encoder(slm_out['hidden']) # (batch, d_model)
slm_hidden = slm_hidden * gate # zero if SLM not selected
# Gate and encode the memory read
# mem_read: (batch, range_len, 128)
mem_enc = self.memory_encoder(mem_read) # (batch, range_len, d_model)
mem_enc = mem_enc * gate.unsqueeze(-1) # zero if SLM not selected
# Pool memory read to single token (mean pool over range)
mem_pooled = mem_enc.mean(dim=1, keepdim=True) # (batch, 1, d_model)
memory_tokens.append(slm_hidden.unsqueeze(1)) # SLM hidden as token
memory_tokens.append(mem_pooled) # memory content as token
# 4. Build input sequence for transformer
# [CLS] + [state] + [slm_0_hidden, slm_0_mem, slm_1_hidden, slm_1_mem, ...]
cls = self.cls_token.expand(batch_size, -1, -1)
state_token = state_enc.unsqueeze(1) + self.state_type_embed # (batch, 1, d_model)
# Add memory type embedding to memory tokens
mem_sequence = torch.cat(memory_tokens, dim=1) # (batch, 2*n_slms, d_model)
mem_sequence = mem_sequence + self.memory_type_embed
sequence = torch.cat([cls, state_token, mem_sequence], dim=1)
# Shape: (batch, 1 + 1 + 2*n_slms, d_model)
# 5. Transformer processing
hidden = sequence
for layer in self.transformer_layers:
hidden = layer(hidden)
hidden = self.final_norm(hidden)
# 6. Extract predictions from CLS token
cls_output = hidden[:, 0, :] # (batch, d_model)
# 7. Predict next state
next_state_pred = self.next_state_head(cls_output) # (batch, state_dim)
# 8. Generate info request for next timestep
info_query = self.info_request(cls_output) # (batch, query_dim)
return {
'next_state': next_state_pred,
'binary_mask': binary_mask,
'info_query': info_query,
'routing_info': routing_info,
'cls_output': cls_output,
}
# =============================================================================
# Component 4: Full LeWorld System
# =============================================================================
class LeWorldSystem(nn.Module):
"""
Complete LeWorld Memory Architecture.
Orchestrates:
- Artificial Memory (bit-level storage)
- 3 SLMs (produce memory address ranges)
- 1 BLM (selects SLMs, reads memory, predicts next state)
Training loop:
1. BLM sees current state → routes to SLMs
2. Selected SLMs produce address ranges
3. Memory is read at those ranges
4. BLM aggregates memory + state → predicts next state
5. BLM generates info-request for next step
Losses:
- next_state_loss: MSE between predicted and actual next state
- routing_balance_loss: encourage balanced SLM usage
- address_diversity_loss: encourage SLMs to read different memory regions
- info_utility_loss: did the info request lead to useful retrievals?
"""
def __init__(
self,
mem_config: MemoryConfig = MemoryConfig(),
slm_config: SLMConfig = SLMConfig(),
blm_config: BLMConfig = BLMConfig(),
):
super().__init__()
# Artificial Memory
self.memory = ArtificialMemory(mem_config)
# 3 SLMs
self.slms = nn.ModuleList([
SmallLeWorldModel(slm_config, slm_id=i)
for i in range(blm_config.n_slms)
])
# BLM
self.blm = BigLeWorldModel(blm_config)
# Info-query → SLM modulation: the BLM's info request
# influences what SLMs look for in the next timestep
self.info_to_slm = nn.Linear(blm_config.info_query_dim, slm_config.state_dim)
self.config = {
'mem': mem_config,
'slm': slm_config,
'blm': blm_config,
}
def forward(
self,
past_state: torch.Tensor, # (batch, state_dim)
current_state: torch.Tensor, # (batch, state_dim)
characteristics: torch.Tensor, # (batch, char_dim)
next_state_target: Optional[torch.Tensor] = None, # (batch, state_dim) for training
info_query_prev: Optional[torch.Tensor] = None, # from previous timestep
) -> dict:
"""
Full system forward pass.
"""
batch_size = current_state.shape[0]
# If we have a previous info query, modulate the current state
# This is how the BLM's "what do I need?" influences retrieval
if info_query_prev is not None:
info_modulation = self.info_to_slm(info_query_prev) # (batch, state_dim)
modulated_state = current_state + 0.1 * info_modulation # gentle modulation
else:
modulated_state = current_state
# 1. Run all 3 SLMs to get address ranges
slm_outputs = []
for slm in self.slms:
out = slm(past_state, modulated_state, characteristics)
slm_outputs.append(out)
# 2. Read memory at each SLM's address range
memory_reads = []
for slm_out in slm_outputs:
_, encoded, valid_mask = self.memory.read(
slm_out['start_addr'],
slm_out['end_addr']
)
memory_reads.append(encoded)
# 3. BLM processes everything
blm_output = self.blm(
past_state, current_state,
slm_outputs, memory_reads,
info_query_prev
)
# 4. Compute losses if training
losses = {}
if next_state_target is not None:
# Primary loss: next state prediction
losses['next_state_loss'] = F.mse_loss(
blm_output['next_state'], next_state_target
)
# Routing balance loss
losses['balance_loss'] = blm_output['routing_info']['balance_loss']
# Address diversity loss: penalize SLMs for reading same regions
addresses = torch.stack([
slm_out['start_addr'].float() for slm_out in slm_outputs
], dim=1) # (batch, n_slms)
# Pairwise distance between SLM addresses (want to maximize)
addr_diff = torch.cdist(addresses.unsqueeze(-1), addresses.unsqueeze(-1))
diversity_loss = -addr_diff.mean() # negative = encourage large distances
losses['diversity_loss'] = diversity_loss
# Total loss
losses['total_loss'] = (
losses['next_state_loss']
+ 0.01 * losses['balance_loss']
+ 0.001 * losses['diversity_loss']
)
return {
'next_state': blm_output['next_state'],
'binary_mask': blm_output['binary_mask'],
'info_query': blm_output['info_query'],
'slm_outputs': slm_outputs,
'memory_reads': memory_reads,
'losses': losses,
'routing_info': blm_output['routing_info'],
}
def multi_step_forward(
self,
states: torch.Tensor, # (batch, T, state_dim) sequence of states
characteristics: torch.Tensor, # (batch, char_dim) static
n_steps: int = None,
) -> dict:
"""
Run the system over multiple timesteps autoregressively.
For training: teacher forcing with ground-truth states
"""
batch_size, T, state_dim = states.shape
if n_steps is None:
n_steps = T - 1 # predict all future states
all_predictions = []
all_masks = []
total_loss = None
info_query = None
for t in range(min(n_steps, T - 1)):
past_state = states[:, max(0, t-1), :]
current_state = states[:, t, :]
next_state_target = states[:, t+1, :]
output = self.forward(
past_state, current_state, characteristics,
next_state_target, info_query
)
all_predictions.append(output['next_state'])
all_masks.append(output['binary_mask'])
info_query = output['info_query']
if output['losses']:
if total_loss is None:
total_loss = output['losses']['total_loss']
else:
total_loss = total_loss + output['losses']['total_loss']
if total_loss is None:
total_loss = torch.tensor(0.0, device=states.device)
return {
'predictions': torch.stack(all_predictions, dim=1),
'masks': torch.stack(all_masks, dim=1),
'total_loss': total_loss / max(1, min(n_steps, T - 1)),
'final_info_query': info_query,
}
# =============================================================================
# Parameter Count Verification
# =============================================================================
def count_params(model, name="Model"):
"""Count and display parameter breakdown."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n{'='*60}")
print(f"{name}: {total:,} total params ({trainable:,} trainable)")
print(f"{'='*60}")
for child_name, child in model.named_children():
child_params = sum(p.numel() for p in child.parameters())
if child_params > 0:
print(f" {child_name}: {child_params:,}")
return total
# =============================================================================
# Demo / Test
# =============================================================================
if __name__ == "__main__":
print("LeWorld Memory Architecture — Component Verification")
print("=" * 60)
# Configs
mem_config = MemoryConfig()
slm_config = SLMConfig()
blm_config = BLMConfig()
# Build system
system = LeWorldSystem(mem_config, slm_config, blm_config)
# Count parameters
print("\n--- Parameter Counts ---")
count_params(system.memory, "Artificial Memory")
for i, slm in enumerate(system.slms):
count_params(slm, f"SLM-{i}")
count_params(system.blm, "BLM")
count_params(system, "Full System")
# Test forward pass
print("\n--- Forward Pass Test ---")
batch_size = 4
state_dim = slm_config.state_dim
char_dim = slm_config.char_dim
past_state = torch.randn(batch_size, state_dim)
current_state = torch.randn(batch_size, state_dim)
characteristics = torch.randn(batch_size, char_dim)
next_state = torch.randn(batch_size, state_dim)
output = system(past_state, current_state, characteristics, next_state)
print(f"Next state prediction shape: {output['next_state'].shape}")
print(f"Binary mask (SLM selection): {output['binary_mask']}")
print(f"Info query shape: {output['info_query'].shape}")
print(f"Losses: {output['losses']}")
# Test multi-step
print("\n--- Multi-Step Test ---")
T = 10
states = torch.randn(batch_size, T, state_dim)
ms_output = system.multi_step_forward(states, characteristics)
print(f"Predictions shape: {ms_output['predictions'].shape}")
print(f"Masks shape: {ms_output['masks'].shape}")
print(f"Average loss: {ms_output['total_loss'].item():.4f}")
# Show routing patterns over time
print("\n--- Routing Patterns Over Time ---")
masks = ms_output['masks'][0].detach() # first batch element
for t in range(masks.shape[0]):
mask = masks[t].int().tolist()
print(f" Step {t}: SLMs selected = {mask}")
print("\n✅ All components verified successfully!")