File size: 37,676 Bytes
52bd035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
"""
LeWorld Memory Architecture β€” Complete Implementation
=====================================================
Component 1: Artificial Memory (CPU-style bit storage)
Component 2: SLMs (Small LeWorld Models, ~1.5M params each)
Component 3: BLM (Big LeWorld Model, ~12M params)
Component 4: Full System with training loop
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Tuple, List, Optional


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class MemoryConfig:
    """CPU-style artificial memory configuration."""
    num_words: int = 65536       # 64K addressable words (like 64K RAM)
    word_size: int = 32          # 32 bits per word
    address_bits: int = 16       # 2^16 = 65536 addresses
    max_read_range: int = 256    # max words per single read operation


@dataclass
class SLMConfig:
    """Small LeWorld Model configuration (~1.5M params)."""
    d_model: int = 128           # internal dimension
    n_heads: int = 4             # attention heads
    n_layers: int = 2            # transformer layers
    state_dim: int = 64          # state vector dimension
    char_dim: int = 32           # characteristics vector dimension
    address_space: int = 65536   # must match MemoryConfig.num_words
    max_read_range: int = 256    # must match MemoryConfig.max_read_range
    dropout: float = 0.1


@dataclass
class BLMConfig:
    """Big LeWorld Model configuration (~12M params)."""
    d_model: int = 384           # internal dimension
    n_heads: int = 6             # attention heads
    n_layers: int = 6            # transformer layers
    state_dim: int = 64          # state vector dimension
    n_slms: int = 3              # number of SLMs to route over
    memory_read_dim: int = 256   # dimension of encoded memory reads
    info_query_dim: int = 128    # dimension of "what info do I need" query
    dropout: float = 0.1


# =============================================================================
# Component 1: Artificial Memory
# =============================================================================

class ArtificialMemory(nn.Module):
    """
    CPU-style bit-level memory with address-range access.
    
    Stores data as actual bits (0/1 tensors), organized into addressable words.
    Supports:
    - READ(start_addr, end_addr) β†’ returns bit block
    - WRITE(start_addr, data) β†’ writes bits to memory
    - Bit-to-embedding projection (for neural network consumption)
    
    This mimics how a CPU accesses RAM:
    - Each address points to a word (32 bits)
    - Contiguous reads fetch a range of words
    - No inherent "meaning" β€” bits are just bits until interpreted
    """
    
    def __init__(self, config: MemoryConfig):
        super().__init__()
        self.config = config
        
        # The actual memory: (num_words, word_size) binary tensor
        # Initialized randomly β€” represents "existing knowledge base"
        self.register_buffer(
            'memory',
            torch.randint(0, 2, (config.num_words, config.word_size)).float()
        )
        
        # Bit-to-embedding projection: converts raw bits into dense vectors
        # This is learnable β€” the system learns what bit patterns mean
        self.bit_encoder = nn.Sequential(
            nn.Linear(config.word_size, 64),
            nn.GELU(),
            nn.Linear(64, 128),
            nn.LayerNorm(128)
        )
        
        # Write projection: converts dense vectors back to bit probabilities
        self.bit_decoder = nn.Sequential(
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Linear(64, config.word_size),
            nn.Sigmoid()  # output probabilities for each bit
        )
    
    def read(self, start_addr: torch.Tensor, end_addr: torch.Tensor) -> torch.Tensor:
        """
        Read a contiguous range of words from memory.
        
        Args:
            start_addr: (batch,) integer tensor of start addresses
            end_addr: (batch,) integer tensor of end addresses
        
        Returns:
            bit_block: (batch, max_range, word_size) raw bits
            encoded: (batch, max_range, 128) encoded memory content
        """
        batch_size = start_addr.shape[0]
        max_range = self.config.max_read_range
        
        # Clamp addresses to valid range
        start_addr = start_addr.clamp(0, self.config.num_words - 1)
        end_addr = end_addr.clamp(start_addr, 
                                   torch.minimum(start_addr + max_range, 
                                                 torch.tensor(self.config.num_words)))
        
        # Gather memory contents for each batch element
        # Create index tensor for the address ranges
        offsets = torch.arange(max_range, device=start_addr.device).unsqueeze(0)  # (1, max_range)
        addresses = start_addr.unsqueeze(1) + offsets  # (batch, max_range)
        addresses = addresses.clamp(0, self.config.num_words - 1)
        
        # Create validity mask (addresses within [start, end) are valid)
        range_lengths = (end_addr - start_addr).unsqueeze(1)  # (batch, 1)
        valid_mask = offsets < range_lengths  # (batch, max_range)
        
        # Gather bits
        bit_block = self.memory[addresses]  # (batch, max_range, word_size)
        bit_block = bit_block * valid_mask.unsqueeze(-1).float()  # zero out invalid
        
        # Encode bits to dense vectors
        encoded = self.bit_encoder(bit_block)  # (batch, max_range, 128)
        encoded = encoded * valid_mask.unsqueeze(-1).float()
        
        return bit_block, encoded, valid_mask
    
    def write(self, start_addr: torch.Tensor, data: torch.Tensor):
        """
        Write data to memory (differentiable soft-write).
        
        Args:
            start_addr: (batch,) start addresses
            data: (batch, n_words, 128) encoded data to write
        """
        n_words = data.shape[1]
        
        # Decode to bit probabilities
        bit_probs = self.bit_decoder(data)  # (batch, n_words, word_size)
        
        # Hard bits via straight-through
        hard_bits = (bit_probs > 0.5).float()
        bits_to_write = hard_bits - bit_probs.detach() + bit_probs  # ST trick
        
        # Write to memory (last batch element wins for simplicity)
        for b in range(start_addr.shape[0]):
            addr = start_addr[b].long()
            end = min(addr + n_words, self.config.num_words)
            actual_n = end - addr
            self.memory[addr:end] = bits_to_write[b, :actual_n].detach()
    
    def soft_read(self, attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Content-based soft read using attention weights over entire memory.
        Used for differentiable end-to-end training.
        
        Args:
            attention_weights: (batch, num_words) soft address distribution
        
        Returns:
            encoded: (batch, 128) weighted memory content
        """
        # Encode all memory (expensive but differentiable)
        all_encoded = self.bit_encoder(self.memory)  # (num_words, 128)
        # Weighted sum
        encoded = torch.matmul(attention_weights, all_encoded)  # (batch, 128)
        return encoded


# =============================================================================
# Component 2: Small LeWorld Model (SLM)
# =============================================================================

class StateEncoder(nn.Module):
    """Encodes past_state and current_state into a joint representation."""
    
    def __init__(self, state_dim: int, d_model: int):
        super().__init__()
        self.past_proj = nn.Linear(state_dim, d_model)
        self.curr_proj = nn.Linear(state_dim, d_model)
        self.combiner = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.LayerNorm(d_model)
        )
    
    def forward(self, past_state: torch.Tensor, current_state: torch.Tensor) -> torch.Tensor:
        """
        Args:
            past_state: (batch, state_dim)
            current_state: (batch, state_dim)
        Returns:
            combined: (batch, d_model)
        """
        past_enc = F.gelu(self.past_proj(past_state))
        curr_enc = F.gelu(self.curr_proj(current_state))
        combined = self.combiner(torch.cat([past_enc, curr_enc], dim=-1))
        return combined


class CharacteristicsEncoder(nn.Module):
    """Encodes static characteristics/context."""
    
    def __init__(self, char_dim: int, d_model: int):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(char_dim, d_model),
            nn.GELU(),
            nn.LayerNorm(d_model)
        )
    
    def forward(self, characteristics: torch.Tensor) -> torch.Tensor:
        return self.encoder(characteristics)


class TransformerBlock(nn.Module):
    """Standard transformer block with pre-norm."""
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Self-attention with pre-norm
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + attn_out
        # FFN with pre-norm
        x = x + self.ffn(self.norm2(x))
        return x


class CrossAttentionBlock(nn.Module):
    """Cross-attention: state attends to characteristics."""
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.norm_q = nn.LayerNorm(d_model)
        self.norm_kv = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm_ff = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, query: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        normed_q = self.norm_q(query)
        normed_kv = self.norm_kv(context)
        attn_out, _ = self.cross_attn(normed_q, normed_kv, normed_kv)
        x = query + attn_out
        x = x + self.ffn(self.norm_ff(x))
        return x


class AddressHead(nn.Module):
    """
    Produces memory address range (start_addr, end_addr) from hidden state.
    
    Uses two approaches:
    1. HARD mode: argmax over address space (for inference)
    2. SOFT mode: attention weights over memory (for differentiable training)
    """
    
    def __init__(self, d_model: int, address_space: int, max_range: int):
        super().__init__()
        self.address_space = address_space
        self.max_range = max_range
        
        # Produce start address logits
        # We don't have a linear over 65K β€” that's too many params
        # Instead: predict address as composition of sub-addresses (like product keys)
        self.addr_bits = int(math.log2(address_space))  # 16 for 65536
        assert 2 ** self.addr_bits == address_space, "address_space must be power of 2"
        
        # Split address into high byte and low byte (8+8 = 16 bits)
        self.half_bits = self.addr_bits // 2  # 8
        self.half_space = 2 ** self.half_bits  # 256
        
        # Predict high and low parts separately (product key approach)
        self.start_high = nn.Linear(d_model, self.half_space)  # 256 outputs
        self.start_low = nn.Linear(d_model, self.half_space)   # 256 outputs
        
        # Predict range length (how many words to read)
        self.range_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, max_range)
        )
        
        # Confidence head
        self.confidence_head = nn.Sequential(
            nn.Linear(d_model, d_model // 4),
            nn.GELU(),
            nn.Linear(d_model // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, hidden: torch.Tensor) -> dict:
        """
        Args:
            hidden: (batch, d_model)
        
        Returns:
            dict with:
                start_addr: (batch,) integer addresses
                end_addr: (batch,) integer addresses  
                range_length: (batch,) how many words to read
                confidence: (batch,) read confidence score
                start_logits_high: (batch, 256) for soft addressing
                start_logits_low: (batch, 256) for soft addressing
                range_logits: (batch, max_range) for soft range selection
        """
        batch_size = hidden.shape[0]
        
        # Product-key address generation
        high_logits = self.start_high(hidden)  # (batch, 256)
        low_logits = self.start_low(hidden)    # (batch, 256)
        
        # Hard address via argmax
        high_idx = high_logits.argmax(dim=-1)  # (batch,)
        low_idx = low_logits.argmax(dim=-1)    # (batch,)
        start_addr = high_idx * self.half_space + low_idx  # (batch,) 0..65535
        
        # Range length
        range_logits = self.range_head(hidden)  # (batch, max_range)
        range_length = range_logits.argmax(dim=-1) + 1  # (batch,) 1..max_range
        end_addr = (start_addr + range_length).clamp(max=self.address_space - 1)
        
        # Confidence
        confidence = self.confidence_head(hidden).squeeze(-1)  # (batch,)
        
        return {
            'start_addr': start_addr,
            'end_addr': end_addr,
            'range_length': range_length,
            'confidence': confidence,
            'start_logits_high': high_logits,
            'start_logits_low': low_logits,
            'range_logits': range_logits,
        }


class SmallLeWorldModel(nn.Module):
    """
    SLM: Small LeWorld Model (~1.5M params)
    
    Takes (past_state, current_state, characteristics) and produces
    a memory address range pointing to the most useful memory for
    next-state prediction.
    
    Architecture:
    1. Encode past + current state β†’ state representation
    2. Encode characteristics
    3. Cross-attend: state attends to characteristics
    4. Self-attention transformer layers
    5. Address head: output (start_addr, end_addr, confidence)
    """
    
    def __init__(self, config: SLMConfig, slm_id: int = 0):
        super().__init__()
        self.config = config
        self.slm_id = slm_id
        
        # Encoders
        self.state_encoder = StateEncoder(config.state_dim, config.d_model)
        self.char_encoder = CharacteristicsEncoder(config.char_dim, config.d_model)
        
        # Cross-attention: state ← characteristics
        self.cross_attn = CrossAttentionBlock(config.d_model, config.n_heads, config.dropout)
        
        # Self-attention transformer
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(config.d_model, config.n_heads, config.dropout)
            for _ in range(config.n_layers)
        ])
        self.final_norm = nn.LayerNorm(config.d_model)
        
        # Address output head
        self.address_head = AddressHead(config.d_model, config.address_space, config.max_read_range)
    
    def forward(
        self,
        past_state: torch.Tensor,      # (batch, state_dim)
        current_state: torch.Tensor,    # (batch, state_dim)
        characteristics: torch.Tensor,  # (batch, char_dim)
    ) -> dict:
        """
        Forward pass: state + characteristics β†’ memory address range.
        
        Returns dict with address info + internal hidden state.
        """
        # Encode states
        state_repr = self.state_encoder(past_state, current_state)  # (batch, d_model)
        
        # Encode characteristics
        char_repr = self.char_encoder(characteristics)  # (batch, d_model)
        
        # Cross-attention: state queries characteristics
        # Unsqueeze to sequence dim for attention
        state_seq = state_repr.unsqueeze(1)   # (batch, 1, d_model)
        char_seq = char_repr.unsqueeze(1)     # (batch, 1, d_model)
        
        enriched = self.cross_attn(state_seq, char_seq)  # (batch, 1, d_model)
        
        # Self-attention layers
        hidden = enriched
        for layer in self.transformer_layers:
            hidden = layer(hidden)
        
        hidden = self.final_norm(hidden)
        hidden = hidden.squeeze(1)  # (batch, d_model)
        
        # Produce address range
        addr_output = self.address_head(hidden)
        addr_output['hidden'] = hidden  # keep for BLM to use
        addr_output['slm_id'] = self.slm_id
        
        return addr_output


# =============================================================================
# Component 3: Big LeWorld Model (BLM)
# =============================================================================

class StraightThroughSigmoid(torch.autograd.Function):
    """
    Binary gate: hard 0/1 in forward, sigmoid gradient in backward.
    From literature: ST-GS (Jang et al. 2017) + Switch Transformer routing.
    """
    @staticmethod
    def forward(ctx, logits):
        probs = torch.sigmoid(logits)
        ctx.save_for_backward(probs)
        return (probs > 0.5).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        probs, = ctx.saved_tensors
        # Sigmoid derivative: p * (1-p)
        return grad_output * probs * (1 - probs)


class BLMRouter(nn.Module):
    """
    Routes/selects which SLMs to activate.
    Produces binary mask like [1, 0, 1].
    
    Uses Straight-Through Sigmoid for differentiable binary selection.
    Includes load-balancing loss to prevent degenerate routing.
    """
    
    def __init__(self, d_model: int, n_slms: int):
        super().__init__()
        self.n_slms = n_slms
        
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, n_slms)
        )
        
        # Temperature for annealing (start warm, cool down)
        self.register_buffer('temperature', torch.tensor(1.0))
    
    def forward(self, state_repr: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            state_repr: (batch, d_model) encoded current state
            
        Returns:
            binary_mask: (batch, n_slms) hard 0/1 selection
            routing_info: dict with probs, losses, etc.
        """
        logits = self.gate(state_repr)  # (batch, n_slms)
        
        # Scale by temperature
        scaled_logits = logits / self.temperature.clamp(min=0.1)
        
        probs = torch.sigmoid(scaled_logits)  # (batch, n_slms)
        
        # Straight-through binary: hard in forward, soft in backward
        hard_mask = (probs > 0.5).float()
        binary_mask = hard_mask - probs.detach() + probs  # THE ST TRICK
        
        # Ensure at least one SLM is selected (don't want all zeros)
        # If all zeros, force-select the highest probability SLM
        all_zero = (binary_mask.sum(dim=-1) == 0)  # (batch,)
        if all_zero.any():
            max_idx = probs[all_zero].argmax(dim=-1)
            forced = torch.zeros_like(probs[all_zero])
            forced.scatter_(1, max_idx.unsqueeze(1), 1.0)
            binary_mask[all_zero] = forced
        
        # Load balance loss: encourage roughly equal usage of SLMs
        usage_per_slm = binary_mask.mean(dim=0)  # (n_slms,)
        target_usage = 1.0 / self.n_slms
        balance_loss = ((usage_per_slm - target_usage) ** 2).sum()
        
        # Entropy loss: encourage decisive routing (not all ~0.5)
        entropy = -(probs * torch.log(probs + 1e-8) + 
                     (1 - probs) * torch.log(1 - probs + 1e-8))
        entropy_loss = entropy.mean()
        
        routing_info = {
            'probs': probs,
            'binary_mask': binary_mask,
            'balance_loss': balance_loss,
            'entropy_loss': entropy_loss,
            'logits': logits,
        }
        
        return binary_mask, routing_info
    
    def anneal_temperature(self, step: int, anneal_rate: float = 3e-5, min_temp: float = 0.1):
        """Anneal temperature: start warm (exploratory), cool down (decisive)."""
        new_temp = max(min_temp, math.exp(-anneal_rate * step))
        self.temperature.fill_(new_temp)


class InfoRequestHead(nn.Module):
    """
    Produces a query vector representing "what information do I need next?"
    
    This is the key innovation: instead of passively receiving all SLM outputs,
    the BLM actively requests specific information. This query modulates which
    memory regions the SLMs should focus on in the NEXT timestep.
    """
    
    def __init__(self, d_model: int, query_dim: int):
        super().__init__()
        self.query_generator = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, query_dim),
            nn.LayerNorm(query_dim)
        )
    
    def forward(self, hidden: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden: (batch, d_model) BLM's internal state
        Returns:
            info_query: (batch, query_dim) "what do I need next?"
        """
        return self.query_generator(hidden)


class BigLeWorldModel(nn.Module):
    """
    BLM: Big LeWorld Model (~12M params)
    
    Two roles:
    1. ROUTER: Select which SLMs to activate (binary mask)
    2. PREDICTOR: Given selected memory contents, predict next state
    
    Plus: Info-Request Head that asks "what information is needed next?"
    
    Architecture:
    1. Encode current state β†’ routing decision
    2. Receive memory reads from selected SLMs
    3. Transformer processes (state + memories)
    4. Predict next state
    5. Generate info request for next timestep
    """
    
    def __init__(self, config: BLMConfig):
        super().__init__()
        self.config = config
        
        # State encoder (maps state_dim β†’ d_model)
        self.state_encoder = nn.Sequential(
            nn.Linear(config.state_dim, config.d_model),
            nn.GELU(),
            nn.LayerNorm(config.d_model)
        )
        
        # Memory read encoder (maps encoded memory β†’ d_model)
        self.memory_encoder = nn.Sequential(
            nn.Linear(128, config.d_model),  # 128 from ArtificialMemory bit_encoder
            nn.GELU(),
            nn.LayerNorm(config.d_model)
        )
        
        # SLM hidden state encoder (maps SLM hidden β†’ d_model)
        self.slm_hidden_encoder = nn.Sequential(
            nn.Linear(128, config.d_model),  # 128 = SLM d_model
            nn.GELU(),
            nn.LayerNorm(config.d_model)
        )
        
        # Router: selects which SLMs to use
        self.router = BLMRouter(config.d_model, config.n_slms)
        
        # Transformer backbone
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(config.d_model, config.n_heads, config.dropout)
            for _ in range(config.n_layers)
        ])
        self.final_norm = nn.LayerNorm(config.d_model)
        
        # Prediction heads
        self.next_state_head = nn.Sequential(
            nn.Linear(config.d_model, config.d_model),
            nn.GELU(),
            nn.Linear(config.d_model, config.state_dim)
        )
        
        # Info request head: "what do I need next?"
        self.info_request = InfoRequestHead(config.d_model, config.info_query_dim)
        
        # Learnable tokens
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
        self.state_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
        self.memory_type_embed = nn.Parameter(torch.randn(1, 1, config.d_model) * 0.02)
    
    def forward(
        self,
        past_state: torch.Tensor,           # (batch, state_dim)
        current_state: torch.Tensor,         # (batch, state_dim)
        slm_outputs: List[dict],             # list of SLM output dicts
        memory_reads: List[torch.Tensor],    # list of (batch, range, 128) encoded memory
        info_query_prev: Optional[torch.Tensor] = None,  # (batch, query_dim) from previous step
    ) -> dict:
        """
        Full BLM forward pass.
        
        Returns:
            dict with next_state, binary_mask, info_query, losses, etc.
        """
        batch_size = current_state.shape[0]
        
        # 1. Encode current state for routing decision
        state_enc = self.state_encoder(current_state)  # (batch, d_model)
        
        # 2. Route: select which SLMs to use
        binary_mask, routing_info = self.router(state_enc)  # (batch, n_slms)
        
        # 3. Aggregate selected memory reads
        # For each SLM, apply its binary gate and encode its memory read
        memory_tokens = []
        for i, (slm_out, mem_read) in enumerate(zip(slm_outputs, memory_reads)):
            gate = binary_mask[:, i:i+1]  # (batch, 1)
            
            # Gate the SLM's hidden representation
            slm_hidden = self.slm_hidden_encoder(slm_out['hidden'])  # (batch, d_model)
            slm_hidden = slm_hidden * gate  # zero if SLM not selected
            
            # Gate and encode the memory read
            # mem_read: (batch, range_len, 128)
            mem_enc = self.memory_encoder(mem_read)  # (batch, range_len, d_model)
            mem_enc = mem_enc * gate.unsqueeze(-1)  # zero if SLM not selected
            
            # Pool memory read to single token (mean pool over range)
            mem_pooled = mem_enc.mean(dim=1, keepdim=True)  # (batch, 1, d_model)
            
            memory_tokens.append(slm_hidden.unsqueeze(1))  # SLM hidden as token
            memory_tokens.append(mem_pooled)                 # memory content as token
        
        # 4. Build input sequence for transformer
        # [CLS] + [state] + [slm_0_hidden, slm_0_mem, slm_1_hidden, slm_1_mem, ...]
        cls = self.cls_token.expand(batch_size, -1, -1)
        state_token = state_enc.unsqueeze(1) + self.state_type_embed  # (batch, 1, d_model)
        
        # Add memory type embedding to memory tokens
        mem_sequence = torch.cat(memory_tokens, dim=1)  # (batch, 2*n_slms, d_model)
        mem_sequence = mem_sequence + self.memory_type_embed
        
        sequence = torch.cat([cls, state_token, mem_sequence], dim=1)
        # Shape: (batch, 1 + 1 + 2*n_slms, d_model)
        
        # 5. Transformer processing
        hidden = sequence
        for layer in self.transformer_layers:
            hidden = layer(hidden)
        hidden = self.final_norm(hidden)
        
        # 6. Extract predictions from CLS token
        cls_output = hidden[:, 0, :]  # (batch, d_model)
        
        # 7. Predict next state
        next_state_pred = self.next_state_head(cls_output)  # (batch, state_dim)
        
        # 8. Generate info request for next timestep
        info_query = self.info_request(cls_output)  # (batch, query_dim)
        
        return {
            'next_state': next_state_pred,
            'binary_mask': binary_mask,
            'info_query': info_query,
            'routing_info': routing_info,
            'cls_output': cls_output,
        }


# =============================================================================
# Component 4: Full LeWorld System
# =============================================================================

class LeWorldSystem(nn.Module):
    """
    Complete LeWorld Memory Architecture.
    
    Orchestrates:
    - Artificial Memory (bit-level storage)
    - 3 SLMs (produce memory address ranges)
    - 1 BLM (selects SLMs, reads memory, predicts next state)
    
    Training loop:
    1. BLM sees current state β†’ routes to SLMs
    2. Selected SLMs produce address ranges
    3. Memory is read at those ranges
    4. BLM aggregates memory + state β†’ predicts next state
    5. BLM generates info-request for next step
    
    Losses:
    - next_state_loss: MSE between predicted and actual next state
    - routing_balance_loss: encourage balanced SLM usage
    - address_diversity_loss: encourage SLMs to read different memory regions
    - info_utility_loss: did the info request lead to useful retrievals?
    """
    
    def __init__(
        self,
        mem_config: MemoryConfig = MemoryConfig(),
        slm_config: SLMConfig = SLMConfig(),
        blm_config: BLMConfig = BLMConfig(),
    ):
        super().__init__()
        
        # Artificial Memory
        self.memory = ArtificialMemory(mem_config)
        
        # 3 SLMs
        self.slms = nn.ModuleList([
            SmallLeWorldModel(slm_config, slm_id=i)
            for i in range(blm_config.n_slms)
        ])
        
        # BLM
        self.blm = BigLeWorldModel(blm_config)
        
        # Info-query β†’ SLM modulation: the BLM's info request
        # influences what SLMs look for in the next timestep
        self.info_to_slm = nn.Linear(blm_config.info_query_dim, slm_config.state_dim)
        
        self.config = {
            'mem': mem_config,
            'slm': slm_config,
            'blm': blm_config,
        }
    
    def forward(
        self,
        past_state: torch.Tensor,          # (batch, state_dim)
        current_state: torch.Tensor,        # (batch, state_dim)
        characteristics: torch.Tensor,      # (batch, char_dim)
        next_state_target: Optional[torch.Tensor] = None,  # (batch, state_dim) for training
        info_query_prev: Optional[torch.Tensor] = None,    # from previous timestep
    ) -> dict:
        """
        Full system forward pass.
        """
        batch_size = current_state.shape[0]
        
        # If we have a previous info query, modulate the current state
        # This is how the BLM's "what do I need?" influences retrieval
        if info_query_prev is not None:
            info_modulation = self.info_to_slm(info_query_prev)  # (batch, state_dim)
            modulated_state = current_state + 0.1 * info_modulation  # gentle modulation
        else:
            modulated_state = current_state
        
        # 1. Run all 3 SLMs to get address ranges
        slm_outputs = []
        for slm in self.slms:
            out = slm(past_state, modulated_state, characteristics)
            slm_outputs.append(out)
        
        # 2. Read memory at each SLM's address range
        memory_reads = []
        for slm_out in slm_outputs:
            _, encoded, valid_mask = self.memory.read(
                slm_out['start_addr'], 
                slm_out['end_addr']
            )
            memory_reads.append(encoded)
        
        # 3. BLM processes everything
        blm_output = self.blm(
            past_state, current_state, 
            slm_outputs, memory_reads,
            info_query_prev
        )
        
        # 4. Compute losses if training
        losses = {}
        if next_state_target is not None:
            # Primary loss: next state prediction
            losses['next_state_loss'] = F.mse_loss(
                blm_output['next_state'], next_state_target
            )
            
            # Routing balance loss
            losses['balance_loss'] = blm_output['routing_info']['balance_loss']
            
            # Address diversity loss: penalize SLMs for reading same regions
            addresses = torch.stack([
                slm_out['start_addr'].float() for slm_out in slm_outputs
            ], dim=1)  # (batch, n_slms)
            # Pairwise distance between SLM addresses (want to maximize)
            addr_diff = torch.cdist(addresses.unsqueeze(-1), addresses.unsqueeze(-1))
            diversity_loss = -addr_diff.mean()  # negative = encourage large distances
            losses['diversity_loss'] = diversity_loss
            
            # Total loss
            losses['total_loss'] = (
                losses['next_state_loss'] 
                + 0.01 * losses['balance_loss']
                + 0.001 * losses['diversity_loss']
            )
        
        return {
            'next_state': blm_output['next_state'],
            'binary_mask': blm_output['binary_mask'],
            'info_query': blm_output['info_query'],
            'slm_outputs': slm_outputs,
            'memory_reads': memory_reads,
            'losses': losses,
            'routing_info': blm_output['routing_info'],
        }
    
    def multi_step_forward(
        self,
        states: torch.Tensor,           # (batch, T, state_dim) sequence of states
        characteristics: torch.Tensor,   # (batch, char_dim) static
        n_steps: int = None,
    ) -> dict:
        """
        Run the system over multiple timesteps autoregressively.
        
        For training: teacher forcing with ground-truth states
        """
        batch_size, T, state_dim = states.shape
        if n_steps is None:
            n_steps = T - 1  # predict all future states
        
        all_predictions = []
        all_masks = []
        total_loss = None
        info_query = None
        
        for t in range(min(n_steps, T - 1)):
            past_state = states[:, max(0, t-1), :]
            current_state = states[:, t, :]
            next_state_target = states[:, t+1, :]
            
            output = self.forward(
                past_state, current_state, characteristics,
                next_state_target, info_query
            )
            
            all_predictions.append(output['next_state'])
            all_masks.append(output['binary_mask'])
            info_query = output['info_query']
            
            if output['losses']:
                if total_loss is None:
                    total_loss = output['losses']['total_loss']
                else:
                    total_loss = total_loss + output['losses']['total_loss']
        
        if total_loss is None:
            total_loss = torch.tensor(0.0, device=states.device)
        return {
            'predictions': torch.stack(all_predictions, dim=1),
            'masks': torch.stack(all_masks, dim=1),
            'total_loss': total_loss / max(1, min(n_steps, T - 1)),
            'final_info_query': info_query,
        }


# =============================================================================
# Parameter Count Verification
# =============================================================================

def count_params(model, name="Model"):
    """Count and display parameter breakdown."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n{'='*60}")
    print(f"{name}: {total:,} total params ({trainable:,} trainable)")
    print(f"{'='*60}")
    
    for child_name, child in model.named_children():
        child_params = sum(p.numel() for p in child.parameters())
        if child_params > 0:
            print(f"  {child_name}: {child_params:,}")
    
    return total


# =============================================================================
# Demo / Test
# =============================================================================

if __name__ == "__main__":
    print("LeWorld Memory Architecture β€” Component Verification")
    print("=" * 60)
    
    # Configs
    mem_config = MemoryConfig()
    slm_config = SLMConfig()
    blm_config = BLMConfig()
    
    # Build system
    system = LeWorldSystem(mem_config, slm_config, blm_config)
    
    # Count parameters
    print("\n--- Parameter Counts ---")
    count_params(system.memory, "Artificial Memory")
    for i, slm in enumerate(system.slms):
        count_params(slm, f"SLM-{i}")
    count_params(system.blm, "BLM")
    count_params(system, "Full System")
    
    # Test forward pass
    print("\n--- Forward Pass Test ---")
    batch_size = 4
    state_dim = slm_config.state_dim
    char_dim = slm_config.char_dim
    
    past_state = torch.randn(batch_size, state_dim)
    current_state = torch.randn(batch_size, state_dim)
    characteristics = torch.randn(batch_size, char_dim)
    next_state = torch.randn(batch_size, state_dim)
    
    output = system(past_state, current_state, characteristics, next_state)
    
    print(f"Next state prediction shape: {output['next_state'].shape}")
    print(f"Binary mask (SLM selection): {output['binary_mask']}")
    print(f"Info query shape: {output['info_query'].shape}")
    print(f"Losses: {output['losses']}")
    
    # Test multi-step
    print("\n--- Multi-Step Test ---")
    T = 10
    states = torch.randn(batch_size, T, state_dim)
    
    ms_output = system.multi_step_forward(states, characteristics)
    print(f"Predictions shape: {ms_output['predictions'].shape}")
    print(f"Masks shape: {ms_output['masks'].shape}")
    print(f"Average loss: {ms_output['total_loss'].item():.4f}")
    
    # Show routing patterns over time
    print("\n--- Routing Patterns Over Time ---")
    masks = ms_output['masks'][0].detach()  # first batch element
    for t in range(masks.shape[0]):
        mask = masks[t].int().tolist()
        print(f"  Step {t}: SLMs selected = {mask}")
    
    print("\nβœ… All components verified successfully!")