""" SeqCond model — self-contained HuggingFace implementation. All model code is embedded here so that trust_remote_code=True works without any dependency on the original seqcond package. Architecture: - Hybrid recurrent-transformer: every (seqcond_ratio+1)-th block (1-indexed) is a standard Transformer decoder block; the rest are SeqCond blocks. - SeqCond blocks use complex-exponential accumulators (den_acc, re_acc, im_acc) for O(1) per-token autoregressive decoding. - Transformer blocks use GQA with RoPE and KV-cache for autoregressive decoding. """ import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_seqcond import SeqCondConfig # --------------------------------------------------------------------------- # Optional Triton kernels (accelerates SeqCond step, not required) # --------------------------------------------------------------------------- try: from .triton_kernels import ( gated_rmsnorm_triton, seqcond_step_triton, TRITON_AVAILABLE, ) except ImportError: gated_rmsnorm_triton = None TRITON_AVAILABLE = False seqcond_step_triton = None # --------------------------------------------------------------------------- # Normalisation layers # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, hidden_size: int, epsilon: float = 1e-5): super().__init__() self.epsilon = epsilon self.scale = nn.Parameter(torch.ones(hidden_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: orig = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.epsilon) return (x * self.scale.float()).to(orig) class GatedRMSNorm(nn.Module): """RMSNorm with SiLU gating: rmsnorm(x * silu(residual)).""" def __init__(self, hidden_size: int, epsilon: float = 1e-6): super().__init__() self.epsilon = epsilon self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: orig = x.dtype x = x.float() * F.silu(residual.float()) x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.epsilon) return (x * self.weight.float()).to(orig) # --------------------------------------------------------------------------- # Rotary Position Embedding # --------------------------------------------------------------------------- def precompute_freqs(maxlen: int, head_dim: int) -> Tuple[torch.Tensor, torch.Tensor]: half_d = head_dim // 2 pos = np.arange(maxlen)[:, None] dim = np.arange(half_d)[None, :] angles = pos * (1.0 / (10000 ** (dim / half_d))) cos = torch.from_numpy(np.cos(angles).astype(np.float32)) sin = torch.from_numpy(np.sin(angles).astype(np.float32)) return cos, sin def apply_rope(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: dim = tensor.shape[-1] // 2 cos = cos[..., :dim] sin = sin[..., :dim] x1, x2 = tensor[..., :dim], tensor[..., dim:] return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).view(tensor.shape) # --------------------------------------------------------------------------- # Transformer decoder block (GQA + RoPE) # --------------------------------------------------------------------------- class RotarySelfAttention(nn.Module): def __init__( self, d_model: int, num_heads: int, num_kv_heads: Optional[int] = None, dropout: float = 0.0, qk_norm: bool = False, qk_norm_eps: float = 1e-6, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self._num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads self.num_groups = num_heads // self._num_kv_heads self.head_dim = d_model // num_heads self.dropout = dropout self.qk_norm = qk_norm self.qk_norm_eps = qk_norm_eps self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, self._num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(d_model, self._num_kv_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor: if self.num_groups == 1: return x b, l = x.shape[:2] extra = x.shape[2:] x = x.view(b, l, self._num_kv_heads, 1, *extra[1:]) x = x.expand(b, l, self._num_kv_heads, self.num_groups, *extra[1:]) return x.reshape(b, l, self.num_heads, *extra[1:]) def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, return_state: bool = False, ): b, l = x.shape[0], x.shape[1] q = self.q_proj(x).reshape(b, l, self.num_heads, self.head_dim) k = self.k_proj(x).reshape(b, l, self._num_kv_heads, self.head_dim) v = self.v_proj(x).reshape(b, l, self._num_kv_heads, self.head_dim) q = apply_rope(q, cos, sin) cos_kv = cos[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else cos sin_kv = sin[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else sin k = apply_rope(k, cos_kv, sin_kv) if self.qk_norm: q_f = q.float(); k_f = k.float() q = (q_f * torch.rsqrt(q_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(q.dtype) k = (k_f * torch.rsqrt(k_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(k.dtype) k_cache = k; v_cache = v k = self._repeat_kv(k); v = self._repeat_kv(v) scale = 1.0 / math.sqrt(self.head_dim) scores = torch.einsum("blhd,bmhd->bhlm", q, k) * scale causal = torch.tril(torch.ones(l, l, dtype=torch.bool, device=x.device)).unsqueeze(0).unsqueeze(0) scores = torch.where(causal, scores, torch.full_like(scores, -1e4)) attn = F.softmax(scores.float(), dim=-1).to(v.dtype) if self.dropout > 0 and self.training: attn = F.dropout(attn, p=self.dropout) out = torch.einsum("bhql,blhd->bqhd", attn, v).reshape(b, l, self.d_model).to(x.dtype) if return_state: return self.out_proj(out), (k_cache, v_cache) return self.out_proj(out) def step( self, x_t: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], pos: torch.Tensor, cos_t: torch.Tensor, sin_t: torch.Tensor, seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, Tuple]: b = x_t.shape[0] q = self.q_proj(x_t).reshape(b, 1, self.num_heads, self.head_dim) k_new = self.k_proj(x_t).reshape(b, 1, self._num_kv_heads, self.head_dim) v_new = self.v_proj(x_t).reshape(b, 1, self._num_kv_heads, self.head_dim) q = apply_rope(q, cos_t, sin_t) cos_kv = cos_t[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else cos_t sin_kv = sin_t[:, :, : self._num_kv_heads, :] if self._num_kv_heads < self.num_heads else sin_t k_new = apply_rope(k_new, cos_kv, sin_kv) if self.qk_norm: q_f = q.float(); k_f = k_new.float() q = (q_f * torch.rsqrt(q_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(q.dtype) k_new = (k_f * torch.rsqrt(k_f.pow(2).mean(-1, keepdim=True) + self.qk_norm_eps)).to(k_new.dtype) k_cache, v_cache = kv_cache pos_idx = pos.long().view(b, 1, 1, 1).expand(-1, 1, k_new.size(2), k_new.size(3)) k_cache.scatter_(1, pos_idx, k_new.to(k_cache.dtype)) v_cache.scatter_(1, pos_idx, v_new.to(v_cache.dtype)) if seq_len is not None: k_slice, v_slice = k_cache[:, :seq_len], v_cache[:, :seq_len]; L = seq_len else: k_slice, v_slice = k_cache, v_cache; L = k_cache.shape[1] k_r = self._repeat_kv(k_slice); v_r = self._repeat_kv(v_slice) mask = torch.arange(L, device=k_cache.device).view(1, 1, 1, L) > pos.long().view(b, 1, 1, 1) scale = 1.0 / math.sqrt(self.head_dim) scores = torch.einsum("bqhd,bkhd->bhqk", q, k_r) * scale scores = scores.masked_fill(mask, float("-inf")) attn = F.softmax(scores.float(), dim=-1).to(v_r.dtype) out = torch.einsum("bhqk,bkhd->bqhd", attn, v_r).reshape(b, self.d_model).to(x_t.dtype) return self.out_proj(out), (k_cache, v_cache) class TransformerDecoderBlock(nn.Module): def __init__( self, d_model: int, num_heads: int, d_ff: int, num_kv_heads: Optional[int] = None, dropout: float = 0.0, norm_eps: float = 1e-6, qk_norm: bool = False, qk_norm_eps: float = 1e-6, ): super().__init__() self.norm1 = RMSNorm(d_model, epsilon=norm_eps) self.attn = RotarySelfAttention(d_model, num_heads, num_kv_heads, dropout, qk_norm, qk_norm_eps) self.norm2 = RMSNorm(d_model, epsilon=norm_eps) self.ff_in = nn.Linear(d_model, 2 * d_ff, bias=True) self.ff_out = nn.Linear(d_ff, d_model, bias=True) self.dropout = dropout def forward(self, x, cos, sin, mask=None, return_state=False): y = self.norm1(x) if return_state: y, kv = self.attn(y, cos=cos, sin=sin, mask=mask, return_state=True) else: y = self.attn(y, cos=cos, sin=sin, mask=mask) if self.dropout > 0 and self.training: y = F.dropout(y, p=self.dropout) x = x + y y = self.norm2(x) u, v = self.ff_in(y).chunk(2, dim=-1) y = self.ff_out(F.silu(v) * u) if self.dropout > 0 and self.training: y = F.dropout(y, p=self.dropout) out = x + y return (out, kv) if return_state else out def step(self, x_t, kv_cache, pos, cos_t, sin_t, seq_len=None): y = self.norm1(x_t) y, new_kv = self.attn.step(y, kv_cache, pos, cos_t, sin_t, seq_len=seq_len) x_t = x_t + y y = self.norm2(x_t) u, v = self.ff_in(y).chunk(2, dim=-1) return x_t + self.ff_out(F.silu(v) * u), new_kv # --------------------------------------------------------------------------- # SeqCond attention block # --------------------------------------------------------------------------- class SeqCondAttention(nn.Module): def __init__( self, d_model: int, num_heads: int = 12, num_query_heads: int = 6, num_anchor_heads: int = 0, num_thetas: int = 1, conv_kernel_size: int = 4, expand_factor: int = 1, out_expand_factor: int = 3, dropout: float = 0.0, maxlen: Optional[int] = None, **kwargs, ): super().__init__() assert num_heads % num_query_heads == 0 self.d_model = d_model self.K = num_heads self.K_q = num_query_heads self.n_rep = num_heads // num_query_heads self.M = num_thetas self.num_decay_heads = num_heads - num_anchor_heads self.num_anchor_heads = num_anchor_heads self.conv_kernel_size = conv_kernel_size self.dropout_rate = dropout self.maxlen = maxlen d_inner = int(d_model * expand_factor) self.H = max(1, d_inner // (self.K * self.M)) self.dim_memory = self.K * self.H self.dim_query_head = self.H * self.M * 2 self.dim_query_total = self.K_q * self.dim_query_head self.dim_expand = self.H * out_expand_factor self.dim_swiglu_head = self.dim_expand * 2 self.dim_swiglu_total = self.K * self.dim_swiglu_head self.dim_mem_total = self.dim_memory + self.K self.dim_conv_total = self.dim_mem_total + self.dim_query_total self.in_proj = nn.Linear(d_model, self.dim_conv_total, bias=False) self.conv_weight = nn.Parameter(torch.empty(self.dim_conv_total, 1, conv_kernel_size)) nn.init.kaiming_normal_(self.conv_weight) # Cached buffers (computed lazily) self.register_buffer("_conv_kernel_t", None) self.register_buffer("_theta_cached", None) self.register_buffer("_w_int_cached", None) self.register_buffer("_decay_slopes_cached", None) self.register_buffer("_anchor_slopes_cached", None) self.register_buffer("_phase_scale_b", None) self.register_buffer("_score_scale_b", None) self.register_buffer("_score_bias_b", None) self._triton_out_re_buffer = None self._triton_out_im_buffer = None self._triton_norm_buffer = None if self.M == 1: init_theta = np.geomspace(0.001, 3.0, self.K).reshape(1, 1, self.K, 1, 1) init_theta = np.tile(init_theta, (1, 1, 1, self.H, 1)) x = np.clip((init_theta - 0.001) / 2.999, 1e-4, 1 - 1e-4) self.theta_raw = nn.Parameter(torch.from_numpy((np.log(x) - np.log(1 - x)).astype(np.float32))) self.w_int_raw = nn.Parameter(torch.zeros(1, 1, self.K_q, self.n_rep, self.H, 1)) else: init_vals = np.geomspace(0.001, 3.0, self.M).reshape(1, 1, 1, 1, self.M) init_vals = np.tile(init_vals, (1, 1, self.K, self.H, 1)) self.theta_d_raw = nn.Parameter(torch.from_numpy(np.log(np.exp(init_vals) - 1.0 + 1e-4).astype(np.float32))) self.w_int_raw = nn.Parameter(torch.zeros(1, 1, self.K_q, self.n_rep, self.H, self.M)) if self.num_decay_heads > 0: self.decay_slopes = nn.Parameter( torch.from_numpy(np.log(np.exp(np.geomspace(0.001, 0.1, self.num_decay_heads)) - 1).astype(np.float32)) ) if self.num_anchor_heads > 0: self.anchor_slopes = nn.Parameter( torch.from_numpy(np.log(np.exp(np.geomspace(0.01, 0.1, self.num_anchor_heads)) - 1).astype(np.float32)) ) self.score_scale = nn.Parameter(torch.ones(self.K)) self.score_bias = nn.Parameter(torch.zeros(self.K)) self.phase_scale = nn.Parameter(torch.ones(self.K)) self.gate_proj = nn.Linear(d_model, self.K * 2 * self.H, bias=False) self.gated_norm = GatedRMSNorm(self.K * 2 * self.H) self.W_readout = nn.Parameter(torch.empty(self.K, 2 * self.H, self.dim_swiglu_head)) nn.init.xavier_uniform_(self.W_readout) self.out_proj = nn.Linear(self.dim_swiglu_total // 2, d_model, bias=False) def forward(self, x: torch.Tensor, mask=None, return_state: bool = False): B, L, D = x.shape z_conv = self.in_proj(x) z_conv_t = F.pad(z_conv.transpose(1, 2), (self.conv_kernel_size - 1, 0)) z_conv = F.silu(F.conv1d(z_conv_t, self.conv_weight, groups=self.dim_conv_total).transpose(1, 2)) z_mem = z_conv[..., : self.dim_mem_total] q_raw = z_conv[..., self.dim_mem_total :] k_val = z_mem[..., : self.dim_memory].reshape(B, L, self.K, self.H) s_raw = z_mem[..., self.dim_memory :] q_raw = q_raw.reshape(B, L, self.K_q, 1, self.H, self.M, 2) q_re, q_im = q_raw[..., 0], q_raw[..., 1] if self.M == 1: theta = 0.001 + 2.999 * torch.sigmoid(self.theta_raw) else: theta_d = F.softplus(self.theta_d_raw) + 1e-4 theta_accum = torch.cumsum(theta_d, dim=-1) theta = 0.001 + (theta_accum / theta_accum[..., -1:]) * 2.999 w_int = torch.exp(self.w_int_raw) w_int = w_int / (w_int.sum(dim=-1, keepdim=True) + 1e-6) pos = torch.arange(L, dtype=torch.float32, device=x.device) log_w_list = [] if self.num_decay_heads > 0: slopes = F.softplus(self.decay_slopes).view(1, 1, -1) dist = torch.clamp((self.maxlen or L) - 1 - pos, min=0.0).view(1, L, 1) log_w_list.append(-slopes * dist) if self.num_anchor_heads > 0: log_w_list.append(-F.softplus(self.anchor_slopes).view(1, 1, -1) * pos.view(1, L, 1)) log_tw = torch.cat(log_w_list, dim=2) if log_w_list else torch.zeros(1, L, self.K, device=x.device) score_raw = self.score_scale.view(1, 1, -1) * s_raw.float() + self.score_bias.view(1, 1, -1) p_w = (F.softplus(score_raw) * torch.exp(log_tw)).clamp(1e-4, 5000.0) k_f32 = k_val.float().unsqueeze(-1) p_w_b = p_w.unsqueeze(-1).unsqueeze(-1) phase_scale_b = self.phase_scale.view(1, 1, self.K, 1, 1) k_scaled = k_f32 * phase_scale_b phi = (k_scaled / (1.0 + k_scaled.abs())) * theta kvw = k_f32 * p_w_b re = kvw * torch.cos(phi) im = kvw * torch.sin(phi) flat_size = self.K * self.H * self.M stack = torch.cat([p_w.float(), re.reshape(B, L, -1), im.reshape(B, L, -1)], dim=-1) cumsum = torch.cumsum(stack, dim=1) den_acc = cumsum[..., : self.K] re_acc = cumsum[..., self.K : self.K + flat_size].reshape(B, L, self.K, self.H, self.M) im_acc = cumsum[..., self.K + flat_size :].reshape(B, L, self.K, self.H, self.M) inv_den = (1.0 / torch.clamp(den_acc, min=1e-4)).unsqueeze(-1).unsqueeze(-1) state_re_g = (re_acc * inv_den).reshape(B, L, self.K_q, self.n_rep, self.H, self.M) state_im_g = (im_acc * inv_den).reshape(B, L, self.K_q, self.n_rep, self.H, self.M) scale = 1.0 / (self.H ** 0.5) match_re = ((state_re_g * q_re + state_im_g * q_im) * scale).float() match_im = ((state_im_g * q_re - state_re_g * q_im) * scale).float() out_re = ((match_re * w_int.float()).sum(dim=-1)).reshape(B, L, self.K, self.H).to(x.dtype) out_im = ((match_im * w_int.float()).sum(dim=-1)).reshape(B, L, self.K, self.H).to(x.dtype) out_complex = self.gated_norm(torch.cat([out_re, out_im], dim=-1).reshape(B, L, -1), self.gate_proj(x)) out_complex = out_complex.reshape(B, L, self.K, 2 * self.H) y_raw = torch.einsum("blkf,kfn->blkn", out_complex, self.W_readout.to(out_complex.dtype)) y_val, y_gate = y_raw.chunk(2, dim=-1) output = self.out_proj((y_val * torch.sigmoid(y_gate)).reshape(B, L, -1).to(x.dtype)) if return_state: z_pre = self.in_proj(x) buf_sz = self.conv_kernel_size - 1 conv_buf = z_pre[:, -buf_sz:] if L >= buf_sz else torch.cat([ torch.zeros(B, buf_sz - L, self.dim_conv_total, device=x.device, dtype=z_pre.dtype), z_pre], dim=1) state = ( p_w.sum(dim=1), re_acc[:, -1], im_acc[:, -1], torch.full((B,), L, dtype=torch.float32, device=x.device), conv_buf, ) return output, state return output def step(self, x_t: torch.Tensor, state: Tuple, use_triton: bool = False) -> Tuple: B, D = x_t.shape den_acc, re_acc, im_acc, pos, conv_buffer = state z_conv = self.in_proj(x_t) if self._conv_kernel_t is None or self._conv_kernel_t.device != z_conv.device: self._conv_kernel_t = self.conv_weight[:, 0, :].t().contiguous() conv_input = torch.cat([conv_buffer, z_conv.unsqueeze(1)], dim=1) z_conv_act = F.silu((conv_input * self._conv_kernel_t).sum(dim=1)) z_mem = z_conv_act[..., : self.dim_mem_total] q_raw = z_conv_act[..., self.dim_mem_total :] k_val = z_mem[..., : self.dim_memory].reshape(B, self.K, self.H) s_raw = z_mem[..., self.dim_memory :] q_raw = q_raw.reshape(B, self.K_q, 1, self.H, self.M, 2) q_re, q_im = q_raw[..., 0], q_raw[..., 1] if self._theta_cached is None: if self.M == 1: self._theta_cached = (0.001 + 2.999 * torch.sigmoid(self.theta_raw))[0, 0] else: theta_d = F.softplus(self.theta_d_raw) + 1e-4 theta_accum = torch.cumsum(theta_d, dim=-1) self._theta_cached = (0.001 + (theta_accum / theta_accum[..., -1:]) * 2.999)[0, 0] w = torch.exp(self.w_int_raw) self._w_int_cached = w / (w.sum(dim=-1, keepdim=True) + 1e-6) self._w_int_cached = self._w_int_cached[0, 0] theta = self._theta_cached w_int = self._w_int_cached if self._decay_slopes_cached is None and self.num_decay_heads > 0: self._decay_slopes_cached = F.softplus(self.decay_slopes).view(1, -1) if self._anchor_slopes_cached is None and self.num_anchor_heads > 0: self._anchor_slopes_cached = F.softplus(self.anchor_slopes).view(1, -1) if self._score_scale_b is None: self._score_scale_b = self.score_scale.view(1, -1) self._score_bias_b = self.score_bias.view(1, -1) self._phase_scale_b = self.phase_scale.view(1, self.K, 1, 1) log_w_list = [] if self.num_decay_heads > 0: dist = (self.maxlen or 2048) - 1 - pos.unsqueeze(-1) log_w_list.append(-self._decay_slopes_cached * dist.clamp(min=0.0)) if self.num_anchor_heads > 0: log_w_list.append(-self._anchor_slopes_cached * pos.unsqueeze(-1)) log_tw = torch.cat(log_w_list, dim=1) if log_w_list else torch.zeros(B, self.K, device=x_t.device) if ( use_triton and x_t.is_cuda and self.n_rep == 1 and TRITON_AVAILABLE and seqcond_step_triton is not None ): if ( self._triton_out_re_buffer is None or self._triton_out_re_buffer.shape != (B, self.K, self.H) or self._triton_out_re_buffer.device != x_t.device ): self._triton_out_re_buffer = torch.empty( B, self.K, self.H, device=x_t.device, dtype=torch.float32 ) self._triton_out_im_buffer = torch.empty_like( self._triton_out_re_buffer ) out_re, out_im = seqcond_step_triton( k_val, s_raw, q_re.squeeze(2), q_im.squeeze(2), re_acc, im_acc, den_acc, theta, w_int, self.phase_scale, self.score_scale, self.score_bias, log_tw, out_re_buffer=self._triton_out_re_buffer, out_im_buffer=self._triton_out_im_buffer, ) out_complex = torch.cat([out_re, out_im], dim=-1) else: score_raw = self._score_scale_b * s_raw.float() + self._score_bias_b p_w = (F.softplus(score_raw) * torch.exp(log_tw)).clamp(1e-4, 5000.0) k_f32 = k_val.float().unsqueeze(-1) k_scaled = k_f32 * self._phase_scale_b phi = (k_scaled / (1.0 + k_scaled.abs())) * theta kvw = k_f32 * p_w.unsqueeze(-1).unsqueeze(-1) re = kvw * torch.cos(phi) im = kvw * torch.sin(phi) den_acc.add_(p_w); re_acc.add_(re); im_acc.add_(im) inv_den = (1.0 / torch.clamp(den_acc, min=1e-4)).unsqueeze(-1).unsqueeze(-1) state_re_g = (re_acc * inv_den).reshape(B, self.K_q, self.n_rep, self.H, self.M) state_im_g = (im_acc * inv_den).reshape(B, self.K_q, self.n_rep, self.H, self.M) scale = 1.0 / (self.H ** 0.5) match_re = ((state_re_g * q_re + state_im_g * q_im) * scale).float() match_im = ((state_im_g * q_re - state_re_g * q_im) * scale).float() out_re = ((match_re * w_int.float()).sum(-1)).reshape(B, self.K, self.H).to(x_t.dtype) out_im = ((match_im * w_int.float()).sum(-1)).reshape(B, self.K, self.H).to(x_t.dtype) out_complex = torch.cat([out_re, out_im], dim=-1) out_complex = out_complex.reshape(B, self.K, 2 * self.H) out_complex_flat = out_complex.reshape(B, -1) gate_for_norm = self.gate_proj(x_t) if use_triton and x_t.is_cuda and gated_rmsnorm_triton is not None: if ( self._triton_norm_buffer is None or self._triton_norm_buffer.shape != out_complex_flat.shape or self._triton_norm_buffer.device != x_t.device ): self._triton_norm_buffer = torch.empty( out_complex_flat.shape, device=x_t.device, dtype=torch.float32, ) out_flat = gated_rmsnorm_triton( out_complex_flat, gate_for_norm, self.gated_norm.weight, self.gated_norm.epsilon, out_buffer=self._triton_norm_buffer, ) else: out_flat = self.gated_norm(out_complex_flat, gate_for_norm) out_complex = out_flat.to(x_t.dtype).reshape(B, self.K, 2 * self.H) y_raw = torch.einsum("bkf,kfn->bkn", out_complex, self.W_readout.to(out_complex.dtype)) y_val, y_gate = y_raw.chunk(2, dim=-1) out = self.out_proj((y_val * torch.sigmoid(y_gate)).reshape(B, -1).to(x_t.dtype)) pos.add_(1).clamp_(max=(self.maxlen or 2048) - 1) if self.conv_kernel_size > 1: if self.conv_kernel_size > 2: conv_buffer[:, :-1, :].copy_(conv_buffer[:, 1:, :].clone()) conv_buffer[:, -1, :].copy_(z_conv) return out, (den_acc, re_acc, im_acc, pos, conv_buffer) class SeqCondBlock(nn.Module): def __init__(self, d_model: int, norm_eps: float = 1e-6, **kwargs): super().__init__() self.norm = RMSNorm(d_model, epsilon=norm_eps) self.attn = SeqCondAttention(d_model=d_model, **kwargs) def forward(self, x, mask=None, return_state=False): if return_state: out, state = self.attn(self.norm(x), mask=mask, return_state=True) return x + out, state return x + self.attn(self.norm(x), mask=mask) def step(self, x_t, state, use_triton=False): out, new_state = self.attn.step(self.norm(x_t), state, use_triton=use_triton) return x_t + out, new_state # --------------------------------------------------------------------------- # Core SeqCond language model # --------------------------------------------------------------------------- class SeqCondModel(nn.Module): """Core SeqCond model (no HF wrapper). Used internally by SeqCondForCausalLM.""" def __init__(self, config: SeqCondConfig): super().__init__() self.d_model = config.d_model self.d_ff = config.d_ff self.num_layers = config.num_layers self.vocab_size = config.vocab_size self.maxlen = config.maxlen self.num_heads = config.num_heads self.num_kv_heads = config.num_kv_heads if config.num_kv_heads is not None else config.num_heads self.seqcond_ratio = config.seqcond_ratio self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.use_positional_embedding = config.use_positional_embedding if config.use_positional_embedding: self.position_embedding = nn.Embedding(config.maxlen, config.d_model) head_dim = config.d_model // config.num_heads cos, sin = precompute_freqs(config.maxlen, head_dim) self.register_buffer("cos_emb", cos) self.register_buffer("sin_emb", sin) self.blocks = nn.ModuleList() self.block_types = [] for i in range(config.num_layers): if (i + 1) % (config.seqcond_ratio + 1) == 0: block = TransformerDecoderBlock( d_model=config.d_model, num_heads=config.num_heads, d_ff=config.d_ff, num_kv_heads=self.num_kv_heads, dropout=config.dropout, qk_norm=config.qk_norm, qk_norm_eps=config.qk_norm_eps, ) self.block_types.append("transformer") else: block = SeqCondBlock( d_model=config.d_model, num_heads=config.seqcond_heads, num_query_heads=config.num_query_heads, num_anchor_heads=config.num_anchor_heads, num_thetas=config.num_thetas, conv_kernel_size=config.conv_kernel_size, expand_factor=config.expand_factor, out_expand_factor=config.out_expand_factor, dropout=config.dropout, maxlen=config.maxlen, ) self.block_types.append("seqcond") self.blocks.append(block) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.tie_weights: self.lm_head.weight = self.embedding.weight def forward(self, input_ids: torch.Tensor) -> torch.Tensor: B, L = input_ids.shape x = self.embedding(input_ids) if self.use_positional_embedding: x = x + self.position_embedding(torch.arange(L, device=input_ids.device)) cos = self.cos_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1) sin = self.sin_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1) for block, bt in zip(self.blocks, self.block_types): x = block(x, cos, sin) if bt == "transformer" else block(x) return self.lm_head(x) def prefill(self, input_ids: torch.Tensor, return_all_logits: bool = False): B, L = input_ids.shape device = input_ids.device x = self.embedding(input_ids) if self.use_positional_embedding: x = x + self.position_embedding(torch.arange(L, device=device)) cos = self.cos_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1) sin = self.sin_emb[:L].unsqueeze(0).unsqueeze(2).expand(B, L, self.num_heads, -1) states = [] for block, bt in zip(self.blocks, self.block_types): if bt == "transformer": x, kv = block(x, cos, sin, return_state=True) k, v = kv k_cache = torch.zeros(B, self.maxlen, self.num_kv_heads, self.d_model // self.num_heads, device=device, dtype=k.dtype) v_cache = torch.zeros_like(k_cache) k_cache[:, :L] = k; v_cache[:, :L] = v states.append((k_cache, v_cache)) else: x, state = block(x, return_state=True) states.append(state) logits = self.lm_head(x) if return_all_logits: return logits, states return logits[:, -1:, :], states def init_state(self, batch_size: int, device: torch.device) -> List: states = [] for block, bt in zip(self.blocks, self.block_types): if bt == "transformer": k = torch.zeros(batch_size, self.maxlen, self.num_kv_heads, self.d_model // self.num_heads, device=device) states.append((k, torch.zeros_like(k))) else: a = block.attn states.append(( torch.zeros(batch_size, a.K, device=device), torch.zeros(batch_size, a.K, a.H, a.M, device=device), torch.zeros(batch_size, a.K, a.H, a.M, device=device), torch.zeros(batch_size, device=device), torch.zeros(batch_size, a.conv_kernel_size - 1, a.dim_conv_total, device=device), )) return states def step(self, token_id: torch.Tensor, states: List, pos=None, seq_len=None, use_triton=False): B = token_id.size(0) if pos is None: for state, bt in zip(states, self.block_types): if bt == "seqcond": pos = state[3]; break if pos is None: pos = torch.zeros(B, device=token_id.device, dtype=torch.long) x = self.embedding(token_id).squeeze(1) pos = pos.clamp(max=self.maxlen - 1) if self.use_positional_embedding: x = x + torch.index_select(self.position_embedding.weight, 0, pos.long()) pos_idx = pos.long() cos_t = torch.index_select(self.cos_emb, 0, pos_idx).unsqueeze(1).unsqueeze(1).expand(B, 1, self.num_heads, -1) sin_t = torch.index_select(self.sin_emb, 0, pos_idx).unsqueeze(1).unsqueeze(1).expand(B, 1, self.num_heads, -1) new_states = [] for block, bt, state in zip(self.blocks, self.block_types, states): if bt == "transformer": x, ns = block.step(x, state, pos, cos_t, sin_t, seq_len=seq_len) else: x, ns = block.step(x, state, use_triton=use_triton) new_states.append(ns) return self.lm_head(x), new_states # --------------------------------------------------------------------------- # HuggingFace wrapper # --------------------------------------------------------------------------- class SeqCondPreTrainedModel(PreTrainedModel): config_class = SeqCondConfig base_model_prefix = "model" supports_gradient_checkpointing = False def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=0.02) class SeqCondForCausalLM(SeqCondPreTrainedModel): """ SeqCond causal language model, HuggingFace-compatible. Supports: - Standard HF forward() for training / perplexity evaluation. - Custom generate() using state-based O(1) decoding. - generate_batch() for batched generation with per-sample early stopping. - precompute() / use_cuda_graph=True for CUDA-graph-accelerated decoding. """ _CUDA_GRAPH_SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] def __init__(self, config: SeqCondConfig): super().__init__(config) self.model = SeqCondModel(config) self.post_init() # CUDA graph state self._cg_graphs: dict = {} self._cg_logits: dict = {} self._cg_token: Optional[torch.Tensor] = None self._cg_states: Optional[list] = None self._cg_use_triton: bool = False self._cg_ready: bool = False # True after precompute() has been called # ------------------------------------------------------------------ # CUDA graph helpers # ------------------------------------------------------------------ def _cg_get_seq_len(self, pos: int) -> int: for s in self._CUDA_GRAPH_SEQ_LENS: if s >= pos + 1: return s return self._CUDA_GRAPH_SEQ_LENS[-1] def _cg_copy_states(self, src, dst): for s, d in zip(src, dst): for st, dt in zip(s, d): dt.copy_(st) def _cg_capture(self, seq_len: int): saved = self.model.init_state(1, device=self._cg_token.device) self._cg_copy_states(self._cg_states, saved) stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): for _ in range(3): self.model.step(self._cg_token, self._cg_states, seq_len=seq_len, use_triton=self._cg_use_triton) torch.cuda.current_stream().wait_stream(stream) self._cg_copy_states(saved, self._cg_states) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): logits, _ = self.model.step(self._cg_token, self._cg_states, seq_len=seq_len, use_triton=self._cg_use_triton) self._cg_copy_states(saved, self._cg_states) self._cg_graphs[seq_len] = graph self._cg_logits[seq_len] = logits @torch.no_grad() def precompute(self, max_seq_len: int = 2048, use_triton: bool = False): """Pre-capture CUDA graphs up to max_seq_len. Call once after loading.""" if not torch.cuda.is_available(): return if self._cg_use_triton != use_triton: self._cg_graphs = {} self._cg_logits = {} self._cg_use_triton = use_triton device = next(self.parameters()).device self._cg_token = torch.zeros((1, 1), dtype=torch.long, device=device) self._cg_states = self.model.init_state(1, device=device) for s in self._CUDA_GRAPH_SEQ_LENS: if s > max_seq_len: break self._cg_capture(s) self._cg_ready = True print(f"Pre-captured {len(self._cg_graphs)} CUDA graphs (triton={use_triton}).") def get_input_embeddings(self): return self.model.embedding def set_input_embeddings(self, value): self.model.embedding = value def get_output_embeddings(self): return self.model.lm_head def set_output_embeddings(self, value): self.model.lm_head = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: """ Standard forward pass (used for training / perplexity). Note: attention_mask is accepted for API compatibility but is not used in the forward pass — SeqCond is always causal. """ logits = self.model(input_ids) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) return CausalLMOutputWithPast(loss=loss, logits=logits) @staticmethod def _detect_triton() -> bool: try: import triton # noqa: F401 return True except ImportError: return False @torch.no_grad() def generate( self, input_ids: torch.LongTensor, max_new_tokens: int = 1024, temperature: float = 0.15, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, eos_token_id: Optional[int] = None, acceleration: str = "auto", use_triton: Optional[bool] = None, use_cuda_graph: Optional[bool] = None, **kwargs, ) -> torch.LongTensor: """ Autoregressive generation with state-based O(1) decoding. Args: acceleration: One of ``"auto"`` (default), ``"cuda_graph"``, ``"triton"`` (cuda_graph + triton), or ``"none"``. ``"auto"`` uses CUDA graphs when a GPU is available, and adds Triton kernels automatically if the triton package is installed. Explicit ``use_triton`` / ``use_cuda_graph`` kwargs override this. Returns the full sequence (prompt + generated tokens) as a LongTensor. """ # ------------------------------------------------------------------ # Resolve acceleration mode # ------------------------------------------------------------------ on_cuda = torch.cuda.is_available() and input_ids.device.type == "cuda" if acceleration == "auto": _use_cuda_graph = on_cuda _use_triton = on_cuda and self._detect_triton() elif acceleration == "triton": _use_cuda_graph = on_cuda _use_triton = on_cuda elif acceleration == "cuda_graph": _use_cuda_graph = on_cuda _use_triton = False else: # "none" _use_cuda_graph = False _use_triton = False # Legacy kwargs override if use_cuda_graph is not None: _use_cuda_graph = use_cuda_graph and on_cuda if use_triton is not None: _use_triton = use_triton and on_cuda # Lazy precompute on first generate() call if _use_cuda_graph and not self._cg_ready: self.precompute(max_seq_len=2048, use_triton=_use_triton) elif _use_cuda_graph and self._cg_use_triton != _use_triton: self.precompute(max_seq_len=2048, use_triton=_use_triton) use_triton = _use_triton use_cuda_graph = _use_cuda_graph if eos_token_id is None: eos_token_id = self.config.eos_token_id device = input_ids.device B = input_ids.size(0) # Prefill logits, states = self.model.prefill(input_ids) logits = logits.squeeze(1) # (B, vocab) generated = input_ids.tolist() finished = [False] * B token_buf = torch.zeros((B, 1), dtype=torch.long, device=device) seq_len = input_ids.size(1) # CUDA graph: sync prefill states into static buffer once before decode loop if use_cuda_graph and torch.cuda.is_available() and B == 1: if self._cg_token is None: self._cg_use_triton = use_triton self._cg_token = torch.zeros((1, 1), dtype=torch.long, device=device) self._cg_states = self.model.init_state(1, device=device) self._cg_copy_states(states, self._cg_states) states = self._cg_states for _ in range(max_new_tokens): # Temperature scaling if temperature > 0: ls = logits / temperature else: ls = logits.clone() # Repetition penalty if repetition_penalty != 1.0: for bi, toks in enumerate(generated): for t in set(toks): if 0 <= t < self.config.vocab_size: ls[bi, t] /= repetition_penalty # Sampling if temperature == 0: next_tokens = torch.argmax(ls, dim=-1) else: if top_k > 0: kth = torch.topk(ls, top_k, dim=-1).values[:, -1:] ls = ls.masked_fill(ls < kth, float("-inf")) if top_p < 1.0: sorted_ls, sorted_idx = torch.sort(ls, dim=-1, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1) sorted_remove = cum_probs > top_p sorted_remove[:, 1:] = sorted_remove[:, :-1].clone() sorted_remove[:, 0] = False remove = torch.zeros_like(sorted_remove) remove.scatter_(1, sorted_idx, sorted_remove) ls = ls.masked_fill(remove, float("-inf")) probs = F.softmax(ls, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) for bi in range(B): tok = next_tokens[bi].item() generated[bi].append(tok) if eos_token_id is not None and tok == eos_token_id: finished[bi] = True token_buf[bi, 0] = tok if all(finished): break seq_len += 1 if use_cuda_graph and torch.cuda.is_available() and B == 1: cg_sl = self._cg_get_seq_len(seq_len - 1) if cg_sl not in self._cg_graphs: self._cg_capture(cg_sl) self._cg_token.copy_(token_buf) self._cg_graphs[cg_sl].replay() logits = self._cg_logits[cg_sl] else: logits, states = self.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) max_len = max(len(g) for g in generated) pad_id = self.config.pad_token_id or 0 out = torch.full((B, max_len), pad_id, dtype=torch.long, device=device) for bi, g in enumerate(generated): out[bi, : len(g)] = torch.tensor(g, dtype=torch.long, device=device) return out @torch.no_grad() def generate_batch( self, input_ids_list: List[torch.LongTensor], max_new_tokens: int = 1024, temperature: float = 0.7, eos_token_id: Optional[int] = None, use_triton: bool = False, ) -> List[List[int]]: """ Batched generation: each prompt is prefilled independently, then decoded in lockstep with per-sample early stopping. Args: input_ids_list: List of 1D LongTensors, one per prompt. Returns: List of generated token id lists (completion only, EOS stripped). """ if eos_token_id is None: eos_token_id = self.config.eos_token_id device = input_ids_list[0].device B = len(input_ids_list) # Per-sample prefill all_logits, all_states = [], [] for ids in input_ids_list: lg, st = self.model.prefill(ids.unsqueeze(0)) all_logits.append(lg.squeeze(1)) all_states.append(st) logits = torch.cat(all_logits, dim=0) # Stack states num_blocks = len(all_states[0]) states = [ tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i]))) for i in range(num_blocks) ] generated = [[] for _ in range(B)] finished = [False] * B active_map = list(range(B)) token_buf = torch.zeros((B, 1), dtype=torch.long, device=device) seq_len = max(ids.size(0) for ids in input_ids_list) for _ in range(max_new_tokens): B_cur = len(active_map) if B_cur == 0: break if temperature == 0: next_tokens = torch.argmax(logits, dim=-1) else: probs = F.softmax(logits / temperature, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) newly_done = set() for bi in range(B_cur): oi = active_map[bi] tok = next_tokens[bi].item() generated[oi].append(tok) if eos_token_id is not None and tok == eos_token_id: finished[oi] = True newly_done.add(bi) else: token_buf[bi, 0] = tok if all(finished): break if newly_done: keep = [bi for bi in range(B_cur) if bi not in newly_done] if not keep: break keep_idx = torch.tensor(keep, device=device) token_buf = token_buf[keep_idx].contiguous() states = [tuple(s[keep_idx].contiguous() for s in st) for st in states] active_map = [active_map[bi] for bi in keep] seq_len += 1 logits, states = self.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) results = [] for toks in generated: if toks and toks[-1] == eos_token_id: toks = toks[:-1] results.append(toks) return results