inv0krr commited on
Commit
52bd035
·
verified ·
1 Parent(s): 2f44c12

Add core architecture: Memory, SLM, BLM, full system

Browse files
Files changed (1) hide show
  1. leworld_architecture.py +990 -0
leworld_architecture.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LeWorld Memory Architecture — Complete Implementation
3
+ =====================================================
4
+ Component 1: Artificial Memory (CPU-style bit storage)
5
+ Component 2: SLMs (Small LeWorld Models, ~1.5M params each)
6
+ Component 3: BLM (Big LeWorld Model, ~12M params)
7
+ Component 4: Full System with training loop
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import math
14
+ from dataclasses import dataclass
15
+ from typing import Tuple, List, Optional
16
+
17
+
18
+ # =============================================================================
19
+ # Configuration
20
+ # =============================================================================
21
+
22
+ @dataclass
23
+ class MemoryConfig:
24
+ """CPU-style artificial memory configuration."""
25
+ num_words: int = 65536 # 64K addressable words (like 64K RAM)
26
+ word_size: int = 32 # 32 bits per word
27
+ address_bits: int = 16 # 2^16 = 65536 addresses
28
+ max_read_range: int = 256 # max words per single read operation
29
+
30
+
31
+ @dataclass
32
+ class SLMConfig:
33
+ """Small LeWorld Model configuration (~1.5M params)."""
34
+ d_model: int = 128 # internal dimension
35
+ n_heads: int = 4 # attention heads
36
+ n_layers: int = 2 # transformer layers
37
+ state_dim: int = 64 # state vector dimension
38
+ char_dim: int = 32 # characteristics vector dimension
39
+ address_space: int = 65536 # must match MemoryConfig.num_words
40
+ max_read_range: int = 256 # must match MemoryConfig.max_read_range
41
+ dropout: float = 0.1
42
+
43
+
44
+ @dataclass
45
+ class BLMConfig:
46
+ """Big LeWorld Model configuration (~12M params)."""
47
+ d_model: int = 384 # internal dimension
48
+ n_heads: int = 6 # attention heads
49
+ n_layers: int = 6 # transformer layers
50
+ state_dim: int = 64 # state vector dimension
51
+ n_slms: int = 3 # number of SLMs to route over
52
+ memory_read_dim: int = 256 # dimension of encoded memory reads
53
+ info_query_dim: int = 128 # dimension of "what info do I need" query
54
+ dropout: float = 0.1
55
+
56
+
57
+ # =============================================================================
58
+ # Component 1: Artificial Memory
59
+ # =============================================================================
60
+
61
+ class ArtificialMemory(nn.Module):
62
+ """
63
+ CPU-style bit-level memory with address-range access.
64
+
65
+ Stores data as actual bits (0/1 tensors), organized into addressable words.
66
+ Supports:
67
+ - READ(start_addr, end_addr) → returns bit block
68
+ - WRITE(start_addr, data) → writes bits to memory
69
+ - Bit-to-embedding projection (for neural network consumption)
70
+
71
+ This mimics how a CPU accesses RAM:
72
+ - Each address points to a word (32 bits)
73
+ - Contiguous reads fetch a range of words
74
+ - No inherent "meaning" — bits are just bits until interpreted
75
+ """
76
+
77
+ def __init__(self, config: MemoryConfig):
78
+ super().__init__()
79
+ self.config = config
80
+
81
+ # The actual memory: (num_words, word_size) binary tensor
82
+ # Initialized randomly — represents "existing knowledge base"
83
+ self.register_buffer(
84
+ 'memory',
85
+ torch.randint(0, 2, (config.num_words, config.word_size)).float()
86
+ )
87
+
88
+ # Bit-to-embedding projection: converts raw bits into dense vectors
89
+ # This is learnable — the system learns what bit patterns mean
90
+ self.bit_encoder = nn.Sequential(
91
+ nn.Linear(config.word_size, 64),
92
+ nn.GELU(),
93
+ nn.Linear(64, 128),
94
+ nn.LayerNorm(128)
95
+ )
96
+
97
+ # Write projection: converts dense vectors back to bit probabilities
98
+ self.bit_decoder = nn.Sequential(
99
+ nn.Linear(128, 64),
100
+ nn.GELU(),
101
+ nn.Linear(64, config.word_size),
102
+ nn.Sigmoid() # output probabilities for each bit
103
+ )
104
+
105
+ def read(self, start_addr: torch.Tensor, end_addr: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Read a contiguous range of words from memory.
108
+
109
+ Args:
110
+ start_addr: (batch,) integer tensor of start addresses
111
+ end_addr: (batch,) integer tensor of end addresses
112
+
113
+ Returns:
114
+ bit_block: (batch, max_range, word_size) raw bits
115
+ encoded: (batch, max_range, 128) encoded memory content
116
+ """
117
+ batch_size = start_addr.shape[0]
118
+ max_range = self.config.max_read_range
119
+
120
+ # Clamp addresses to valid range
121
+ start_addr = start_addr.clamp(0, self.config.num_words - 1)
122
+ end_addr = end_addr.clamp(start_addr,
123
+ torch.minimum(start_addr + max_range,
124
+ torch.tensor(self.config.num_words)))
125
+
126
+ # Gather memory contents for each batch element
127
+ # Create index tensor for the address ranges
128
+ offsets = torch.arange(max_range, device=start_addr.device).unsqueeze(0) # (1, max_range)
129
+ addresses = start_addr.unsqueeze(1) + offsets # (batch, max_range)
130
+ addresses = addresses.clamp(0, self.config.num_words - 1)
131
+
132
+ # Create validity mask (addresses within [start, end) are valid)
133
+ range_lengths = (end_addr - start_addr).unsqueeze(1) # (batch, 1)
134
+ valid_mask = offsets < range_lengths # (batch, max_range)
135
+
136
+ # Gather bits
137
+ bit_block = self.memory[addresses] # (batch, max_range, word_size)
138
+ bit_block = bit_block * valid_mask.unsqueeze(-1).float() # zero out invalid
139
+
140
+ # Encode bits to dense vectors
141
+ encoded = self.bit_encoder(bit_block) # (batch, max_range, 128)
142
+ encoded = encoded * valid_mask.unsqueeze(-1).float()
143
+
144
+ return bit_block, encoded, valid_mask
145
+
146
+ def write(self, start_addr: torch.Tensor, data: torch.Tensor):
147
+ """
148
+ Write data to memory (differentiable soft-write).
149
+
150
+ Args:
151
+ start_addr: (batch,) start addresses
152
+ data: (batch, n_words, 128) encoded data to write
153
+ """
154
+ n_words = data.shape[1]
155
+
156
+ # Decode to bit probabilities
157
+ bit_probs = self.bit_decoder(data) # (batch, n_words, word_size)
158
+
159
+ # Hard bits via straight-through
160
+ hard_bits = (bit_probs > 0.5).float()
161
+ bits_to_write = hard_bits - bit_probs.detach() + bit_probs # ST trick
162
+
163
+ # Write to memory (last batch element wins for simplicity)
164
+ for b in range(start_addr.shape[0]):
165
+ addr = start_addr[b].long()
166
+ end = min(addr + n_words, self.config.num_words)
167
+ actual_n = end - addr
168
+ self.memory[addr:end] = bits_to_write[b, :actual_n].detach()
169
+
170
+ def soft_read(self, attention_weights: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Content-based soft read using attention weights over entire memory.
173
+ Used for differentiable end-to-end training.
174
+
175
+ Args:
176
+ attention_weights: (batch, num_words) soft address distribution
177
+
178
+ Returns:
179
+ encoded: (batch, 128) weighted memory content
180
+ """
181
+ # Encode all memory (expensive but differentiable)
182
+ all_encoded = self.bit_encoder(self.memory) # (num_words, 128)
183
+ # Weighted sum
184
+ encoded = torch.matmul(attention_weights, all_encoded) # (batch, 128)
185
+ return encoded
186
+
187
+
188
+ # =============================================================================
189
+ # Component 2: Small LeWorld Model (SLM)
190
+ # =============================================================================
191
+
192
+ class StateEncoder(nn.Module):
193
+ """Encodes past_state and current_state into a joint representation."""
194
+
195
+ def __init__(self, state_dim: int, d_model: int):
196
+ super().__init__()
197
+ self.past_proj = nn.Linear(state_dim, d_model)
198
+ self.curr_proj = nn.Linear(state_dim, d_model)
199
+ self.combiner = nn.Sequential(
200
+ nn.Linear(d_model * 2, d_model),
201
+ nn.GELU(),
202
+ nn.LayerNorm(d_model)
203
+ )
204
+
205
+ def forward(self, past_state: torch.Tensor, current_state: torch.Tensor) -> torch.Tensor:
206
+ """
207
+ Args:
208
+ past_state: (batch, state_dim)
209
+ current_state: (batch, state_dim)
210
+ Returns:
211
+ combined: (batch, d_model)
212
+ """
213
+ past_enc = F.gelu(self.past_proj(past_state))
214
+ curr_enc = F.gelu(self.curr_proj(current_state))
215
+ combined = self.combiner(torch.cat([past_enc, curr_enc], dim=-1))
216
+ return combined
217
+
218
+
219
+ class CharacteristicsEncoder(nn.Module):
220
+ """Encodes static characteristics/context."""
221
+
222
+ def __init__(self, char_dim: int, d_model: int):
223
+ super().__init__()
224
+ self.encoder = nn.Sequential(
225
+ nn.Linear(char_dim, d_model),
226
+ nn.GELU(),
227
+ nn.LayerNorm(d_model)
228
+ )
229
+
230
+ def forward(self, characteristics: torch.Tensor) -> torch.Tensor:
231
+ return self.encoder(characteristics)
232
+
233
+
234
+ class TransformerBlock(nn.Module):
235
+ """Standard transformer block with pre-norm."""
236
+
237
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
238
+ super().__init__()
239
+ self.norm1 = nn.LayerNorm(d_model)
240
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
241
+ self.norm2 = nn.LayerNorm(d_model)
242
+ self.ffn = nn.Sequential(
243
+ nn.Linear(d_model, d_model * 4),
244
+ nn.GELU(),
245
+ nn.Linear(d_model * 4, d_model),
246
+ nn.Dropout(dropout)
247
+ )
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ # Self-attention with pre-norm
251
+ normed = self.norm1(x)
252
+ attn_out, _ = self.attn(normed, normed, normed)
253
+ x = x + attn_out
254
+ # FFN with pre-norm
255
+ x = x + self.ffn(self.norm2(x))
256
+ return x
257
+
258
+
259
+ class CrossAttentionBlock(nn.Module):
260
+ """Cross-attention: state attends to characteristics."""
261
+
262
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
263
+ super().__init__()
264
+ self.norm_q = nn.LayerNorm(d_model)
265
+ self.norm_kv = nn.LayerNorm(d_model)
266
+ self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
267
+ self.norm_ff = nn.LayerNorm(d_model)
268
+ self.ffn = nn.Sequential(
269
+ nn.Linear(d_model, d_model * 4),
270
+ nn.GELU(),
271
+ nn.Linear(d_model * 4, d_model),
272
+ nn.Dropout(dropout)
273
+ )
274
+
275
+ def forward(self, query: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
276
+ normed_q = self.norm_q(query)
277
+ normed_kv = self.norm_kv(context)
278
+ attn_out, _ = self.cross_attn(normed_q, normed_kv, normed_kv)
279
+ x = query + attn_out
280
+ x = x + self.ffn(self.norm_ff(x))
281
+ return x
282
+
283
+
284
+ class AddressHead(nn.Module):
285
+ """
286
+ Produces memory address range (start_addr, end_addr) from hidden state.
287
+
288
+ Uses two approaches:
289
+ 1. HARD mode: argmax over address space (for inference)
290
+ 2. SOFT mode: attention weights over memory (for differentiable training)
291
+ """
292
+
293
+ def __init__(self, d_model: int, address_space: int, max_range: int):
294
+ super().__init__()
295
+ self.address_space = address_space
296
+ self.max_range = max_range
297
+
298
+ # Produce start address logits
299
+ # We don't have a linear over 65K — that's too many params
300
+ # Instead: predict address as composition of sub-addresses (like product keys)
301
+ self.addr_bits = int(math.log2(address_space)) # 16 for 65536
302
+ assert 2 ** self.addr_bits == address_space, "address_space must be power of 2"
303
+
304
+ # Split address into high byte and low byte (8+8 = 16 bits)
305
+ self.half_bits = self.addr_bits // 2 # 8
306
+ self.half_space = 2 ** self.half_bits # 256
307
+
308
+ # Predict high and low parts separately (product key approach)
309
+ self.start_high = nn.Linear(d_model, self.half_space) # 256 outputs
310
+ self.start_low = nn.Linear(d_model, self.half_space) # 256 outputs
311
+
312
+ # Predict range length (how many words to read)
313
+ self.range_head = nn.Sequential(
314
+ nn.Linear(d_model, d_model // 2),
315
+ nn.GELU(),
316
+ nn.Linear(d_model // 2, max_range)
317
+ )
318
+
319
+ # Confidence head
320
+ self.confidence_head = nn.Sequential(
321
+ nn.Linear(d_model, d_model // 4),
322
+ nn.GELU(),
323
+ nn.Linear(d_model // 4, 1),
324
+ nn.Sigmoid()
325
+ )
326
+
327
+ def forward(self, hidden: torch.Tensor) -> dict:
328
+ """
329
+ Args:
330
+ hidden: (batch, d_model)
331
+
332
+ Returns:
333
+ dict with:
334
+ start_addr: (batch,) integer addresses
335
+ end_addr: (batch,) integer addresses
336
+ range_length: (batch,) how many words to read
337
+ confidence: (batch,) read confidence score
338
+ start_logits_high: (batch, 256) for soft addressing
339
+ start_logits_low: (batch, 256) for soft addressing
340
+ range_logits: (batch, max_range) for soft range selection
341
+ """
342
+ batch_size = hidden.shape[0]
343
+
344
+ # Product-key address generation
345
+ high_logits = self.start_high(hidden) # (batch, 256)
346
+ low_logits = self.start_low(hidden) # (batch, 256)
347
+
348
+ # Hard address via argmax
349
+ high_idx = high_logits.argmax(dim=-1) # (batch,)
350
+ low_idx = low_logits.argmax(dim=-1) # (batch,)
351
+ start_addr = high_idx * self.half_space + low_idx # (batch,) 0..65535
352
+
353
+ # Range length
354
+ range_logits = self.range_head(hidden) # (batch, max_range)
355
+ range_length = range_logits.argmax(dim=-1) + 1 # (batch,) 1..max_range
356
+ end_addr = (start_addr + range_length).clamp(max=self.address_space - 1)
357
+
358
+ # Confidence
359
+ confidence = self.confidence_head(hidden).squeeze(-1) # (batch,)
360
+
361
+ return {
362
+ 'start_addr': start_addr,
363
+ 'end_addr': end_addr,
364
+ 'range_length': range_length,
365
+ 'confidence': confidence,
366
+ 'start_logits_high': high_logits,
367
+ 'start_logits_low': low_logits,
368
+ 'range_logits': range_logits,
369
+ }
370
+
371
+
372
+ class SmallLeWorldModel(nn.Module):
373
+ """
374
+ SLM: Small LeWorld Model (~1.5M params)
375
+
376
+ Takes (past_state, current_state, characteristics) and produces
377
+ a memory address range pointing to the most useful memory for
378
+ next-state prediction.
379
+
380
+ Architecture:
381
+ 1. Encode past + current state → state representation
382
+ 2. Encode characteristics
383
+ 3. Cross-attend: state attends to characteristics
384
+ 4. Self-attention transformer layers
385
+ 5. Address head: output (start_addr, end_addr, confidence)
386
+ """
387
+
388
+ def __init__(self, config: SLMConfig, slm_id: int = 0):
389
+ super().__init__()
390
+ self.config = config
391
+ self.slm_id = slm_id
392
+
393
+ # Encoders
394
+ self.state_encoder = StateEncoder(config.state_dim, config.d_model)
395
+ self.char_encoder = CharacteristicsEncoder(config.char_dim, config.d_model)
396
+
397
+ # Cross-attention: state ← characteristics
398
+ self.cross_attn = CrossAttentionBlock(config.d_model, config.n_heads, config.dropout)
399
+
400
+ # Self-attention transformer
401
+ self.transformer_layers = nn.ModuleList([
402
+ TransformerBlock(config.d_model, config.n_heads, config.dropout)
403
+ for _ in range(config.n_layers)
404
+ ])
405
+ self.final_norm = nn.LayerNorm(config.d_model)
406
+
407
+ # Address output head
408
+ self.address_head = AddressHead(config.d_model, config.address_space, config.max_read_range)
409
+
410
+ def forward(
411
+ self,
412
+ past_state: torch.Tensor, # (batch, state_dim)
413
+ current_state: torch.Tensor, # (batch, state_dim)
414
+ characteristics: torch.Tensor, # (batch, char_dim)
415
+ ) -> dict:
416
+ """
417
+ Forward pass: state + characteristics → memory address range.
418
+
419
+ Returns dict with address info + internal hidden state.
420
+ """
421
+ # Encode states
422
+ state_repr = self.state_encoder(past_state, current_state) # (batch, d_model)
423
+
424
+ # Encode characteristics
425
+ char_repr = self.char_encoder(characteristics) # (batch, d_model)
426
+
427
+ # Cross-attention: state queries characteristics
428
+ # Unsqueeze to sequence dim for attention
429
+ state_seq = state_repr.unsqueeze(1) # (batch, 1, d_model)
430
+ char_seq = char_repr.unsqueeze(1) # (batch, 1, d_model)
431
+
432
+ enriched = self.cross_attn(state_seq, char_seq) # (batch, 1, d_model)
433
+
434
+ # Self-attention layers
435
+ hidden = enriched
436
+ for layer in self.transformer_layers:
437
+ hidden = layer(hidden)
438
+
439
+ hidden = self.final_norm(hidden)
440
+ hidden = hidden.squeeze(1) # (batch, d_model)
441
+
442
+ # Produce address range
443
+ addr_output = self.address_head(hidden)
444
+ addr_output['hidden'] = hidden # keep for BLM to use
445
+ addr_output['slm_id'] = self.slm_id
446
+
447
+ return addr_output
448
+
449
+
450
+ # =============================================================================
451
+ # Component 3: Big LeWorld Model (BLM)
452
+ # =============================================================================
453
+
454
+ class StraightThroughSigmoid(torch.autograd.Function):
455
+ """
456
+ Binary gate: hard 0/1 in forward, sigmoid gradient in backward.
457
+ From literature: ST-GS (Jang et al. 2017) + Switch Transformer routing.
458
+ """
459
+ @staticmethod
460
+ def forward(ctx, logits):
461
+ probs = torch.sigmoid(logits)
462
+ ctx.save_for_backward(probs)
463
+ return (probs > 0.5).float()
464
+
465
+ @staticmethod
466
+ def backward(ctx, grad_output):
467
+ probs, = ctx.saved_tensors
468
+ # Sigmoid derivative: p * (1-p)
469
+ return grad_output * probs * (1 - probs)
470
+
471
+
472
+ class BLMRouter(nn.Module):
473
+ """
474
+ Routes/selects which SLMs to activate.
475
+ Produces binary mask like [1, 0, 1].
476
+
477
+ Uses Straight-Through Sigmoid for differentiable binary selection.
478
+ Includes load-balancing loss to prevent degenerate routing.
479
+ """
480
+
481
+ def __init__(self, d_model: int, n_slms: int):
482
+ super().__init__()
483
+ self.n_slms = n_slms
484
+
485
+ self.gate = nn.Sequential(
486
+ nn.Linear(d_model, d_model // 2),
487
+ nn.GELU(),
488
+ nn.Linear(d_model // 2, n_slms)
489
+ )
490
+
491
+ # Temperature for annealing (start warm, cool down)
492
+ self.register_buffer('temperature', torch.tensor(1.0))
493
+
494
+ def forward(self, state_repr: torch.Tensor) -> Tuple[torch.Tensor, dict]:
495
+ """
496
+ Args:
497
+ state_repr: (batch, d_model) encoded current state
498
+
499
+ Returns:
500
+ binary_mask: (batch, n_slms) hard 0/1 selection
501
+ routing_info: dict with probs, losses, etc.
502
+ """
503
+ logits = self.gate(state_repr) # (batch, n_slms)
504
+
505
+ # Scale by temperature
506
+ scaled_logits = logits / self.temperature.clamp(min=0.1)
507
+
508
+ probs = torch.sigmoid(scaled_logits) # (batch, n_slms)
509
+
510
+ # Straight-through binary: hard in forward, soft in backward
511
+ hard_mask = (probs > 0.5).float()
512
+ binary_mask = hard_mask - probs.detach() + probs # THE ST TRICK
513
+
514
+ # Ensure at least one SLM is selected (don't want all zeros)
515
+ # If all zeros, force-select the highest probability SLM
516
+ all_zero = (binary_mask.sum(dim=-1) == 0) # (batch,)
517
+ if all_zero.any():
518
+ max_idx = probs[all_zero].argmax(dim=-1)
519
+ forced = torch.zeros_like(probs[all_zero])
520
+ forced.scatter_(1, max_idx.unsqueeze(1), 1.0)
521
+ binary_mask[all_zero] = forced
522
+
523
+ # Load balance loss: encourage roughly equal usage of SLMs
524
+ usage_per_slm = binary_mask.mean(dim=0) # (n_slms,)
525
+ target_usage = 1.0 / self.n_slms
526
+ balance_loss = ((usage_per_slm - target_usage) ** 2).sum()
527
+
528
+ # Entropy loss: encourage decisive routing (not all ~0.5)
529
+ entropy = -(probs * torch.log(probs + 1e-8) +
530
+ (1 - probs) * torch.log(1 - probs + 1e-8))
531
+ entropy_loss = entropy.mean()
532
+
533
+ routing_info = {
534
+ 'probs': probs,
535
+ 'binary_mask': binary_mask,
536
+ 'balance_loss': balance_loss,
537
+ 'entropy_loss': entropy_loss,
538
+ 'logits': logits,
539
+ }
540
+
541
+ return binary_mask, routing_info
542
+
543
+ def anneal_temperature(self, step: int, anneal_rate: float = 3e-5, min_temp: float = 0.1):
544
+ """Anneal temperature: start warm (exploratory), cool down (decisive)."""
545
+ new_temp = max(min_temp, math.exp(-anneal_rate * step))
546
+ self.temperature.fill_(new_temp)
547
+
548
+
549
+ class InfoRequestHead(nn.Module):
550
+ """
551
+ Produces a query vector representing "what information do I need next?"
552
+
553
+ This is the key innovation: instead of passively receiving all SLM outputs,
554
+ the BLM actively requests specific information. This query modulates which
555
+ memory regions the SLMs should focus on in the NEXT timestep.
556
+ """
557
+
558
+ def __init__(self, d_model: int, query_dim: int):
559
+ super().__init__()
560
+ self.query_generator = nn.Sequential(
561
+ nn.Linear(d_model, d_model),
562
+ nn.GELU(),
563
+ nn.Linear(d_model, query_dim),
564
+ nn.LayerNorm(query_dim)
565
+ )
566
+
567
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
568
+ """
569
+ Args:
570
+ hidden: (batch, d_model) BLM's internal state
571
+ Returns:
572
+ info_query: (batch, query_dim) "what do I need next?"
573
+ """
574
+ return self.query_generator(hidden)
575
+
576
+
577
+ class BigLeWorldModel(nn.Module):
578
+ """
579
+ BLM: Big LeWorld Model (~12M params)
580
+
581
+ Two roles:
582
+ 1. ROUTER: Select which SLMs to activate (binary mask)
583
+ 2. PREDICTOR: Given selected memory contents, predict next state
584
+
585
+ Plus: Info-Request Head that asks "what information is needed next?"
586
+
587
+ Architecture:
588
+ 1. Encode current state → routing decision
589
+ 2. Receive memory reads from selected SLMs
590
+ 3. Transformer processes (state + memories)
591
+ 4. Predict next state
592
+ 5. Generate info request for next timestep
593
+ """
594
+
595
+ def __init__(self, config: BLMConfig):
596
+ super().__init__()
597
+ self.config = config
598
+
599
+ # State encoder (maps state_dim → d_model)
600
+ self.state_encoder = nn.Sequential(
601
+ nn.Linear(config.state_dim, config.d_model),
602
+ nn.GELU(),
603
+ nn.LayerNorm(config.d_model)
604
+ )
605
+
606
+ # Memory read encoder (maps encoded memory → d_model)
607
+ self.memory_encoder = nn.Sequential(
608
+ nn.Linear(128, config.d_model), # 128 from ArtificialMemory bit_encoder
609
+ nn.GELU(),
610
+ nn.LayerNorm(config.d_model)
611
+ )
612
+
613
+ # SLM hidden state encoder (maps SLM hidden → d_model)
614
+ self.slm_hidden_encoder = nn.Sequential(
615
+ nn.Linear(128, config.d_model), # 128 = SLM d_model
616
+ nn.GELU(),
617
+ nn.LayerNorm(config.d_model)
618
+ )
619
+
620
+ # Router: selects which SLMs to use
621
+ self.router = BLMRouter(config.d_model, config.n_slms)
622
+
623
+ # Transformer backbone
624
+ self.transformer_layers = nn.ModuleList([
625
+ TransformerBlock(config.d_model, config.n_heads, config.dropout)
626
+ for _ in range(config.n_layers)
627
+ ])
628
+ self.final_norm = nn.LayerNorm(config.d_model)
629
+
630
+ # Prediction heads
631
+ self.next_state_head = nn.Sequential(
632
+ nn.Linear(config.d_model, config.d_model),
633
+ nn.GELU(),
634
+ nn.Linear(config.d_model, config.state_dim)
635
+ )
636
+
637
+ # Info request head: "what do I need next?"
638
+ self.info_request = InfoRequestHead(config.d_model, config.info_query_dim)
639
+
640
+ # Learnable tokens
641
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
642
+ self.state_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
643
+ self.memory_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
644
+
645
+ def forward(
646
+ self,
647
+ past_state: torch.Tensor, # (batch, state_dim)
648
+ current_state: torch.Tensor, # (batch, state_dim)
649
+ slm_outputs: List[dict], # list of SLM output dicts
650
+ memory_reads: List[torch.Tensor], # list of (batch, range, 128) encoded memory
651
+ info_query_prev: Optional[torch.Tensor] = None, # (batch, query_dim) from previous step
652
+ ) -> dict:
653
+ """
654
+ Full BLM forward pass.
655
+
656
+ Returns:
657
+ dict with next_state, binary_mask, info_query, losses, etc.
658
+ """
659
+ batch_size = current_state.shape[0]
660
+
661
+ # 1. Encode current state for routing decision
662
+ state_enc = self.state_encoder(current_state) # (batch, d_model)
663
+
664
+ # 2. Route: select which SLMs to use
665
+ binary_mask, routing_info = self.router(state_enc) # (batch, n_slms)
666
+
667
+ # 3. Aggregate selected memory reads
668
+ # For each SLM, apply its binary gate and encode its memory read
669
+ memory_tokens = []
670
+ for i, (slm_out, mem_read) in enumerate(zip(slm_outputs, memory_reads)):
671
+ gate = binary_mask[:, i:i+1] # (batch, 1)
672
+
673
+ # Gate the SLM's hidden representation
674
+ slm_hidden = self.slm_hidden_encoder(slm_out['hidden']) # (batch, d_model)
675
+ slm_hidden = slm_hidden * gate # zero if SLM not selected
676
+
677
+ # Gate and encode the memory read
678
+ # mem_read: (batch, range_len, 128)
679
+ mem_enc = self.memory_encoder(mem_read) # (batch, range_len, d_model)
680
+ mem_enc = mem_enc * gate.unsqueeze(-1) # zero if SLM not selected
681
+
682
+ # Pool memory read to single token (mean pool over range)
683
+ mem_pooled = mem_enc.mean(dim=1, keepdim=True) # (batch, 1, d_model)
684
+
685
+ memory_tokens.append(slm_hidden.unsqueeze(1)) # SLM hidden as token
686
+ memory_tokens.append(mem_pooled) # memory content as token
687
+
688
+ # 4. Build input sequence for transformer
689
+ # [CLS] + [state] + [slm_0_hidden, slm_0_mem, slm_1_hidden, slm_1_mem, ...]
690
+ cls = self.cls_token.expand(batch_size, -1, -1)
691
+ state_token = state_enc.unsqueeze(1) + self.state_type_embed # (batch, 1, d_model)
692
+
693
+ # Add memory type embedding to memory tokens
694
+ mem_sequence = torch.cat(memory_tokens, dim=1) # (batch, 2*n_slms, d_model)
695
+ mem_sequence = mem_sequence + self.memory_type_embed
696
+
697
+ sequence = torch.cat([cls, state_token, mem_sequence], dim=1)
698
+ # Shape: (batch, 1 + 1 + 2*n_slms, d_model)
699
+
700
+ # 5. Transformer processing
701
+ hidden = sequence
702
+ for layer in self.transformer_layers:
703
+ hidden = layer(hidden)
704
+ hidden = self.final_norm(hidden)
705
+
706
+ # 6. Extract predictions from CLS token
707
+ cls_output = hidden[:, 0, :] # (batch, d_model)
708
+
709
+ # 7. Predict next state
710
+ next_state_pred = self.next_state_head(cls_output) # (batch, state_dim)
711
+
712
+ # 8. Generate info request for next timestep
713
+ info_query = self.info_request(cls_output) # (batch, query_dim)
714
+
715
+ return {
716
+ 'next_state': next_state_pred,
717
+ 'binary_mask': binary_mask,
718
+ 'info_query': info_query,
719
+ 'routing_info': routing_info,
720
+ 'cls_output': cls_output,
721
+ }
722
+
723
+
724
+ # =============================================================================
725
+ # Component 4: Full LeWorld System
726
+ # =============================================================================
727
+
728
+ class LeWorldSystem(nn.Module):
729
+ """
730
+ Complete LeWorld Memory Architecture.
731
+
732
+ Orchestrates:
733
+ - Artificial Memory (bit-level storage)
734
+ - 3 SLMs (produce memory address ranges)
735
+ - 1 BLM (selects SLMs, reads memory, predicts next state)
736
+
737
+ Training loop:
738
+ 1. BLM sees current state → routes to SLMs
739
+ 2. Selected SLMs produce address ranges
740
+ 3. Memory is read at those ranges
741
+ 4. BLM aggregates memory + state → predicts next state
742
+ 5. BLM generates info-request for next step
743
+
744
+ Losses:
745
+ - next_state_loss: MSE between predicted and actual next state
746
+ - routing_balance_loss: encourage balanced SLM usage
747
+ - address_diversity_loss: encourage SLMs to read different memory regions
748
+ - info_utility_loss: did the info request lead to useful retrievals?
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ mem_config: MemoryConfig = MemoryConfig(),
754
+ slm_config: SLMConfig = SLMConfig(),
755
+ blm_config: BLMConfig = BLMConfig(),
756
+ ):
757
+ super().__init__()
758
+
759
+ # Artificial Memory
760
+ self.memory = ArtificialMemory(mem_config)
761
+
762
+ # 3 SLMs
763
+ self.slms = nn.ModuleList([
764
+ SmallLeWorldModel(slm_config, slm_id=i)
765
+ for i in range(blm_config.n_slms)
766
+ ])
767
+
768
+ # BLM
769
+ self.blm = BigLeWorldModel(blm_config)
770
+
771
+ # Info-query → SLM modulation: the BLM's info request
772
+ # influences what SLMs look for in the next timestep
773
+ self.info_to_slm = nn.Linear(blm_config.info_query_dim, slm_config.state_dim)
774
+
775
+ self.config = {
776
+ 'mem': mem_config,
777
+ 'slm': slm_config,
778
+ 'blm': blm_config,
779
+ }
780
+
781
+ def forward(
782
+ self,
783
+ past_state: torch.Tensor, # (batch, state_dim)
784
+ current_state: torch.Tensor, # (batch, state_dim)
785
+ characteristics: torch.Tensor, # (batch, char_dim)
786
+ next_state_target: Optional[torch.Tensor] = None, # (batch, state_dim) for training
787
+ info_query_prev: Optional[torch.Tensor] = None, # from previous timestep
788
+ ) -> dict:
789
+ """
790
+ Full system forward pass.
791
+ """
792
+ batch_size = current_state.shape[0]
793
+
794
+ # If we have a previous info query, modulate the current state
795
+ # This is how the BLM's "what do I need?" influences retrieval
796
+ if info_query_prev is not None:
797
+ info_modulation = self.info_to_slm(info_query_prev) # (batch, state_dim)
798
+ modulated_state = current_state + 0.1 * info_modulation # gentle modulation
799
+ else:
800
+ modulated_state = current_state
801
+
802
+ # 1. Run all 3 SLMs to get address ranges
803
+ slm_outputs = []
804
+ for slm in self.slms:
805
+ out = slm(past_state, modulated_state, characteristics)
806
+ slm_outputs.append(out)
807
+
808
+ # 2. Read memory at each SLM's address range
809
+ memory_reads = []
810
+ for slm_out in slm_outputs:
811
+ _, encoded, valid_mask = self.memory.read(
812
+ slm_out['start_addr'],
813
+ slm_out['end_addr']
814
+ )
815
+ memory_reads.append(encoded)
816
+
817
+ # 3. BLM processes everything
818
+ blm_output = self.blm(
819
+ past_state, current_state,
820
+ slm_outputs, memory_reads,
821
+ info_query_prev
822
+ )
823
+
824
+ # 4. Compute losses if training
825
+ losses = {}
826
+ if next_state_target is not None:
827
+ # Primary loss: next state prediction
828
+ losses['next_state_loss'] = F.mse_loss(
829
+ blm_output['next_state'], next_state_target
830
+ )
831
+
832
+ # Routing balance loss
833
+ losses['balance_loss'] = blm_output['routing_info']['balance_loss']
834
+
835
+ # Address diversity loss: penalize SLMs for reading same regions
836
+ addresses = torch.stack([
837
+ slm_out['start_addr'].float() for slm_out in slm_outputs
838
+ ], dim=1) # (batch, n_slms)
839
+ # Pairwise distance between SLM addresses (want to maximize)
840
+ addr_diff = torch.cdist(addresses.unsqueeze(-1), addresses.unsqueeze(-1))
841
+ diversity_loss = -addr_diff.mean() # negative = encourage large distances
842
+ losses['diversity_loss'] = diversity_loss
843
+
844
+ # Total loss
845
+ losses['total_loss'] = (
846
+ losses['next_state_loss']
847
+ + 0.01 * losses['balance_loss']
848
+ + 0.001 * losses['diversity_loss']
849
+ )
850
+
851
+ return {
852
+ 'next_state': blm_output['next_state'],
853
+ 'binary_mask': blm_output['binary_mask'],
854
+ 'info_query': blm_output['info_query'],
855
+ 'slm_outputs': slm_outputs,
856
+ 'memory_reads': memory_reads,
857
+ 'losses': losses,
858
+ 'routing_info': blm_output['routing_info'],
859
+ }
860
+
861
+ def multi_step_forward(
862
+ self,
863
+ states: torch.Tensor, # (batch, T, state_dim) sequence of states
864
+ characteristics: torch.Tensor, # (batch, char_dim) static
865
+ n_steps: int = None,
866
+ ) -> dict:
867
+ """
868
+ Run the system over multiple timesteps autoregressively.
869
+
870
+ For training: teacher forcing with ground-truth states
871
+ """
872
+ batch_size, T, state_dim = states.shape
873
+ if n_steps is None:
874
+ n_steps = T - 1 # predict all future states
875
+
876
+ all_predictions = []
877
+ all_masks = []
878
+ total_loss = None
879
+ info_query = None
880
+
881
+ for t in range(min(n_steps, T - 1)):
882
+ past_state = states[:, max(0, t-1), :]
883
+ current_state = states[:, t, :]
884
+ next_state_target = states[:, t+1, :]
885
+
886
+ output = self.forward(
887
+ past_state, current_state, characteristics,
888
+ next_state_target, info_query
889
+ )
890
+
891
+ all_predictions.append(output['next_state'])
892
+ all_masks.append(output['binary_mask'])
893
+ info_query = output['info_query']
894
+
895
+ if output['losses']:
896
+ if total_loss is None:
897
+ total_loss = output['losses']['total_loss']
898
+ else:
899
+ total_loss = total_loss + output['losses']['total_loss']
900
+
901
+ if total_loss is None:
902
+ total_loss = torch.tensor(0.0, device=states.device)
903
+ return {
904
+ 'predictions': torch.stack(all_predictions, dim=1),
905
+ 'masks': torch.stack(all_masks, dim=1),
906
+ 'total_loss': total_loss / max(1, min(n_steps, T - 1)),
907
+ 'final_info_query': info_query,
908
+ }
909
+
910
+
911
+ # =============================================================================
912
+ # Parameter Count Verification
913
+ # =============================================================================
914
+
915
+ def count_params(model, name="Model"):
916
+ """Count and display parameter breakdown."""
917
+ total = sum(p.numel() for p in model.parameters())
918
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
919
+ print(f"\n{'='*60}")
920
+ print(f"{name}: {total:,} total params ({trainable:,} trainable)")
921
+ print(f"{'='*60}")
922
+
923
+ for child_name, child in model.named_children():
924
+ child_params = sum(p.numel() for p in child.parameters())
925
+ if child_params > 0:
926
+ print(f" {child_name}: {child_params:,}")
927
+
928
+ return total
929
+
930
+
931
+ # =============================================================================
932
+ # Demo / Test
933
+ # =============================================================================
934
+
935
+ if __name__ == "__main__":
936
+ print("LeWorld Memory Architecture — Component Verification")
937
+ print("=" * 60)
938
+
939
+ # Configs
940
+ mem_config = MemoryConfig()
941
+ slm_config = SLMConfig()
942
+ blm_config = BLMConfig()
943
+
944
+ # Build system
945
+ system = LeWorldSystem(mem_config, slm_config, blm_config)
946
+
947
+ # Count parameters
948
+ print("\n--- Parameter Counts ---")
949
+ count_params(system.memory, "Artificial Memory")
950
+ for i, slm in enumerate(system.slms):
951
+ count_params(slm, f"SLM-{i}")
952
+ count_params(system.blm, "BLM")
953
+ count_params(system, "Full System")
954
+
955
+ # Test forward pass
956
+ print("\n--- Forward Pass Test ---")
957
+ batch_size = 4
958
+ state_dim = slm_config.state_dim
959
+ char_dim = slm_config.char_dim
960
+
961
+ past_state = torch.randn(batch_size, state_dim)
962
+ current_state = torch.randn(batch_size, state_dim)
963
+ characteristics = torch.randn(batch_size, char_dim)
964
+ next_state = torch.randn(batch_size, state_dim)
965
+
966
+ output = system(past_state, current_state, characteristics, next_state)
967
+
968
+ print(f"Next state prediction shape: {output['next_state'].shape}")
969
+ print(f"Binary mask (SLM selection): {output['binary_mask']}")
970
+ print(f"Info query shape: {output['info_query'].shape}")
971
+ print(f"Losses: {output['losses']}")
972
+
973
+ # Test multi-step
974
+ print("\n--- Multi-Step Test ---")
975
+ T = 10
976
+ states = torch.randn(batch_size, T, state_dim)
977
+
978
+ ms_output = system.multi_step_forward(states, characteristics)
979
+ print(f"Predictions shape: {ms_output['predictions'].shape}")
980
+ print(f"Masks shape: {ms_output['masks'].shape}")
981
+ print(f"Average loss: {ms_output['total_loss'].item():.4f}")
982
+
983
+ # Show routing patterns over time
984
+ print("\n--- Routing Patterns Over Time ---")
985
+ masks = ms_output['masks'][0].detach() # first batch element
986
+ for t in range(masks.shape[0]):
987
+ mask = masks[t].int().tolist()
988
+ print(f" Step {t}: SLMs selected = {mask}")
989
+
990
+ print("\n✅ All components verified successfully!")