| """ |
| Resonance 200M β Content + RRPRAM dual attention transformer. |
| Low-rank RRPRAM (Wr = Wr_a @ Wr_b), SwiGLU MLP, RMSNorm, RoPE. |
| Content attention uses FlashAttention via F.scaled_dot_product_attention. |
| |
| Architecture matches resonance-bpe.c (with low-rank extension). |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| class ResonanceBlock(nn.Module): |
| """ |
| Dual attention block: Content (QKV + RoPE + FlashAttn) + RRPRAM (low-rank Wr) + SwiGLU MLP. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| E = config['n_embd'] |
| H = config['n_head'] |
| D = config['head_dim'] |
| R = config['rrpram_rank'] |
| T = config['context_len'] |
| M = config['ffn_dim'] |
|
|
| self.n_head = H |
| self.head_dim = D |
| self.n_embd = E |
|
|
| |
| self.norm1 = RMSNorm(E) |
|
|
| |
| self.wq = nn.Linear(E, H * D, bias=False) |
| self.wk = nn.Linear(E, H * D, bias=False) |
| self.wv = nn.Linear(E, H * D, bias=False) |
|
|
| |
| self.wr_a = nn.Parameter(torch.randn(H, E, R) * (2.0 / E) ** 0.5) |
| self.wr_b = nn.Parameter(torch.randn(H, R, T) * (2.0 / R) ** 0.5) |
|
|
| |
| self.gate = nn.Parameter(torch.zeros(H)) |
|
|
| |
| self.wo = nn.Linear(E, E, bias=False) |
|
|
| |
| self.norm2 = RMSNorm(E) |
|
|
| |
| self.mlp_gate = nn.Linear(E, M, bias=False) |
| self.mlp_up = nn.Linear(E, M, bias=False) |
| self.mlp_down = nn.Linear(M, E, bias=False) |
|
|
| |
| n_layer = config['n_layer'] |
| nn.init.normal_(self.wo.weight, std=0.02 / math.sqrt(2 * n_layer)) |
| nn.init.normal_(self.mlp_down.weight, std=0.02 / math.sqrt(2 * n_layer)) |
|
|
| def forward(self, x, rope_cos, rope_sin, mask): |
| B, T, E = x.shape |
| H = self.n_head |
| D = self.head_dim |
|
|
| |
| xn = self.norm1(x) |
|
|
| |
| q = self.wq(xn).view(B, T, H, D).transpose(1, 2) |
| k = self.wk(xn).view(B, T, H, D).transpose(1, 2) |
| v = self.wv(xn).view(B, T, H, D).transpose(1, 2) |
|
|
| |
| q = _apply_rope(q, rope_cos, rope_sin) |
| k = _apply_rope(k, rope_cos, rope_sin) |
|
|
| |
| c_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
| |
| |
| |
| xn_h = xn.unsqueeze(1).expand(-1, H, -1, -1) |
| |
| temp = torch.einsum('bhie,her->bhir', xn_h, self.wr_a) |
| r_attn = torch.einsum('bhir,hrj->bhij', temp, self.wr_b[:, :, :T]) |
| r_attn = r_attn * (D ** -0.5) |
| r_attn = r_attn.masked_fill(mask, float('-inf')) |
| r_attn = F.softmax(r_attn, dim=-1) |
| r_out = r_attn @ v |
|
|
| |
| g = torch.sigmoid(self.gate).view(1, H, 1, 1) |
| attn_out = g * c_out + (1 - g) * r_out |
|
|
| |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, E) |
| x = x + self.wo(attn_out) |
|
|
| |
| xn = self.norm2(x) |
| gate = F.silu(self.mlp_gate(xn)) |
| up = self.mlp_up(xn) |
| x = x + self.mlp_down(gate * up) |
|
|
| return x |
|
|
|
|
| def _apply_rope(x, cos, sin): |
| """Apply RoPE to tensor x: [B, H, T, D].""" |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| out = torch.stack([ |
| x1 * cos - x2 * sin, |
| x1 * sin + x2 * cos, |
| ], dim=-1).flatten(-2) |
| return out |
|
|
|
|
| class Resonance(nn.Module): |
| """ |
| Resonance 200M: dual attention (Content + RRPRAM) transformer. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| V = config['vocab_size'] |
| E = config['n_embd'] |
| T = config['context_len'] |
| D = config['head_dim'] |
|
|
| |
| self.tok_emb = nn.Embedding(V, E) |
| nn.init.normal_(self.tok_emb.weight, std=0.02) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| ResonanceBlock(config) for _ in range(config['n_layer']) |
| ]) |
|
|
| |
| self.norm_f = RMSNorm(E) |
| self.out_head = nn.Linear(E, V, bias=False) |
| nn.init.normal_(self.out_head.weight, std=0.02) |
|
|
| |
| freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2).float() / D)) |
| t = torch.arange(T).float() |
| angles = torch.outer(t, freqs) |
| self.register_buffer('rope_cos', angles.cos().unsqueeze(0).unsqueeze(0)) |
| self.register_buffer('rope_sin', angles.sin().unsqueeze(0).unsqueeze(0)) |
|
|
| |
| mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1) |
| self.register_buffer('causal_mask', mask) |
|
|
| n_params = sum(p.numel() for p in self.parameters()) |
| print(f" [Resonance] {n_params:,} parameters") |
| self._report_balance() |
|
|
| def _report_balance(self): |
| """Report parameter budget distribution.""" |
| cfg = self.config |
| E, H, D = cfg['n_embd'], cfg['n_head'], cfg['head_dim'] |
| R, T, M = cfg['rrpram_rank'], cfg['context_len'], cfg['ffn_dim'] |
| V, L = cfg['vocab_size'], cfg['n_layer'] |
|
|
| emb = V * E * 2 |
| qkv = L * (3 * E * H * D) |
| rrpram = L * (H * E * R + H * R * T + H) |
| wo = L * E * E |
| mlp = L * (3 * E * M) |
| norms = L * 2 * E + E |
|
|
| total = emb + qkv + rrpram + wo + mlp + norms |
| print(f" [Resonance] Budget: emb={emb/total*100:.1f}% qkv={qkv/total*100:.1f}% " |
| f"rrpram={rrpram/total*100:.1f}% wo={wo/total*100:.1f}% " |
| f"mlp={mlp/total*100:.1f}% norms={norms/total*100:.1f}%") |
|
|
| def set_gradient_checkpointing(self, enable=True): |
| self._grad_ckpt = enable |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| x = self.tok_emb(idx) |
|
|
| cos = self.rope_cos[:, :, :T, :] |
| sin = self.rope_sin[:, :, :T, :] |
| mask = self.causal_mask[:T, :T] |
|
|
| for block in self.blocks: |
| if getattr(self, '_grad_ckpt', False) and self.training: |
| x = torch.utils.checkpoint.checkpoint( |
| block, x, cos, sin, mask, use_reentrant=False) |
| else: |
| x = block(x, cos, sin, mask) |
|
|
| logits = self.out_head(self.norm_f(x)) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
| return logits, loss |
|
|
|
|
| |
| RESONANCE_200M = { |
| 'n_embd': 768, |
| 'n_head': 12, |
| 'head_dim': 64, |
| 'n_layer': 20, |
| 'rrpram_rank': 48, |
| 'context_len': 2048, |
| 'ffn_dim': 2048, |
| 'vocab_size': 16384, |
| } |
|
|