|
|
| """ |
| Addressed State Attention (ASA) - Training Harness |
| |
| Efficient implementation optimized for language model training. |
| For mechanistic analysis and interventions, use asm_analysis.py instead. |
| |
| Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
| Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/paper_drafts |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Dict, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| __all__ = [ |
| 'AddressedStateAttention', |
| 'ASMBlock', |
| 'ASMLanguageModel', |
| 'ASMTrainConfig', |
| 'build_model_from_cfg', |
| ] |
|
|
| |
| |
| |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: float = 10000.0): |
| super().__init__() |
| assert dim % 2 == 0, "RoPE requires even dim" |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._cos_cached = None |
| self._sin_cached = None |
| self._t_cached = None |
| self._device_cached = None |
|
|
| def get_cos_sin(self, T: int, device, dtype): |
| if self._t_cached == T and self._cos_cached is not None and self._device_cached == device: |
| return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) |
|
|
| t = torch.arange(T, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("t,f->tf", t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos()[None, None, :, :] |
| sin = emb.sin()[None, None, :, :] |
|
|
| self._t_cached = T |
| self._device_cached = device |
| self._cos_cached = cos |
| self._sin_cached = sin |
| return cos.to(dtype=dtype), sin.to(dtype=dtype) |
|
|
| def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| return (x * cos) + (_rotate_half(x) * sin) |
|
|
| |
| |
| |
| def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor: |
| def get_slopes(n): |
| def power_of_2_slopes(n): |
| start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3))) |
| ratio = start |
| return [start * (ratio ** i) for i in range(n)] |
| if math.log2(n).is_integer(): |
| return power_of_2_slopes(n) |
| closest = 2 ** math.floor(math.log2(n)) |
| return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest] |
| return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype) |
|
|
| def _inv_softplus(y: torch.Tensor) -> torch.Tensor: |
| return torch.log(torch.expm1(y)) |
|
|
| class AddressedStateAttention(nn.Module): |
| """ |
| ASA with integral slotspace refine fused into the compiled chunk kernel. |
| Fixes included: |
| (1) pad slotspace RoPE cos/sin to CH (identity on padded positions) |
| (2) build valid_mask_c even when attention_mask is None (padding-only) |
| (3) pad write logits with -inf (so padded positions contribute zero to scan) |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int = 12, |
| num_slots: int = 16, |
| dropout: float = 0.1, |
| |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.0, |
| normalize_k: bool = False, |
| |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| |
| |
| use_alibi_write: bool = True, |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 8, |
| slotspace_gate_init: float = -4.0, |
| slotspace_dropout: float = 0.05, |
| slotspace_signed_weights: bool = True, |
| |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| |
| write_chunk_size: int = 1024, |
| enable_compiled: bool = True, |
| ): |
| super().__init__() |
| assert embed_dim % num_heads == 0 |
| assert (slotspace_dim % 2) == 0, "slotspace_dim must be even if RoPE enabled" |
|
|
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.num_slots = num_slots |
| self.head_dim = embed_dim // num_heads |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.read_temperature = float(read_temperature) |
| self.write_temperature = float(write_temperature) |
| self.state_fp32 = bool(state_fp32) |
| self.slot_dropout = float(slot_dropout) |
| self.normalize_k = bool(normalize_k) |
|
|
| self.use_rope_keys = bool(use_rope_keys) |
| self.use_alibi_write = bool(use_alibi_write) |
| self.learn_alibi_strength = bool(learn_alibi_strength) |
| self.min_strength = float(min_strength) |
|
|
| self.use_content_read = bool(use_content_read) |
| self.content_read_max_gamma = float(content_read_max_gamma) |
|
|
| self.slotspace_dim = int(slotspace_dim) |
| self.slotspace_dropout = nn.Dropout(float(slotspace_dropout)) |
| self.slotspace_signed_weights = bool(slotspace_signed_weights) |
|
|
| self.use_rope_slotspace = bool(use_rope_slotspace) |
| self.write_chunk_size = int(write_chunk_size) |
|
|
| H, K, d = self.num_heads, self.num_slots, self.head_dim |
| M = self.slotspace_dim |
|
|
| self.slot_keys = nn.Parameter(torch.randn(H, K, d) / math.sqrt(d)) |
|
|
| self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
| self.rope = RotaryEmbedding(d, base=rope_base) if self.use_rope_keys else None |
|
|
| if self.use_alibi_write: |
| self.register_buffer("_alibi_slopes", alibi_slopes(H), persistent=False) |
| else: |
| self.register_buffer("_alibi_slopes", torch.zeros(H), persistent=False) |
|
|
| if self.use_alibi_write and self.learn_alibi_strength: |
| init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8) |
| self._alibi_strength_param = nn.Parameter(_inv_softplus(init)) |
| else: |
| self._alibi_strength_param = None |
| self.alibi_strength = float(alibi_strength_init) |
|
|
| if self.use_content_read: |
| self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init))) |
| else: |
| self._content_read_gamma_raw = None |
|
|
| self.slot_in = nn.Linear(K, M, bias=False) |
| self.slot_q = nn.Linear(M, M, bias=False) |
| self.slot_k = nn.Linear(M, M, bias=False) |
| self.slot_v = nn.Linear(M, M, bias=False) |
| self.slot_out = nn.Linear(M, K, bias=False) |
|
|
| self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init))) |
|
|
| self.rope_slotspace = RotaryEmbedding(M, base=float(rope_base_slotspace)) if self.use_rope_slotspace else None |
|
|
| self._compiled = None |
| if enable_compiled: |
| self.enable_compiled_kernel() |
|
|
| def enable_compiled_kernel(self): |
| if self._compiled is None: |
| self._compiled = torch.compile(self._asa_chunk_fused, dynamic=False, fullgraph=False) |
|
|
| def _alibi_strength(self, dtype, device) -> torch.Tensor: |
| if not (self.use_alibi_write and self.learn_alibi_strength): |
| return torch.tensor(getattr(self, "alibi_strength", 0.0), dtype=dtype, device=device) |
| return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device) |
|
|
| def _content_read_gamma(self, dtype, device) -> torch.Tensor: |
| if not self.use_content_read: |
| return torch.tensor(0.0, dtype=dtype, device=device) |
| g = F.softplus(self._content_read_gamma_raw) |
| if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0: |
| g = g.clamp(max=self.content_read_max_gamma) |
| return g.to(dtype=dtype, device=device) |
|
|
| def _slotspace_gate(self, dtype, device) -> torch.Tensor: |
| return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device) |
|
|
| @staticmethod |
| def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor: |
| diff = s - m |
| diff = diff.masked_fill(~torch.isfinite(m), float("-inf")) |
| return torch.exp(diff) |
|
|
| @staticmethod |
| def _phi(x: torch.Tensor) -> torch.Tensor: |
| return F.elu(x) + 1.0 |
|
|
| @staticmethod |
| def _pad_time_slice(x: torch.Tensor, t0: int, L: int, CH: int, dim: int): |
| sl = x.narrow(dim, t0, L) |
| if L == CH: |
| return sl, None |
| pad_shape = list(sl.shape) |
| pad_shape[dim] = CH - L |
| pad = torch.zeros(pad_shape, device=sl.device, dtype=sl.dtype) |
| xpad = torch.cat([sl, pad], dim=dim) |
| mask = torch.zeros((CH,), device=sl.device, dtype=torch.bool) |
| mask[:L] = True |
| return xpad, mask |
|
|
| def _asa_chunk_fused( |
| self, |
| wlog_c: torch.Tensor, |
| v_c: torch.Tensor, |
| q_c: torch.Tensor, |
| slot_keys_dk: torch.Tensor, |
| pos_cos_s: Optional[torch.Tensor], |
| pos_sin_s: Optional[torch.Tensor], |
| content_gamma: torch.Tensor, |
| rtemp_t: torch.Tensor, |
| gate_t: torch.Tensor, |
| m_state: torch.Tensor, |
| denom_state: torch.Tensor, |
| numer_state: torch.Tensor, |
| S_state: torch.Tensor, |
| Z_state: torch.Tensor, |
| valid_mask_c: Optional[torch.Tensor], |
| do_dropout: bool, |
| dropout_p: float, |
| signed_slot_w: bool, |
| ): |
| B, H, K, CH = wlog_c.shape |
| d = numer_state.shape[-1] |
| M = S_state.shape[-1] |
| inv_sqrt_d = 1.0 / math.sqrt(d) |
|
|
| |
| m_c, _ = torch.cummax(wlog_c, dim=-1) |
| m_new = torch.maximum(m_state.unsqueeze(-1), m_c) |
| scale = torch.exp(m_state.unsqueeze(-1) - m_new) |
|
|
| denom_c = denom_state.unsqueeze(-1) * scale |
| numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1) |
|
|
| w_new = self._safe_exp_sub_max(wlog_c, m_new) |
| denom_c = denom_c + torch.cumsum(w_new, dim=-1) |
| numer_c = numer_c + torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) |
|
|
| |
| read_logits_key = torch.matmul(q_c, slot_keys_dk) * inv_sqrt_d |
|
|
| if self.use_content_read: |
| numer_for_dot = numer_c.to(q_c.dtype).permute(0, 1, 3, 2, 4) |
| denom_for_div = denom_c.to(q_c.dtype).permute(0, 1, 3, 2) |
| read_logits_content = (q_c.unsqueeze(-2) * numer_for_dot).sum(dim=-1) * inv_sqrt_d |
| read_logits_content = read_logits_content / denom_for_div.clamp_min(1e-8) |
| read_logits = read_logits_key + content_gamma.to(read_logits_key.dtype) * read_logits_content |
| else: |
| read_logits = read_logits_key |
|
|
| read_w = torch.softmax(read_logits / rtemp_t, dim=-1) |
|
|
| |
| inv_denom = (1.0 / denom_c.clamp_min(1e-8)).to(numer_c.dtype) |
| w_scaled = read_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom |
| out_base = (w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) |
|
|
| |
| u = self.slot_in(read_w.to(out_base.dtype)) |
| q_s = self.slot_q(u) |
| k_s = self.slot_k(u) |
| v_s = self.slot_v(u) |
|
|
| if self.use_rope_slotspace and (pos_cos_s is not None) and (pos_sin_s is not None): |
| q_s = apply_rope(q_s, pos_cos_s, pos_sin_s) |
| k_s = apply_rope(k_s, pos_cos_s, pos_sin_s) |
|
|
| if valid_mask_c is not None: |
| q_s = q_s * valid_mask_c |
| k_s = k_s * valid_mask_c |
| v_s = v_s * valid_mask_c |
|
|
| qf = self._phi(q_s) |
| kf = self._phi(k_s) |
|
|
| kv = kf.unsqueeze(-1) * v_s.unsqueeze(-2) |
| S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2) |
| Z_c = torch.cumsum(kf, dim=2) + Z_state.unsqueeze(2) |
| Z_c = Z_c.clamp_min(1e-8) |
|
|
| num = torch.matmul(qf.unsqueeze(-2), S_c).squeeze(-2) |
| den = (qf * Z_c).sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| u2 = num / den |
|
|
| S_state_new = S_c[:, :, -1, :, :] |
| Z_state_new = Z_c[:, :, -1, :] |
|
|
| if do_dropout and dropout_p > 0.0: |
| keep = (torch.rand_like(u2) > dropout_p).to(u2.dtype) / (1.0 - dropout_p) |
| u2 = u2 * keep |
|
|
| slot_w = self.slot_out(u2) |
| if signed_slot_w: |
| slot_w = torch.tanh(slot_w) |
| else: |
| slot_w = torch.softmax(slot_w, dim=-1) |
|
|
| slot_w_scaled = slot_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom |
| delta = (slot_w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) |
|
|
| out = out_base + gate_t.to(out_base.dtype) * delta |
|
|
| m_state_new = m_new[:, :, :, -1] |
| denom_state_new = denom_c[:, :, :, -1] |
| numer_state_new = numer_c[:, :, :, -1, :] |
|
|
| return out, read_w, m_state_new, denom_state_new, numer_state_new, S_state_new, Z_state_new |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_info: bool = False, |
| return_light_stats: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: |
|
|
| B, T, C = x.shape |
| H, K, d = self.num_heads, self.num_slots, self.head_dim |
| M = self.slotspace_dim |
|
|
| k_write = self.Wk_write(x).reshape(B, T, H, d).transpose(1, 2) |
| v_write = self.Wv_write(x).reshape(B, T, H, d).transpose(1, 2) |
| q_read = self.Wq_read(x).reshape(B, T, H, d).transpose(1, 2) |
|
|
| if self.normalize_k: |
| k_write = F.normalize(k_write, dim=-1, eps=1e-8) |
|
|
| if self.use_rope_keys: |
| cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype) |
| k_write = apply_rope(k_write, cos, sin) |
|
|
| slot_keys = self.slot_keys |
| if self.training and self.slot_dropout > 0.0: |
| drop = (torch.rand((H, K), device=x.device) < self.slot_dropout) |
| slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1) |
|
|
| slot_keys_dk = slot_keys.transpose(-1, -2).unsqueeze(0).to(q_read.dtype) |
|
|
| write_logits_raw = torch.matmul(k_write.to(q_read.dtype), slot_keys_dk).permute(0, 1, 3, 2) / math.sqrt(d) |
|
|
| state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype |
| write_logits = write_logits_raw.to(state_dtype) |
|
|
| wtemp = max(1e-6, self.write_temperature) |
| write_logits = write_logits / wtemp |
|
|
| if self.use_alibi_write: |
| strength = self._alibi_strength(dtype=state_dtype, device=x.device) |
| slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength |
| pos = torch.arange(T, device=x.device, dtype=state_dtype) |
| write_logits = write_logits + slopes.view(1, H, 1, 1) * pos.view(1, 1, 1, T) |
|
|
| valid = None |
| if attention_mask is not None: |
| valid = attention_mask.to(dtype=torch.bool) |
| write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf")) |
|
|
| content_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device) |
| rtemp_t = torch.tensor(max(1e-6, self.read_temperature), device=x.device, dtype=q_read.dtype) |
| gate_t = self._slotspace_gate(dtype=state_dtype, device=x.device) |
|
|
| denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype) |
| numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype) |
| m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype) |
|
|
| S_state = torch.zeros((B, H, M, M), device=x.device, dtype=state_dtype) |
| Z_state = torch.zeros((B, H, M), device=x.device, dtype=state_dtype) |
|
|
| out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) |
|
|
| if self.use_rope_slotspace: |
| cos_s_full, sin_s_full = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=state_dtype) |
| else: |
| cos_s_full = sin_s_full = None |
|
|
| CH = self.write_chunk_size |
| kernel = self._compiled if self._compiled is not None else self._asa_chunk_fused |
|
|
| do_dropout = bool(self.training and self.slotspace_dropout.p > 0.0) |
| dropout_p = float(self.slotspace_dropout.p) |
| signed_slot_w = bool(self.slotspace_signed_weights) |
|
|
| for t0 in range(0, T, CH): |
| t1 = min(T, t0 + CH) |
| L = t1 - t0 |
|
|
| wlog_c, mask = self._pad_time_slice(write_logits, t0, L, CH, dim=3) |
| v_c, _ = self._pad_time_slice(v_write.to(state_dtype), t0, L, CH, dim=2) |
| q_c, _ = self._pad_time_slice(q_read, t0, L, CH, dim=2) |
|
|
| |
| if mask is not None: |
| wlog_c = wlog_c.clone() |
| wlog_c[:, :, :, L:] = float("-inf") |
|
|
| |
| valid_mask_c = None |
| if (valid is not None) or (mask is not None): |
| if valid is None: |
| vm_pad = mask.view(1, CH).expand(B, CH) |
| else: |
| if mask is None: |
| vm_pad = valid[:, t0:t1] |
| else: |
| vm = valid[:, t0:t1] |
| vm_pad = torch.zeros((B, CH), device=x.device, dtype=torch.bool) |
| vm_pad[:, :L] = vm |
| valid_mask_c = vm_pad.view(B, 1, CH, 1).to(state_dtype) |
|
|
| |
| if self.use_rope_slotspace: |
| cos_slice = cos_s_full[:, :, t0:t1, :] |
| sin_slice = sin_s_full[:, :, t0:t1, :] |
| if L == CH: |
| cos_s, sin_s = cos_slice, sin_slice |
| else: |
| cos_s = torch.ones((1, 1, CH, M), device=x.device, dtype=state_dtype) |
| sin_s = torch.zeros((1, 1, CH, M), device=x.device, dtype=state_dtype) |
| cos_s[:, :, :L, :] = cos_slice |
| sin_s[:, :, :L, :] = sin_slice |
| else: |
| cos_s = sin_s = None |
|
|
| out_c, read_w_c, m_state, denom_state, numer_state, S_state, Z_state = kernel( |
| wlog_c, v_c, q_c, slot_keys_dk, |
| cos_s, sin_s, |
| content_gamma, rtemp_t, gate_t, |
| m_state, denom_state, numer_state, |
| S_state, Z_state, |
| valid_mask_c, |
| do_dropout, dropout_p, |
| signed_slot_w, |
| ) |
|
|
| if mask is not None: |
| out_c = out_c * mask.view(1, 1, CH, 1).to(out_c.dtype) |
|
|
| out_h[:, :, t0:t1, :] = out_c[:, :, :L, :] |
|
|
| out = out_h.transpose(1, 2).reshape(B, T, C) |
| out = self.out_proj(out) |
| out = self.dropout(out) |
|
|
| info = None |
| if return_info or return_light_stats: |
| info = { |
| "content_read_gamma": content_gamma.detach().to(torch.float32).cpu(), |
| "slotspace_gate": gate_t.detach().to(torch.float32).cpu(), |
| } |
| return out, info |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| @dataclass |
| class ASMTrainConfig: |
| |
| dataset_name: str = "wikitext" |
| dataset_config: str = "wikitext-103-raw-v1" |
| tokenizer_name: str = "gpt2" |
|
|
| max_seq_len: int = 256 |
| stride_frac_val: float = 0.50 |
| seed: int = 1337 |
| micro_batch_size: int = 2 |
| grad_accum_steps: int = 8 |
| |
| train_samples_target: int = 100_000_000 |
| val_samples_target: int = 25_000 |
|
|
| |
| batch_size: int = 64 |
| learning_rate: float = 3e-4 |
| weight_decay: float = 0.01 |
| betas: Tuple[float, float] = (0.9, 0.95) |
| grad_clip: float = 1.0 |
| warmup_steps: int = 1_000 |
| total_steps: int = 75_000 |
| eval_interval: int = 1_000 |
| log_interval: int = 100 |
|
|
| |
| vocab_size: int = 50257 |
| embed_dim: int = 384 |
| num_layers: int = 23 |
| num_heads: int = 8 |
| num_slots: int = 32 |
| mlp_ratio: float = 4.0 |
| dropout: float = 0.1 |
| tie_weights: bool = True |
|
|
| |
| read_temperature: float = 1.0 |
| write_temperature: float = 1.0 |
| slot_dropout: float = 0.05 |
| state_fp32: bool = True |
| normalize_k: bool = False |
|
|
| |
| use_abs_pos: bool = False |
| use_rope_keys: bool = True |
| rope_base: float = 10000.0 |
| use_alibi_write: bool = True |
| alibi_strength_init: float = 0.1 |
| learn_alibi_strength: bool = True |
| min_strength: float = 0.0 |
|
|
| |
| use_content_read: bool = True |
| content_read_init: float = -4.0 |
| content_read_max_gamma: float = 3.0 |
|
|
| |
| use_slotspace_refine: bool = True |
| slotspace_dim: int = 64 |
| slotspace_gate_init: float = -4.0 |
| slotspace_dropout: float = 0.05 |
| slotspace_signed_weights: bool = True |
|
|
| |
| use_rope_slotspace: bool = True |
| rope_base_slotspace: float = 100000.0 |
|
|
| |
| write_chunk_size: int = 128 |
| enable_compiled: bool = True |
|
|
| |
| eval_max_batches: int = 150 |
| analytics_last_k: int = 4 |
|
|
| |
| output_dir: str = "./drive/MyDrive/asm_outputs" |
| tag: str = "asm_wikitext" |
| cache_dir: str = "./drive/MyDrive/asm_caches/fineweb/1B" |
| val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl" |
|
|
|
|
| |
| |
| |
| class ASMBlock(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| num_slots: int, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.1, |
| |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.0, |
| normalize_k: bool = False, |
| |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| use_alibi_write: bool = True, |
| |
| |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 32, |
| slotspace_gate_init: float = -10.0, |
| slotspace_dropout: float = 0.0, |
| slotspace_signed_weights: bool = True, |
| |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| |
| write_chunk_size: int = 128, |
| enable_compiled: bool = False, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(embed_dim) |
|
|
| self.asa = AddressedStateAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_slots=num_slots, |
| dropout=dropout, |
|
|
| read_temperature=read_temperature, |
| write_temperature=write_temperature, |
| state_fp32=state_fp32, |
| slot_dropout=slot_dropout, |
| normalize_k=normalize_k, |
|
|
| use_rope_keys=use_rope_keys, |
| rope_base=rope_base, |
| use_alibi_write=use_alibi_write, |
| alibi_strength_init=alibi_strength_init, |
| learn_alibi_strength=learn_alibi_strength, |
| min_strength=min_strength, |
|
|
| use_content_read=use_content_read, |
| content_read_init=content_read_init, |
| content_read_max_gamma=content_read_max_gamma, |
|
|
| use_slotspace_refine=use_slotspace_refine, |
| slotspace_dim=slotspace_dim, |
| slotspace_gate_init=slotspace_gate_init, |
| slotspace_dropout=slotspace_dropout, |
| slotspace_signed_weights=slotspace_signed_weights, |
|
|
| use_rope_slotspace=use_rope_slotspace, |
| rope_base_slotspace=rope_base_slotspace, |
|
|
| write_chunk_size=write_chunk_size, |
| enable_compiled=enable_compiled, |
|
|
| ) |
|
|
| self.norm2 = nn.LayerNorm(embed_dim) |
| hidden = int(embed_dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(embed_dim, hidden, bias=False), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, embed_dim, bias=False), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_info: bool = False, return_light_stats: Optional[bool] = None): |
| a, info = self.asa(self.norm1(x), attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats) |
| x = x + a |
| x = x + self.mlp(self.norm2(x)) |
| return x, info |
|
|
|
|
| |
| |
| |
| class ASMLanguageModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| embed_dim: int = 384, |
| num_layers: int = 6, |
| num_heads: int = 8, |
| num_slots: int = 8, |
| max_seq_len: int = 1024, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.1, |
| |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.05, |
| normalize_k: bool = False, |
| |
| tie_weights: bool = True, |
| |
| |
| use_abs_pos: bool = False, |
| |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| use_alibi_write: bool = True, |
| |
| |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 32, |
| slotspace_gate_init: float = -10.0, |
| slotspace_dropout: float = 0.0, |
| slotspace_signed_weights: bool = True, |
| |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| |
| write_chunk_size: int = 128, |
| enable_compiled: bool = False, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.embed_dim = embed_dim |
| self.max_seq_len = max_seq_len |
| self.use_abs_pos = bool(use_abs_pos) |
|
|
| self.tok = nn.Embedding(vocab_size, embed_dim) |
| self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None |
| self.drop = nn.Dropout(dropout) |
|
|
| self.blocks = nn.ModuleList([ |
| ASMBlock( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_slots=num_slots, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
|
|
| read_temperature=read_temperature, |
| write_temperature=write_temperature, |
| state_fp32=state_fp32, |
| slot_dropout=slot_dropout, |
| normalize_k=normalize_k, |
|
|
| use_rope_keys=use_rope_keys, |
| rope_base=rope_base, |
| use_alibi_write=use_alibi_write, |
|
|
| alibi_strength_init=alibi_strength_init, |
| learn_alibi_strength=learn_alibi_strength, |
| min_strength=min_strength, |
|
|
| use_content_read=use_content_read, |
| content_read_init=content_read_init, |
| content_read_max_gamma=content_read_max_gamma, |
|
|
| use_slotspace_refine=use_slotspace_refine, |
| slotspace_dim=slotspace_dim, slotspace_gate_init=slotspace_gate_init, |
| slotspace_dropout=slotspace_dropout, |
| slotspace_signed_weights=slotspace_signed_weights, |
| use_rope_slotspace=use_rope_slotspace, |
| rope_base_slotspace=rope_base_slotspace, |
|
|
| write_chunk_size=write_chunk_size, |
| enable_compiled=enable_compiled, |
| ) |
| for _ in range(num_layers) |
| ]) |
|
|
| self.norm = nn.LayerNorm(embed_dim) |
| self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) |
| if tie_weights: |
| self.lm_head.weight = self.tok.weight |
|
|
| self.apply(self._init) |
|
|
| def _init(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_info: bool = False, |
| return_light_stats: Optional[bool] = None, |
| ): |
| B, T = input_ids.shape |
| assert T <= self.max_seq_len, f"T={T} exceeds max_seq_len={self.max_seq_len}" |
|
|
| x = self.tok(input_ids) |
| if self.use_abs_pos: |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1) |
| x = x + self.pos(pos) |
|
|
| x = self.drop(x) |
|
|
| infos = [] |
| for blk in self.blocks: |
| x, info = blk(x, attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats) |
| if return_info: |
| infos.append(info) |
|
|
| x = self.norm(x) |
| logits = self.lm_head(x) |
| return (logits, infos) if return_info else logits |
|
|
|
|
| |
| |
| |
| def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel: |
| return ASMLanguageModel( |
| vocab_size=cfg.vocab_size, |
| embed_dim=cfg.embed_dim, |
| num_layers=cfg.num_layers, |
| num_heads=cfg.num_heads, |
| num_slots=cfg.num_slots, |
| max_seq_len=cfg.max_seq_len, |
| mlp_ratio=cfg.mlp_ratio, |
| dropout=cfg.dropout, |
|
|
| read_temperature=cfg.read_temperature, |
| write_temperature=cfg.write_temperature, |
| state_fp32=cfg.state_fp32, |
| slot_dropout=cfg.slot_dropout, |
| normalize_k=cfg.normalize_k, |
|
|
| tie_weights=cfg.tie_weights, |
|
|
| use_abs_pos=cfg.use_abs_pos, |
| use_rope_keys=cfg.use_rope_keys, |
| rope_base=cfg.rope_base, |
| use_alibi_write=cfg.use_alibi_write, |
|
|
| alibi_strength_init=cfg.alibi_strength_init, |
| learn_alibi_strength=cfg.learn_alibi_strength, |
| min_strength=cfg.min_strength, |
|
|
| use_content_read=cfg.use_content_read, |
| content_read_init=cfg.content_read_init, |
| content_read_max_gamma=cfg.content_read_max_gamma, |
|
|
| use_slotspace_refine=cfg.use_slotspace_refine, |
| slotspace_dim=cfg.slotspace_dim, |
| slotspace_gate_init=cfg.slotspace_gate_init, |
| slotspace_dropout=cfg.slotspace_dropout, |
| slotspace_signed_weights=cfg.slotspace_signed_weights, |
| use_rope_slotspace=cfg.use_rope_slotspace, |
| rope_base_slotspace=cfg.rope_base_slotspace, |
|
|
| write_chunk_size=cfg.write_chunk_size, |
| enable_compiled=cfg.enable_compiled, |
| ) |
|
|