import math from dataclasses import dataclass from typing import Tuple, Optional, Literal from functools import lru_cache from contextlib import contextmanager import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn world_size = 1 rank = 0 block_size = 128 fp4_block_size = 32 w4a16_group_size = 128 default_dtype = torch.bfloat16 scale_fmt = None scale_dtype = torch.float32 w4a16_mode = False # set in Transformer.__init__ when args.dtype == "w4a16" def dequantize_w4a16(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor: """Auto-round / auto_gptq W4A16 packing -> BF16 weight [out, in]. qweight: int32 [in/8, out], LSB-first 4-bit packed along input dim qzeros : int32 [in/g, out/8], LSB-first 4-bit packed along output dim scales : bf16 [in/g, out] """ in_packed, out_features = qweight.shape in_features = in_packed * 8 n_groups = scales.shape[0] device = qweight.device shifts = torch.arange(0, 32, 4, device=device, dtype=torch.int32) w = (qweight.unsqueeze(1) >> shifts.view(1, 8, 1)) & 0xF # [in/8, 8, out] w = w.reshape(in_features, out_features).to(torch.float32) z = (qzeros.unsqueeze(2) >> shifts.view(1, 1, 8)) & 0xF # [in/g, out/8, 8] z = z.reshape(n_groups, out_features).to(torch.float32) + 1.0 # GPTQ stores zero - 1 s = scales.to(torch.float32) w = w.view(n_groups, group_size, out_features) deq = (w - z.unsqueeze(1)) * s.unsqueeze(1) deq = deq.view(in_features, out_features) return deq.t().contiguous().to(torch.bfloat16) @contextmanager def set_dtype(dtype): """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs).""" prev = torch.get_default_dtype() torch.set_default_dtype(dtype) try: yield finally: torch.set_default_dtype(prev) @dataclass class ModelArgs: """Model hyperparameters. Field names match the config JSON keys.""" max_batch_size: int = 4 max_seq_len: int = 4096 dtype: Literal["bf16", "fp8", "w4a16"] = "fp8" scale_fmt: Literal[None, "ue8m0"] = "ue8m0" expert_dtype: Literal[None, "fp4"] = None scale_dtype: Literal["fp32", "fp8"] = "fp8" vocab_size: int = 129280 dim: int = 4096 moe_inter_dim: int = 4096 n_layers: int = 7 n_hash_layers: int = 0 n_mtp_layers: int = 1 n_heads: int = 64 # moe n_routed_experts: int = 8 n_shared_experts: int = 1 n_activated_experts: int = 2 score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus" route_scale: float = 1. swiglu_limit: float = 0. # mqa q_lora_rank: int = 1024 head_dim: int = 512 rope_head_dim: int = 64 norm_eps: float = 1e-6 o_groups: int = 8 o_lora_rank: int = 1024 window_size: int = 128 compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0) # yarn compress_rope_theta: float = 40000.0 original_seq_len: int = 0 rope_theta: float = 10000.0 rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 # index index_n_heads: int = 64 index_head_dim: int = 128 index_topk: int = 512 # hc hc_mult: int = 4 hc_sinkhorn_iters: int = 20 hc_eps: float = 1e-6 class ParallelEmbedding(nn.Module): """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows. Out-of-range indices are zero-masked before all_reduce to combine partial embeddings.""" def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size self.dim = dim assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: y[mask] = 0 dist.all_reduce(y) return y def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype. For quantized weights, x is first quantized to FP8 via act_quant.""" assert bias is None if weight.dtype == torch.float4_e2m1fn_x2: x, s = act_quant(x, block_size, scale_fmt, scale_dtype) return fp4_gemm(x, s, weight, weight.scale, scale_dtype) elif weight.dtype == torch.float8_e4m3fn: x, s = act_quant(x, block_size, scale_fmt, scale_dtype) return fp8_gemm(x, s, weight, weight.scale, scale_dtype) else: return F.linear(x, weight) class Linear(nn.Module): """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() self.in_features = in_features self.out_features = out_features # In a W4A16 build every Linear becomes W4A16 regardless of the dtype the # original FP8/FP4 model wanted. The non-quant special cases (RMSNorm, # embed, attn_sink, etc.) are NOT instances of `Linear`, so they are # untouched. if w4a16_mode: dtype = "w4a16" else: dtype = dtype or default_dtype self.is_w4a16 = (dtype == "w4a16") if self.is_w4a16: assert in_features % 8 == 0 and in_features % w4a16_group_size == 0 assert out_features % 8 == 0 self.group_size = w4a16_group_size self.qweight = nn.Parameter( torch.empty(in_features // 8, out_features, dtype=torch.int32), requires_grad=False, ) self.qzeros = nn.Parameter( torch.empty(in_features // self.group_size, out_features // 8, dtype=torch.int32), requires_grad=False, ) self.scales = nn.Parameter( torch.empty(in_features // self.group_size, out_features, dtype=torch.bfloat16), requires_grad=False, ) self.register_parameter("weight", None) self.register_parameter("scale", None) elif dtype == torch.float4_e2m1fn_x2: # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4 # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K) self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2)) scale_out_features = out_features scale_in_features = in_features // fp4_block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) elif dtype == torch.float8_e4m3fn: self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) else: self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) self.register_parameter("scale", None) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def init_woq(self, QuantLinear): """Create a QuantLinear from loaded GPTQ parameters.""" if not self.is_w4a16: return # Marlin requires out_features % 64 == 0; fall back to manual dequant if self.out_features % 64 != 0: self._woq = None return dev = self.qweight.device layer = QuantLinear( bits=4, group_size=self.group_size, in_features=self.in_features, out_features=self.out_features, bias=False, desc_act=False, sym=True, register_buffers=True, ) layer = layer.to(dev) layer.qweight.copy_(self.qweight.data) layer.qzeros.copy_(self.qzeros.data) layer.scales.copy_(self.scales.to(layer.scales.dtype).data) layer.g_idx.copy_(torch.arange(self.in_features, dtype=torch.int32, device=dev) // self.group_size) layer.post_init() self._woq = layer # Free original parameters to save memory self.qweight = None self.qzeros = None self.scales = None def get_weight(self) -> torch.Tensor: """Return the dequantised BF16 weight [out, in]. For non-W4A16 modes returns ``self.weight`` unchanged. Used only for wo_a einsum path.""" if self.is_w4a16: if self._woq is not None: return dequantize_w4a16(self._woq.qweight, self._woq.qzeros, self._woq.scales, self.group_size) return dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size) return self.weight def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_w4a16: if hasattr(self, '_woq') and self._woq is not None: y = self._woq(x.to(torch.bfloat16)) else: w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size) y = F.linear(x.to(w.dtype), w) if self.bias is not None: y = y + self.bias return y.type_as(x) return linear(x, self.weight, self.bias) class ColumnParallelLinear(Linear): """Shards output dim across TP ranks. No all-reduce needed on output.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_w4a16: return Linear.forward(self, x) return linear(x, self.weight, self.bias) class RowParallelLinear(Linear): """Shards input dim across TP ranks. All-reduce on output to sum partial results.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_w4a16: if hasattr(self, '_woq') and self._woq is not None: y = self._woq(x.to(torch.bfloat16)) else: w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size) y = F.linear(x.to(w.dtype), w) else: y = linear(x, self.weight, None) if world_size > 1: y = y.float() dist.all_reduce(y) if self.bias is not None: y += self.bias return y.type_as(x) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: torch.Tensor): dtype = x.dtype x = x.float() var = x.square().mean(-1, keepdim=True) x = x * torch.rsqrt(var + self.eps) return (self.weight * x).to(dtype) @lru_cache(2) def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor: """Precomputes complex exponentials for rotary embeddings with YaRN scaling. When original_seq_len > 0, applies frequency interpolation with a smooth linear ramp between beta_fast and beta_slow correction ranges.""" def find_correction_dim(num_rotations, dim, base, max_seq_len): return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) return max(low, 0), min(high, dim-1) def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) if original_seq_len > 0: low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth t = torch.arange(seqlen) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor: """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation).""" y = x x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) if inverse: freqs_cis = freqs_cis.conj() if x.ndim == 3: freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1)) else: freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) x = torch.view_as_real(x * freqs_cis).flatten(-2) y.copy_(x) return y def rotate_activation(x: torch.Tensor) -> torch.Tensor: """Applies randomized Hadamard rotation to spread information across dims before FP8 quant.""" assert x.dtype == torch.bfloat16 from fast_hadamard_transform import hadamard_transform return hadamard_transform(x, scale=x.size(-1) ** -0.5) @lru_cache(1) def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int): if start_pos >= window_size - 1: start_pos %= window_size matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0) elif start_pos > 0: matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1) else: base = torch.arange(seqlen).unsqueeze(1) matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size)) matrix = torch.where(matrix > base, -1, matrix) return matrix.unsqueeze(0).expand(bsz, -1, -1) @lru_cache(2) def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int): if start_pos > 0: matrix = torch.arange(0, (start_pos + 1) // ratio) + offset else: matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1) mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio matrix = torch.where(mask, -1, matrix + offset) return matrix.unsqueeze(0).expand(bsz, -1, -1) class Compressor(nn.Module): """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens. When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries.""" def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False): super().__init__() self.dim = args.dim self.head_dim = head_dim self.rope_head_dim = args.rope_head_dim self.nope_head_dim = head_dim - args.rope_head_dim self.compress_ratio = compress_ratio self.overlap = compress_ratio == 4 self.rotate = rotate coff = 1 + self.overlap self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32)) # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. # When overlap, the first half of dims is for overlapping compression, second half for normal. self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) self.norm = RMSNorm(self.head_dim, args.norm_eps) self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache # State buffers for decode-phase incremental compression. # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window. self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False) self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False) self.freqs_cis: torch.Tensor = None def overlap_transform(self, tensor: torch.Tensor, value=0): # tensor: [b,s,r,2d] b, s, _, _ = tensor.size() ratio, d = self.compress_ratio, self.head_dim new_tensor = tensor.new_full((b, s, 2 * ratio, d), value) new_tensor[:, :, ratio:] = tensor[:, :, :, d:] new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d] return new_tensor def forward(self, x: torch.Tensor, start_pos: int): assert self.kv_cache is not None bsz, seqlen, _ = x.size() ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim dtype = x.dtype # compression need fp32 x = x.float() kv = self.wkv(x) score = self.wgate(x) if start_pos == 0: should_compress = seqlen >= ratio remainder = seqlen % ratio cutoff = seqlen - remainder offset = ratio if overlap else 0 if overlap and cutoff >= ratio: self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff] self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape if remainder > 0: kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1) self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder] score = score[:, :cutoff] kv = kv.unflatten(1, (-1, ratio)) score = score.unflatten(1, (-1, ratio)) + self.ape if overlap: kv = self.overlap_transform(kv, 0) score = self.overlap_transform(score, float("-inf")) kv = (kv * score.softmax(dim=2)).sum(dim=2) else: should_compress = (start_pos + 1) % self.compress_ratio == 0 score += self.ape[start_pos % ratio] if overlap: self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1) self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1) if should_compress: kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1) score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1) kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True) self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:] self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:] else: self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1) self.score_state[:bsz, start_pos % ratio] = score.squeeze(1) if should_compress: kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True) if not should_compress: return kv = self.norm(kv.to(dtype)) if start_pos == 0: freqs_cis = self.freqs_cis[:cutoff:ratio] else: freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0) apply_rotary_emb(kv[..., -rd:], freqs_cis) if self.rotate: kv = rotate_activation(kv) fp4_act_quant(kv, fp4_block_size, True) else: act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) if start_pos == 0: self.kv_cache[:bsz, :seqlen // ratio] = kv else: self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) return kv class Indexer(torch.nn.Module): """Selects top-k compressed KV positions for sparse attention via learned scoring. Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring.""" def __init__(self, args: ModelArgs, compress_ratio: int = 4): super().__init__() self.dim = args.dim self.n_heads = args.index_n_heads self.n_local_heads = args.index_n_heads // world_size self.head_dim = args.index_head_dim self.rope_head_dim = args.rope_head_dim self.index_topk = args.index_topk self.q_lora_rank = args.q_lora_rank self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim) self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16) self.softmax_scale = self.head_dim ** -0.5 self.compress_ratio = compress_ratio self.compressor = Compressor(args, compress_ratio, self.head_dim, True) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False) self.freqs_cis = None def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int): bsz, seqlen, _ = x.size() freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] ratio = self.compress_ratio rd = self.rope_head_dim end_pos = start_pos + seqlen if self.compressor.kv_cache is None: self.compressor.kv_cache = self.kv_cache self.compressor.freqs_cis = self.freqs_cis q = self.wq_b(qr) q = q.unflatten(-1, (self.n_local_heads, self.head_dim)) apply_rotary_emb(q[..., -rd:], freqs_cis) q = rotate_activation(q) # use fp4 simulation for q and kv in indexer fp4_act_quant(q, fp4_block_size, True) self.compressor(x, start_pos) weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5) # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio]) index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) if world_size > 1: dist.all_reduce(index_score) if start_pos == 0: mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio index_score += torch.where(mask, float("-inf"), 0) topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1] if start_pos == 0: mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio topk_idxs = torch.where(mask, -1, topk_idxs + offset) else: topk_idxs += offset return topk_idxs class Attention(nn.Module): """Multi-head Latent Attention (MLA) with sliding window + optional KV compression. Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.dim = args.dim self.n_heads = args.n_heads self.n_local_heads = args.n_heads // world_size self.q_lora_rank = args.q_lora_rank self.o_lora_rank = args.o_lora_rank self.head_dim = args.head_dim self.rope_head_dim = args.rope_head_dim self.nope_head_dim = args.head_dim - args.rope_head_dim self.n_groups = args.o_groups self.n_local_groups = self.n_groups // world_size self.window_size = args.window_size self.compress_ratio = args.compress_ratios[layer_id] self.eps = args.norm_eps self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) self.wq_a = Linear(self.dim, self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim) self.wkv = Linear(self.dim, self.head_dim) self.kv_norm = RMSNorm(self.head_dim, self.eps) self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16) self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim) self.softmax_scale = self.head_dim ** -0.5 if self.compress_ratio: self.compressor = Compressor(args, self.compress_ratio, self.head_dim) if self.compress_ratio == 4: self.indexer = Indexer(args, self.compress_ratio) else: self.indexer = None kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False) if self.compress_ratio: original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta else: # disable YaRN and use base rope_theta in pure sliding-window attention original_seq_len, rope_theta = 0, args.rope_theta freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len, rope_theta, args.rope_factor, args.beta_fast, args.beta_slow) self.register_buffer("freqs_cis", freqs_cis, persistent=False) def forward(self, x: torch.Tensor, start_pos: int): bsz, seqlen, _ = x.size() freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] win = self.window_size ratio = self.compress_ratio rd = self.rope_head_dim if self.compress_ratio and self.compressor.kv_cache is None: self.compressor.kv_cache = self.kv_cache[:, win:] self.compressor.freqs_cis = self.freqs_cis if self.indexer is not None: self.indexer.freqs_cis = self.freqs_cis # q qr = q = self.q_norm(self.wq_a(x)) q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim)) q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps) apply_rotary_emb(q[..., -rd:], freqs_cis) # win kv & topk_idxs kv = self.wkv(x) kv = self.kv_norm(kv) apply_rotary_emb(kv[..., -rd:], freqs_cis) # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos) if self.compress_ratio: offset = kv.size(1) if start_pos == 0 else win if self.indexer is not None: compress_topk_idxs = self.indexer(x, qr, start_pos, offset) else: compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset) topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) topk_idxs = topk_idxs.int() # compress kv & attn if start_pos == 0: if seqlen <= win: self.kv_cache[:bsz, :seqlen] = kv else: cutoff = seqlen % win self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1) if self.compress_ratio: if (kv_compress := self.compressor(x, start_pos)) is not None: kv = torch.cat([kv, kv_compress], dim=1) # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) else: self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1) if self.compress_ratio: self.compressor(x, start_pos) o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale) apply_rotary_emb(o[..., -rd:], freqs_cis, True) # o: apply wo_a per-group projection then wo_b # Flatten groups into the feature dim, call wo_a as a normal linear, then reshape back. # Equivalent to the per-group einsum when wo_a weight is block-diagonal across groups # (always true here since n_local_groups = n_groups/world_size = 1 for 8-GPU deploy). o = o.view(bsz, seqlen, self.n_local_groups, -1) o = self.wo_a(o.flatten(2)).view(bsz, seqlen, self.n_local_groups, self.o_lora_rank) x = self.wo_b(o.flatten(2)) return x class Gate(nn.Module): """MoE gating: computes expert routing scores and selects top-k experts. Supports hash-based routing (first n_hash_layers) where expert indices are predetermined per token ID, and score-based routing (remaining layers).""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.dim = args.dim self.topk = args.n_activated_experts self.score_func = args.score_func self.route_scale = args.route_scale self.hash = layer_id < args.n_hash_layers self.is_w4a16 = w4a16_mode if self.is_w4a16: in_f, out_f = args.dim, args.n_routed_experts assert in_f % w4a16_group_size == 0 and out_f % 8 == 0 self.group_size = w4a16_group_size self.qweight = nn.Parameter( torch.empty(in_f // 8, out_f, dtype=torch.int32), requires_grad=False) self.qzeros = nn.Parameter( torch.empty(in_f // self.group_size, out_f // 8, dtype=torch.int32), requires_grad=False) self.scales = nn.Parameter( torch.empty(in_f // self.group_size, out_f, dtype=torch.bfloat16), requires_grad=False) self.register_parameter("weight", None) else: self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) if self.hash: self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False) self.bias = None else: self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) def init_woq(self, QuantLinear): if not self.is_w4a16: return dev = self.qweight.device in_f, out_f = self.dim, self.qweight.shape[1] if out_f % 64 != 0: self._woq = None return layer = QuantLinear( bits=4, group_size=self.group_size, in_features=in_f, out_features=out_f, bias=False, desc_act=False, sym=True, register_buffers=True, ) layer = layer.to(dev) layer.qweight.copy_(self.qweight.data) layer.qzeros.copy_(self.qzeros.data) layer.scales.copy_(self.scales.to(layer.scales.dtype).data) layer.g_idx.copy_(torch.arange(in_f, dtype=torch.int32, device=dev) // self.group_size) layer.post_init() self._woq = layer self.qweight = None self.qzeros = None self.scales = None def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: if self.is_w4a16: if hasattr(self, '_woq') and self._woq is not None: scores = self._woq(x.to(torch.bfloat16)).float() else: w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size) scores = F.linear(x.to(w.dtype), w).float() else: scores = linear(x.float(), self.weight.float()) if self.score_func == "softmax": scores = scores.softmax(dim=-1) elif self.score_func == "sigmoid": scores = scores.sigmoid() else: scores = F.softplus(scores).sqrt() original_scores = scores # Bias shifts scores for expert selection (topk) but does not affect routing weights. if self.bias is not None: scores = scores + self.bias if self.hash: indices = self.tid2eid[input_ids] else: indices = scores.topk(self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func != "softmax": weights /= weights.sum(dim=-1, keepdim=True) weights *= self.route_scale return weights, indices class Expert(nn.Module): """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability.""" def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0): super().__init__() self.w1 = Linear(dim, inter_dim, dtype=dtype) self.w2 = Linear(inter_dim, dim, dtype=dtype) self.w3 = Linear(dim, inter_dim, dtype=dtype) self.swiglu_limit = swiglu_limit def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: dtype = x.dtype gate = self.w1(x).float() up = self.w3(x).float() if self.swiglu_limit > 0: up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit) gate = torch.clamp(gate, max=self.swiglu_limit) x = F.silu(gate) * up if weights is not None: x = weights * x return self.w2(x.to(dtype)) class MoE(nn.Module): """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert. Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.dim = args.dim assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.gate = Gate(layer_id, args) expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)]) assert args.n_shared_experts == 1 # no swiglu_limit self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit) def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: shape = x.size() x = x.view(-1, self.dim) weights, indices = self.gate(x, input_ids.flatten()) y = torch.zeros_like(x, dtype=torch.float32) counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue expert = self.experts[i] idx, top = torch.where(indices == i) y[idx] += expert(x[idx], weights[idx, top, None]) if world_size > 1: dist.all_reduce(y) y += self.shared_experts(x) return y.type_as(x).view(shape) class Block(nn.Module): """Transformer block with Hyper-Connections (HC) mixing. Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state. hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn). hc_post: expands 1 -> hc copies via learned post-weights + combination matrix.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.norm_eps = args.norm_eps self.attn = Attention(layer_id, args) self.ffn = MoE(layer_id, args) self.attn_norm = RMSNorm(args.dim, self.norm_eps) self.ffn_norm = RMSNorm(args.dim, self.norm_eps) self.hc_mult = hc_mult = args.hc_mult self.hc_sinkhorn_iters = args.hc_sinkhorn_iters self.hc_eps = args.hc_eps mix_hc = (2 + hc_mult) * hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) self.hc_attn_base = nn.Parameter(torch.empty(mix_hc)) self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc)) self.hc_attn_scale = nn.Parameter(torch.empty(3)) self.hc_ffn_scale = nn.Parameter(torch.empty(3)) def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d] shape, dtype = x.size(), x.dtype x = x.flatten(2).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) mixes = F.linear(x, hc_fn) * rsqrt pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps) y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) return y.to(dtype), post, comb def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor): # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d] y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2) return y.type_as(x) def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor: residual = x x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base) x = self.attn_norm(x) x = self.attn(x, start_pos) x = self.hc_post(x, residual, post, comb) residual = x x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base) x = self.ffn_norm(x) x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) return x class ParallelHead(nn.Module): def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6): super().__init__() self.vocab_size = vocab_size self.dim = dim self.norm_eps = norm_eps self.hc_eps = hc_eps self.part_vocab_size = (vocab_size // world_size) # lm_head is always stored as bf16 (even in W4A16 checkpoints); use fp32 for logit precision self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32)) def get_logits(self, x): return F.linear(x[:, -1].float(), self.weight) def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm): # x: [b,s,hc,d] x = self.hc_head(x, hc_fn, hc_scale, hc_base) logits = self.get_logits(norm(x)) if world_size > 1: all_logits = [torch.empty_like(logits) for _ in range(world_size)] dist.all_gather(all_logits, logits) logits = torch.cat(all_logits, dim=-1) return logits def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): shape, dtype = x.size(), x.dtype x = x.flatten(2).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) mixes = F.linear(x, hc_fn) * rsqrt pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) return y.to(dtype) class MTPBlock(Block): def __init__(self, layer_id: int, args: ModelArgs): super().__init__(layer_id, args) self.e_proj = Linear(args.dim, args.dim) self.h_proj = Linear(args.dim, args.dim) self.enorm = RMSNorm(args.dim, args.norm_eps) self.hnorm = RMSNorm(args.dim, args.norm_eps) self.norm = RMSNorm(args.dim, args.norm_eps) self.hc_mult = hc_mult = args.hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) self.hc_head_scale = nn.Parameter(torch.empty(1)) self.embed: ParallelEmbedding = None self.head: ParallelHead = None @torch.inference_mode() def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor: # x: [b,s,hc,d] assert self.embed is not None and self.head is not None e = self.embed(input_ids) e = self.enorm(e) x = self.hnorm(x) x = self.e_proj(e).unsqueeze(2) + self.h_proj(x) x = super().forward(x, start_pos, input_ids) logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) return logits class Transformer(nn.Module): """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits. Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__.""" def __init__(self, args: ModelArgs): global world_size, rank, default_dtype, scale_fmt, scale_dtype, w4a16_mode world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 w4a16_mode = (args.dtype == "w4a16") if w4a16_mode: default_dtype = torch.bfloat16 scale_fmt = None scale_dtype = torch.float32 else: default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32 super().__init__() self.max_seq_len = args.max_seq_len self.norm_eps = args.norm_eps self.hc_eps = args.hc_eps self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): self.layers.append(Block(layer_id, args)) self.norm = RMSNorm(args.dim, self.norm_eps) self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps) self.mtp = torch.nn.ModuleList() for layer_id in range(args.n_mtp_layers): self.mtp.append(MTPBlock(args.n_layers + layer_id, args)) self.mtp[-1].embed = self.embed self.mtp[-1].head = self.head self.hc_mult = hc_mult = args.hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) self.hc_head_scale = nn.Parameter(torch.empty(1)) def init_woq_layers(self): """After load_model(), convert all W4A16 parameters into QuantLinear layers.""" # from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear as QuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear as QuantLinear for module in self.modules(): if hasattr(module, 'init_woq') and module is not self: module.init_woq(QuantLinear) torch.cuda.empty_cache() @torch.inference_mode() def forward(self, input_ids: torch.Tensor, start_pos: int = 0): h = self.embed(input_ids) # Expand to hc_mult copies for Hyper-Connections h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1) for layer in self.layers: h = layer(h, start_pos, input_ids) logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) return logits if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs(n_hash_layers=0) x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size()) for i in range(128, 150): print(i, model(x[:, 0:1], i).size()) h = torch.randn(2, 128, args.hc_mult, args.dim) mtp = model.mtp[0] print(mtp(h, 0, x).size()) print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())