deepseek-v4-mini-3B-init / code /deepseek_v4 /modeling_deepseek_v4.py
kshitijthakkar's picture
Re-upload as 4x2GB sharded safetensors for parallel transfer
e1d16ef verified
"""DeepSeek-V4 modeling code (faithful small-scale replica).
Parameter naming mirrors the official DeepSeek-V4 safetensors index so that
weights can later be transferred / sliced from real V4-Pro / V4-Flash
checkpoints. Top-level layout (flat, no ``model.`` prefix):
embed.weight
layers.{i}.attn_norm.weight
layers.{i}.ffn_norm.weight
layers.{i}.hc_attn_{base,fn,scale}
layers.{i}.hc_ffn_{base,fn,scale}
layers.{i}.attn.{wq_a, wq_b, wkv, wo_a, wo_b, q_norm, kv_norm, attn_sink}
layers.{i}.attn.compressor.{wkv, wgate, ape, norm} # CSA / HCA only
layers.{i}.attn.indexer.{wq_b, weights_proj, compressor.*}# CSA only
layers.{i}.ffn.gate.{weight, bias} # routed MoE
layers.{i}.ffn.gate.tid2eid # hash MoE
layers.{i}.ffn.experts.{j}.{w1, w2, w3}.weight
layers.{i}.ffn.shared_experts.{w1, w2, w3}.weight
norm.weight
head.weight
hc_head_{base, fn, scale}
mtp.{k}.{...} # one per MTP step
"""
from __future__ import annotations
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from .configuration_deepseek_v4 import DeepseekV4Config
# =============================================================================
# Norms, RoPE, utilities
# =============================================================================
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_dtype = x.dtype
x32 = x.float()
var = x32.pow(2).mean(-1, keepdim=True)
x32 = x32 * torch.rsqrt(var + self.eps)
return (self.weight * x32).to(in_dtype)
def fixed_rmsnorm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
"""RMSNorm without a learnable scale (used inside mHC)."""
in_dtype = x.dtype
x32 = x.float()
var = x32.pow(2).mean(-1, keepdim=True)
return (x32 * torch.rsqrt(var + eps)).to(in_dtype)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def build_rope_cache(seq_len: int, dim: int, base: float, device, dtype):
if dim <= 0:
return torch.zeros(seq_len, 0, device=device, dtype=dtype), \
torch.zeros(seq_len, 0, device=device, dtype=dtype)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos().to(dtype), emb.sin().to(dtype)
def apply_partial_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
rope_dim: int, positions: torch.Tensor) -> torch.Tensor:
"""Apply RoPE to last `rope_dim` dims of x at given positions.
x: [..., L, D]; cos/sin: [P, rope_dim]; positions: [L] long.
"""
if rope_dim <= 0:
return x
x_pass, x_rot = x[..., :-rope_dim], x[..., -rope_dim:]
c = cos[positions]
s = sin[positions]
while c.dim() < x_rot.dim():
c = c.unsqueeze(0)
s = s.unsqueeze(0)
x_rot = (x_rot * c) + (_rotate_half(x_rot) * s)
return torch.cat([x_pass, x_rot], dim=-1)
# =============================================================================
# Manifold-Constrained Hyper-Connections (mHC)
# =============================================================================
class MHC(nn.Module):
"""Manifold-Constrained Hyper-Connections.
Parameter layout matches official safetensors / kernel.py exactly:
- {prefix}_fn [mix_hc, n*d] dynamic generator (single combined matmul)
- {prefix}_base [mix_hc] static biases
- {prefix}_scale [3] three scalar gates (one per pre/post/comb part)
with mix_hc = (2 + n) * n.
Math (matches inference/kernel.py:hc_split_sinkhorn_kernel):
flat = X.flatten(-2) # [B,S,n*d]
rsqrt = rsqrt(mean(flat^2) + eps) # row-wise
mixes = (flat @ fn.T) * rsqrt # [B,S, mix_hc]
pre[i] = sigmoid(mixes[:, i] * scale[0] + base[i]) + eps for i in [0,n)
post[i] = 2 * sigmoid(mixes[:, n+i] * scale[1] + base[n+i]) for i in [0,n)
comb_raw = mixes[:, 2n + j*n + k] * scale[2] + base[2n+j*n+k] [n,n]
comb = softmax(comb_raw, dim=-1) + eps # row softmax then +eps
comb = comb / (comb.sum(-2, keepdim=True) + eps) # column normalize
repeat (sinkhorn_iters - 1) times:
comb = comb / (comb.sum(-1, keepdim=True) + eps)
comb = comb / (comb.sum(-2, keepdim=True) + eps)
Apply (matches Block.hc_pre / hc_post):
sublayer_in = sum_i pre[i] * X[i] # [B,S,d]
new_X[i] = post[i] * F_out + sum_j comb[i,j] * X[j] # [B,S,n,d]
"""
def __init__(self, hidden_size: int, n_hc: int, sinkhorn_iters: int = 20,
eps: float = 1e-6):
super().__init__()
self.d = hidden_size
self.n = n_hc
self.iters = sinkhorn_iters
self.eps = eps
self.flat = n_hc * hidden_size
self.mix_hc = (2 + n_hc) * n_hc # = 24 for n=4
def split_and_construct(self, mixes: torch.Tensor, base: torch.Tensor,
scale: torch.Tensor):
"""mixes: [..., mix_hc]; base: [mix_hc]; scale: [3].
Returns (pre [...,n], post [...,n], comb [...,n,n]).
All math is in fp32 (matches official ``with set_dtype(torch.float32)``
block around hc_*_fn / base / scale params); base/scale may be stored
in any dtype but are promoted to mixes.dtype for arithmetic.
"""
n = self.n
base = base.to(mixes.dtype)
scale = scale.to(mixes.dtype)
# Indexing: pre = first n, post = next n, comb = last n*n flattened row-major.
pre_raw = mixes[..., :n]
post_raw = mixes[..., n:2 * n]
comb_raw = mixes[..., 2 * n:].reshape(*mixes.shape[:-1], n, n)
base_pre = base[:n]
base_post = base[n:2 * n]
base_comb = base[2 * n:].view(n, n)
pre = torch.sigmoid(scale[0] * pre_raw + base_pre) + self.eps
post = 2.0 * torch.sigmoid(scale[1] * post_raw + base_post)
comb_pre = scale[2] * comb_raw + base_comb
# Row-softmax then +eps, then column normalize, then alternating row/col norms.
comb = F.softmax(comb_pre, dim=-1) + self.eps
comb = comb / (comb.sum(dim=-2, keepdim=True) + self.eps)
for _ in range(self.iters - 1):
comb = comb / (comb.sum(dim=-1, keepdim=True) + self.eps)
comb = comb / (comb.sum(dim=-2, keepdim=True) + self.eps)
return pre, post, comb
def gen_params(self, X: torch.Tensor, base: torch.Tensor, fn: torch.Tensor,
scale: torch.Tensor):
"""X: [B,S,n,d]. Returns (pre [B,S,n], post [B,S,n], comb [B,S,n,n]).
Always computed in fp32 (matches official `with set_dtype(fp32)` for mHC).
"""
Bsz, S, n, d = X.shape
flat = X.reshape(Bsz, S, n * d).float()
rsqrt = torch.rsqrt(flat.square().mean(-1, keepdim=True) + self.eps)
mixes = F.linear(flat, fn.float()) * rsqrt # [B,S, mix_hc]
return self.split_and_construct(mixes, base, scale)
@staticmethod
def hc_pre(X: torch.Tensor, pre: torch.Tensor) -> torch.Tensor:
"""X: [B,S,n,d], pre: [B,S,n]. Returns [B,S,d]."""
return torch.sum(pre.unsqueeze(-1).to(X.dtype) * X, dim=-2)
@staticmethod
def hc_post(new_x: torch.Tensor, residual: torch.Tensor,
post: torch.Tensor, comb: torch.Tensor) -> torch.Tensor:
"""new_x: [B,S,d], residual: [B,S,n,d], post: [B,S,n], comb: [B,S,n,n].
out[i] = post[i] * new_x + sum_j comb[i,j] * residual[j]
"""
post_e = post.unsqueeze(-1).to(new_x.dtype) # [B,S,n,1]
comb_e = comb.to(residual.dtype) # [B,S,n,n]
return post_e * new_x.unsqueeze(-2) + torch.matmul(comb_e, residual)
# --- Head-side mHC: only computes `pre`, no Sinkhorn. ---
def gen_head_pre(self, X: torch.Tensor, fn: torch.Tensor,
base: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""fn: [n, n*d]; base: [n]; scale: [1] or scalar. Returns pre: [B,S,n]."""
Bsz, S, n, d = X.shape
flat = X.reshape(Bsz, S, n * d).float()
rsqrt = torch.rsqrt(flat.square().mean(-1, keepdim=True) + self.eps)
mixes = F.linear(flat, fn.float()) * rsqrt # [B,S,n]
scale = scale.float()
base = base.float()
s = scale.view(-1)[0] if scale.numel() else 1.0
return torch.sigmoid(s * mixes + base) + self.eps
# =============================================================================
# Attention sink helper
# =============================================================================
def sink_softmax(logits: torch.Tensor, sink: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Softmax with extra learnable per-head sink logit in the denominator.
logits: [..., H, ..., K]; sink: [H] (broadcast).
Caller must shape sink so it broadcasts with logits.
"""
m = logits.amax(dim=dim, keepdim=True)
m = torch.maximum(m, sink)
ex = torch.exp(logits - m)
sink_ex = torch.exp(sink - m)
return ex / (ex.sum(dim=dim, keepdim=True) + sink_ex)
# =============================================================================
# Compressor (token-level pooling); shared for HCA, CSA, and indexer keys
# =============================================================================
class Compressor(nn.Module):
"""Compresses every `m` hidden states into one entry via softmax-weighted pool.
Matches inference/model.py:Compressor exactly:
- When overlap (compress_ratio == 4): wkv and wgate output 2*head_dim
(first half = overlap stream, second half = current); ape: [m, 2*head_dim].
``overlap_transform`` rearranges [b,nb,m,2c] -> [b,nb,2m,c] before softmax.
- When non-overlap: outputs head_dim, ape: [m, head_dim], plain softmax pool.
Tensor names: ``norm.weight``, ``wkv.weight``, ``wgate.weight``, ``ape``.
"""
def __init__(self, hidden_size: int, head_dim: int, m: int, overlap: bool):
super().__init__()
self.m = m
self.overlap = overlap
self.head_dim = head_dim
coff = 2 if overlap else 1
self.coff = coff
self.norm = RMSNorm(head_dim)
self.wkv = nn.Linear(hidden_size, coff * head_dim, bias=False)
self.wgate = nn.Linear(hidden_size, coff * head_dim, bias=False)
self.ape = nn.Parameter(torch.zeros(m, coff * head_dim))
@staticmethod
def _overlap_transform(t: torch.Tensor, head_dim: int, fill_value) -> torch.Tensor:
"""t: [b, nb, m, 2*head_dim]; returns [b, nb, 2*m, head_dim].
First m positions of each block come from the previous block's overlap-half;
next m positions come from the current block's current-half.
"""
b, nb, m, _ = t.shape
d = head_dim
out = t.new_full((b, nb, 2 * m, d), fill_value)
out[:, :, m:] = t[:, :, :, d:] # current half
out[:, 1:, :m] = t[:, :-1, :, :d] # prev block's overlap half, shift +1
return out
def forward(self, h: torch.Tensor) -> torch.Tensor:
"""h: [B, n, D]. Returns compressed [B, ceil(n/m), head_dim]."""
Bsz, n, _ = h.shape
m, d = self.m, self.head_dim
# Matmul in whatever dtype the wkv/wgate weights live in (bf16 / fp32).
# We then upcast to fp32 for the softmax-weighted pool (numerical stability).
param_dtype = self.wkv.weight.dtype
xx = h.to(param_dtype)
kv = self.wkv(xx).float() # [B, n, coff*d]
score = self.wgate(xx).float() # [B, n, coff*d]
# Pad to multiple of m
pad = (m - n % m) % m
if pad:
kv = F.pad(kv, (0, 0, 0, pad))
score = F.pad(score, (0, 0, 0, pad))
nb = kv.size(1) // m
kv = kv.view(Bsz, nb, m, -1) # [B,nb,m, coff*d]
score = score.view(Bsz, nb, m, -1) + self.ape.float() # bias by ape (fp32)
if self.overlap:
kv = self._overlap_transform(kv, d, 0.0) # [B,nb, 2m, d]
score = self._overlap_transform(score, d, float("-inf")) # [B,nb, 2m, d]
# Softmax over the m (or 2m) positions, weighted sum to one entry per block
kv = (kv * score.softmax(dim=2)).sum(dim=2) # [B,nb, d]
return self.norm(kv.to(h.dtype))
# =============================================================================
# Lightning Indexer
# =============================================================================
class LightningIndexer(nn.Module):
"""
Names:
indexer.compressor.* (separate Compressor for indexer keys)
indexer.wq_b.weight (q-up from shared cQ -> H_I * head_dim)
indexer.weights_proj.weight (per-head weight w_t,h)
"""
def __init__(self, hidden_size: int, q_lora_rank: int,
index_n_heads: int, index_head_dim: int,
m: int, overlap: bool):
super().__init__()
self.n_heads = index_n_heads
self.head_dim = index_head_dim
self.compressor = Compressor(hidden_size, index_head_dim, m, overlap=overlap)
self.wq_b = nn.Linear(q_lora_rank, index_n_heads * index_head_dim, bias=False)
self.weights_proj = nn.Linear(hidden_size, index_n_heads, bias=False)
# Score scaling: softmax_scale * 1/sqrt(n_heads), as in inference/model.py
self.score_scale = (index_head_dim ** -0.5) * (index_n_heads ** -0.5)
def keys(self, h: torch.Tensor) -> torch.Tensor:
return self.compressor(h) # [B, nb, head_dim]
def select(self, h: torch.Tensor, cQ: torch.Tensor, K: torch.Tensor,
positions: torch.Tensor, m: int, top_k: int):
"""Returns (idx [B,Lq,k], mask [B,Lq,k] bool)."""
Bsz, Lq, _ = h.shape
nb = K.size(1)
qI = self.wq_b(cQ).view(Bsz, Lq, self.n_heads, self.head_dim)
wI = self.weights_proj(h) * self.score_scale # [B,Lq,H_I]
qK = torch.einsum("blhd,bsd->blhs", qI, K)
qK = F.relu(qK)
scores = (wI.unsqueeze(-1) * qK).sum(dim=2) # [B,Lq,nb]
# Causal: query at pos t may attend to block s if (s+1)*m - 1 < t ⇔ s < t/m
s_idx = torch.arange(nb, device=h.device)
causal = s_idx.unsqueeze(0) < (positions.unsqueeze(-1) // m) # [Lq, nb]
scores = scores.masked_fill(~causal.unsqueeze(0), float("-inf"))
k = min(top_k, nb)
if k <= 0:
empty = torch.zeros(Bsz, Lq, 0, dtype=torch.long, device=h.device)
return empty, empty.bool()
topk = scores.topk(k, dim=-1)
return topk.indices, torch.isfinite(topk.values)
# =============================================================================
# Attention layer (CSA / HCA / pure sliding-window)
# =============================================================================
class DeepseekV4Attention(nn.Module):
"""One attention layer.
compress_ratio: 0 -> pure SW; small (>0, <16) -> CSA; large (>=16) -> HCA.
"""
def __init__(self, config: DeepseekV4Config, compress_ratio: int):
super().__init__()
self.config = config
self.compress_ratio = compress_ratio
d = config.hidden_size
H = config.num_attention_heads
c = config.head_dim
self.H = H
self.c = c
self.q_lora_rank = config.q_lora_rank
self.o_groups = config.o_groups
assert H % self.o_groups == 0
self.heads_per_group = H // self.o_groups
self.d_g = config.o_lora_rank
self.rope_dim = config.qk_rope_head_dim
self.window = config.sliding_window
if compress_ratio == 0:
self.mode = "sw"
elif compress_ratio < 16:
self.mode = "csa"
else:
self.mode = "hca"
# Query path: low-rank with norm at q_lora; per-head rsqrt-norm applied at use time
self.wq_a = nn.Linear(d, config.q_lora_rank, bias=False)
self.q_norm = RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps)
self.wq_b = nn.Linear(config.q_lora_rank, H * c, bias=False)
# Sliding-window KV (always present, single shared head — MQA)
self.wkv = nn.Linear(d, c, bias=False)
self.kv_norm = RMSNorm(c, eps=config.rms_norm_eps)
# Output projection: per-group wo_a (n_groups separate sub-matrices stored in
# one Linear; reshape weight to [n_groups, o_lora_rank, heads_per_group*c]
# at use time and apply via einsum).
self.wo_a = nn.Linear(self.heads_per_group * c,
self.o_groups * self.d_g, bias=False)
self.wo_b = nn.Linear(self.o_groups * self.d_g, d, bias=False)
self.attn_sink = nn.Parameter(torch.zeros(H))
if self.mode in ("csa", "hca"):
self.compressor = Compressor(d, c, compress_ratio, overlap=(self.mode == "csa"))
if self.mode == "csa":
self.indexer = LightningIndexer(d, config.q_lora_rank,
config.index_n_heads, config.index_head_dim,
m=compress_ratio, overlap=True)
def _output_proj(self, attn_out: torch.Tensor) -> torch.Tensor:
"""attn_out: [B, S, H, c]. Returns [B, S, d].
Uses per-group wo_a: weight is [n_groups*o_lora, heads_per_group*c]; we
reshape to [n_groups, o_lora, heads_per_group*c] and apply via einsum
so each group has its own projection (matching official inference).
"""
B, S, H, c = attn_out.shape
out_g = attn_out.reshape(B, S, self.o_groups, self.heads_per_group * c)
wo_a = self.wo_a.weight.view(self.o_groups, self.d_g, self.heads_per_group * c)
out = torch.einsum("bsgd,grd->bsgr", out_g, wo_a) # [B,S,g,d_g]
out = out.reshape(B, S, self.o_groups * self.d_g)
return self.wo_b(out)
def _apply_output_rope(self, out: torch.Tensor, rope_cos, rope_sin,
positions) -> torch.Tensor:
"""V4 trick: rotate output by -position so contributions carry relative pos."""
if self.rope_dim <= 0:
return out
cos = rope_cos[positions] # [S, rope_dim]
sin = -rope_sin[positions] # negate -> rotate by -i
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
out_pass, out_rot = out[..., :-self.rope_dim], out[..., -self.rope_dim:]
out_rot = (out_rot * cos) + (_rotate_half(out_rot) * sin)
return torch.cat([out_pass, out_rot], dim=-1)
def forward(self, x: torch.Tensor, positions: torch.Tensor,
rope_cos: torch.Tensor, rope_sin: torch.Tensor,
rope_cos_c: torch.Tensor, rope_sin_c: torch.Tensor,
pad_mask: Optional[torch.Tensor]) -> torch.Tensor:
"""x: [B,S,d]; positions: [S] long; pad_mask: [B,S] bool (True=valid) or None."""
Bsz, S, _ = x.shape
H, c, m = self.H, self.c, self.compress_ratio
# Queries: low-rank, latent norm, then per-head no-weight RMSNorm (paper),
# then partial RoPE on the last `rope_dim` dims.
cQ = self.q_norm(self.wq_a(x)) # [B,S,q_lora]
q = self.wq_b(cQ).view(Bsz, S, H, c)
# Per-head fixed RMSNorm (no learnable weight) — see inference/model.py
q = q * torch.rsqrt(q.float().square().mean(-1, keepdim=True) +
self.config.rms_norm_eps).to(q.dtype)
q = apply_partial_rope(q.transpose(1, 2), rope_cos, rope_sin,
self.rope_dim, positions).transpose(1, 2)
# q now [B,S,H,c]
# Sliding-window KV
kv_sw = self.kv_norm(self.wkv(x)) # [B,S,c]
kv_sw = apply_partial_rope(kv_sw, rope_cos, rope_sin, self.rope_dim, positions)
# Build SW causal+window mask: [S, S] then expand to [B,S,S]
i = positions.unsqueeze(-1)
j = positions.unsqueeze(0)
sw_mask = (j <= i) & (j > i - self.window) # [S,S]
sw_mask = sw_mask.unsqueeze(0).expand(Bsz, -1, -1)
if pad_mask is not None:
sw_mask = sw_mask & pad_mask.unsqueeze(1) # mask padded keys
# ---------------- compressed branch ----------------
if self.mode in ("csa", "hca"):
# The compressor has its OWN internal RMSNorm on output — do not
# re-apply self.kv_norm (that one is for the sliding-window path).
kv_comp = self.compressor(x) # [B,nb,c]
nb = kv_comp.size(1)
comp_pos = (torch.arange(nb, device=x.device) * m + (m - 1)).clamp(
max=rope_cos_c.size(0) - 1
)
kv_comp = apply_partial_rope(kv_comp, rope_cos_c, rope_sin_c,
self.rope_dim, comp_pos)
# Per-query causal mask over compressed blocks
block_end = torch.arange(nb, device=x.device) * m + (m - 1)
comp_mask = (block_end.unsqueeze(0) < positions.unsqueeze(-1)) # [S,nb]
comp_mask = comp_mask.unsqueeze(0).expand(Bsz, -1, -1) # [B,S,nb]
else:
kv_comp = None
comp_mask = None
nb = 0
if self.mode == "csa":
K_idx = self.indexer.keys(x) # [B,nb,idx_head_dim]
idx, sel_mask = self.indexer.select(x, cQ, K_idx, positions, m,
self.config.index_topk)
# Gather selected compressed entries for each query
kk = idx.size(-1)
if kk == 0:
kv_sel = kv_comp.new_zeros(Bsz, S, 0, c)
sel_mask = sel_mask.new_zeros(Bsz, S, 0, dtype=torch.bool)
else:
idx_safe = idx.clamp(min=0)
kv_comp_exp = kv_comp.unsqueeze(1).expand(-1, S, -1, -1) # [B,S,nb,c]
kv_sel = torch.gather(
kv_comp_exp, 2,
idx_safe.unsqueeze(-1).expand(-1, -1, -1, c)
) # [B,S,kk,c]
else:
kv_sel = None
sel_mask = None
kk = 0
# ---------------- core attention ----------------
scale = 1.0 / math.sqrt(c)
# Sliding-window logits: einsum over the shared-KV (single head broadcast)
# q: [B,S,H,c], kv_sw: [B,S,c]
sw_logits = torch.einsum("bthd,bjd->bthj", q, kv_sw) * scale # [B,S,H,S]
if self.mode == "sw":
mask = sw_mask.unsqueeze(2) # [B,S,1,S]
logits = sw_logits.masked_fill(~mask, float("-inf"))
sink = self.attn_sink.view(1, 1, -1, 1)
probs = sink_softmax(logits, sink, dim=-1)
kv_v = kv_sw.unsqueeze(1).expand(-1, S, -1, -1) # [B,S,S,c] (broadcast read)
out = torch.einsum("bthj,btjd->bthd", probs, kv_v) # [B,S,H,c]
elif self.mode == "hca":
# logits over [compressed blocks (nb)] + [SW window (S)]
comp_logits = torch.einsum("bthd,bjd->bthj", q, kv_comp) * scale # [B,S,H,nb]
comp_logits = comp_logits.masked_fill(~comp_mask.unsqueeze(2), float("-inf"))
sw_logits = sw_logits.masked_fill(~sw_mask.unsqueeze(2), float("-inf"))
logits = torch.cat([comp_logits, sw_logits], dim=-1) # [B,S,H,nb+S]
sink = self.attn_sink.view(1, 1, -1, 1)
probs = sink_softmax(logits, sink, dim=-1)
p_comp, p_sw = probs.split([nb, S], dim=-1)
out = (
torch.einsum("bthj,bjd->bthd", p_comp, kv_comp) +
torch.einsum("bthj,bjd->bthd", p_sw, kv_sw)
) # [B,S,H,c]
else: # csa
# logits over [selected (kk)] + [SW window (S)]
sel_logits = torch.einsum("bthd,btjd->bthj", q, kv_sel) * scale # [B,S,H,kk]
if kk > 0:
sel_logits = sel_logits.masked_fill(~sel_mask.unsqueeze(2), float("-inf"))
sw_logits = sw_logits.masked_fill(~sw_mask.unsqueeze(2), float("-inf"))
logits = torch.cat([sel_logits, sw_logits], dim=-1) # [B,S,H,kk+S]
sink = self.attn_sink.view(1, 1, -1, 1)
probs = sink_softmax(logits, sink, dim=-1)
p_sel, p_sw = probs.split([kk, S], dim=-1)
out = (
(torch.einsum("bthj,btjd->bthd", p_sel, kv_sel) if kk > 0 else 0) +
torch.einsum("bthj,bjd->bthd", p_sw, kv_sw)
)
# Output RoPE-by-(-i) trick
out = self._apply_output_rope(out, rope_cos, rope_sin, positions)
return self._output_proj(out)
# =============================================================================
# Clamped SwiGLU expert
# =============================================================================
class SwiGLUExpert(nn.Module):
"""w1 = gate, w3 = up, w2 = down (matches official naming)."""
def __init__(self, hidden_size: int, intermediate_size: int, limit: float):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.limit = limit
def forward(self, x):
g = self.w1(x)
u = self.w3(x)
# V4 SwiGLU clamping: linear in [-limit, limit], gate <= limit
u = torch.clamp(u, -self.limit, self.limit)
g = torch.minimum(g, torch.full_like(g, self.limit))
return self.w2(F.silu(g) * u)
# =============================================================================
# DeepseekMoE (sqrt-softplus routing, aux-loss-free) + Hash variant
# =============================================================================
class MoEGate(nn.Module):
"""Gate parameters matching official ``inference/model.py:Gate``:
Always present:
- ``weight`` [n_routed_experts, hidden_size]: produces routing scores
(sqrt(softplus) by default in V4) for BOTH hash and non-hash layers.
For hash layers the score still defines per-token expert weights;
only the *index selection* uses the hash table.
Conditional:
- ``bias`` [n_routed_experts] (non-hash only): aux-loss-free routing
bias added to scores at top-k selection time. Stored as a learnable
float32 parameter to match the official layout.
- ``tid2eid`` [vocab_size, top_k] (hash only): non-trainable lookup
table mapping token-id -> expert indices.
"""
def __init__(self, hidden_size: int, num_experts: int, vocab_size: int,
hash_routing: bool, top_k: int):
super().__init__()
self.hash_routing = hash_routing
# Gate weight is ALWAYS present — used to compute routing scores even
# in hash-routed layers (only the index selection differs there).
self.weight = nn.Parameter(torch.zeros(num_experts, hidden_size))
if hash_routing:
# tid2eid is non-trainable; matches official (requires_grad=False)
self.tid2eid = nn.Parameter(
torch.zeros(vocab_size, top_k, dtype=torch.long),
requires_grad=False,
)
self.bias = None
else:
self.bias = nn.Parameter(torch.zeros(num_experts, dtype=torch.float32))
class DeepseekV4MoE(nn.Module):
def __init__(self, config: DeepseekV4Config, hash_routing: bool):
super().__init__()
self.config = config
self.hash_routing = hash_routing
self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling = config.routed_scaling_factor
d = config.hidden_size
inter = config.moe_intermediate_size
limit = config.swiglu_limit
self.gate = MoEGate(d, self.num_experts, config.vocab_size,
hash_routing=hash_routing, top_k=self.top_k)
self.experts = nn.ModuleList([
SwiGLUExpert(d, inter, limit) for _ in range(self.num_experts)
])
if config.n_shared_experts > 0:
self.shared_experts = SwiGLUExpert(d, inter * config.n_shared_experts, limit)
else:
self.shared_experts = None
def _routed_indices(self, x_flat: torch.Tensor, token_ids_flat: torch.Tensor):
"""Matches inference/model.py:Gate exactly. Hash layers still derive
weights from the learned gate (only the index selection differs).
"""
# Score in fp32 for stability, matches official.
logits = F.linear(x_flat.float(), self.gate.weight.float()) # [N, E]
if self.config.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
elif self.config.scoring_func == "sigmoid":
scores = torch.sigmoid(logits)
else: # sqrtsoftplus (V4 default)
scores = F.softplus(logits).sqrt()
original_scores = scores
if self.hash_routing:
idx = self.gate.tid2eid[token_ids_flat].long() # [N, K]
else:
biased = scores + self.gate.bias.float()
idx = biased.topk(self.top_k, dim=-1).indices
weights = original_scores.gather(-1, idx)
if self.config.scoring_func != "softmax":
weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-9)
weights = weights * self.routed_scaling
return idx, weights.to(x_flat.dtype)
def forward(self, x: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor:
Bsz, S, D = x.shape
N = Bsz * S
x_flat = x.reshape(N, D)
token_ids_flat = token_ids.reshape(N)
idx, w = self._routed_indices(x_flat, token_ids_flat) # [N,K], [N,K]
out = torch.zeros_like(x_flat)
flat_idx = idx.reshape(-1)
flat_w = w.reshape(-1)
flat_tok = torch.arange(N, device=x.device).unsqueeze(-1).expand(-1, self.top_k).reshape(-1)
for e in range(self.num_experts):
mask = flat_idx == e
if not mask.any():
continue
t = flat_tok[mask]
inp = x_flat[t]
y = self.experts[e](inp) * flat_w[mask].unsqueeze(-1)
out.index_add_(0, t, y)
if self.shared_experts is not None:
out = out + self.shared_experts(x_flat)
return out.reshape(Bsz, S, D)
# =============================================================================
# Decoder layer
# =============================================================================
class DeepseekV4Layer(nn.Module):
def __init__(self, config: DeepseekV4Config, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
compress_ratio = config.compress_ratios[layer_idx]
is_hash = layer_idx < config.num_hash_layers
self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = DeepseekV4Attention(config, compress_ratio)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = DeepseekV4MoE(config, hash_routing=is_hash)
# mHC: parameter shapes match official ((2+n)*n outputs from a single
# combined `_fn` matmul; 3 scalar `_scale` gates; mix_hc-sized `_base`).
# Init: zeros for `_base` (so initial pre/post = sigmoid(0)+eps = 0.5+eps,
# comb starts as a near-uniform softmax then Sinkhorn-projected); small
# random for `_fn`; small `_scale`.
n_hc = config.hc_mult
mix_hc = (2 + n_hc) * n_hc
flat = n_hc * config.hidden_size
self.hc_attn_fn = nn.Parameter(torch.zeros(mix_hc, flat))
self.hc_ffn_fn = nn.Parameter(torch.zeros(mix_hc, flat))
nn.init.normal_(self.hc_attn_fn, mean=0.0, std=config.initializer_range)
nn.init.normal_(self.hc_ffn_fn, mean=0.0, std=config.initializer_range)
self.hc_attn_base = nn.Parameter(torch.zeros(mix_hc))
self.hc_ffn_base = nn.Parameter(torch.zeros(mix_hc))
self.hc_attn_scale = nn.Parameter(torch.full((3,), 1e-2))
self.hc_ffn_scale = nn.Parameter(torch.full((3,), 1e-2))
def forward(self, X: torch.Tensor, mhc: MHC, token_ids: torch.Tensor,
positions: torch.Tensor,
rope_cos, rope_sin, rope_cos_c, rope_sin_c,
pad_mask: Optional[torch.Tensor]) -> torch.Tensor:
# Attention sub-block: hc_pre collapses [B,S,n,d] -> [B,S,d] via `pre` weights;
# hc_post produces [B,S,n,d] = post * new_x + comb @ residual.
residual = X
pre, post, comb = mhc.gen_params(X, self.hc_attn_base, self.hc_attn_fn,
self.hc_attn_scale)
sub_in = MHC.hc_pre(X, pre)
sub_in = self.attn_norm(sub_in)
attn_out = self.attn(sub_in, positions, rope_cos, rope_sin,
rope_cos_c, rope_sin_c, pad_mask)
X = MHC.hc_post(attn_out, residual, post, comb)
# FFN sub-block
residual = X
pre, post, comb = mhc.gen_params(X, self.hc_ffn_base, self.hc_ffn_fn,
self.hc_ffn_scale)
sub_in = MHC.hc_pre(X, pre)
sub_in = self.ffn_norm(sub_in)
ffn_out = self.ffn(sub_in, token_ids)
X = MHC.hc_post(ffn_out, residual, post, comb)
return X
# =============================================================================
# MTP module (V3-style single-step)
# =============================================================================
class DeepseekV4MTPModule(nn.Module):
"""One MTP step. Mirrors the official ``MTPBlock`` (which inherits from Block):
Pre-block:
e = embed(input_ids); e = enorm(e); X = hnorm(X)
X = e_proj(e).unsqueeze(2) + h_proj(X) # broadcast e across hc copies
Block:
full hc_pre / attn / hc_post / hc_pre / ffn / hc_post
Post-block:
logits = head( hc_head_collapse(X), through final norm + lm_head )
hc_attn_* and hc_ffn_* shapes: [(2+n)*n, n*d] / [(2+n)*n] / [3] (full mHC)
hc_head_* shapes: [n, n*d] / [n] / [1] (pre-only)
"""
def __init__(self, config: DeepseekV4Config):
super().__init__()
d = config.hidden_size
self.enorm = RMSNorm(d, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(d, eps=config.rms_norm_eps)
self.e_proj = nn.Linear(d, d, bias=False)
self.h_proj = nn.Linear(d, d, bias=False)
# One transformer block (dense / pure-SW attention)
self.attn_norm = RMSNorm(d, eps=config.rms_norm_eps)
self.attn = DeepseekV4Attention(config, compress_ratio=0)
self.ffn_norm = RMSNorm(d, eps=config.rms_norm_eps)
self.ffn = DeepseekV4MoE(config, hash_routing=False)
self.norm = RMSNorm(d, eps=config.rms_norm_eps)
n_hc = config.hc_mult
mix_hc = (2 + n_hc) * n_hc
flat = n_hc * d
# Full mHC for attn and ffn sub-blocks
for prefix in ("hc_attn", "hc_ffn"):
fn_p = nn.Parameter(torch.zeros(mix_hc, flat))
nn.init.normal_(fn_p, mean=0.0, std=config.initializer_range)
setattr(self, f"{prefix}_fn", fn_p)
setattr(self, f"{prefix}_base", nn.Parameter(torch.zeros(mix_hc)))
setattr(self, f"{prefix}_scale", nn.Parameter(torch.full((3,), 1e-2)))
# Pre-only mHC for head collapse
head_fn = nn.Parameter(torch.zeros(n_hc, flat))
nn.init.normal_(head_fn, mean=0.0, std=config.initializer_range)
self.hc_head_fn = head_fn
self.hc_head_base = nn.Parameter(torch.zeros(n_hc))
self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2))
def forward(self, X: torch.Tensor, embed: nn.Embedding, head: nn.Linear,
input_ids: torch.Tensor, mhc: MHC, positions: torch.Tensor,
rope_cos, rope_sin, rope_cos_c, rope_sin_c,
pad_mask: Optional[torch.Tensor]) -> torch.Tensor:
"""X: [B,S,n,d] residual stream from main model. Returns logits [B,S,V]."""
e = embed(input_ids) # [B,S,d]
e = self.enorm(e)
Xn = self.hnorm(X) # [B,S,n,d]
# Mix in next-token embedding broadcast across hc copies
X = self.e_proj(e).unsqueeze(-2) + self.h_proj(Xn) # [B,S,n,d]
# Attention sub-block via full mHC
residual = X
pre, post, comb = mhc.gen_params(X, self.hc_attn_base, self.hc_attn_fn,
self.hc_attn_scale)
sub_in = MHC.hc_pre(X, pre)
sub_in = self.attn_norm(sub_in)
attn_out = self.attn(sub_in, positions, rope_cos, rope_sin,
rope_cos_c, rope_sin_c, pad_mask)
X = MHC.hc_post(attn_out, residual, post, comb)
# FFN sub-block via full mHC
residual = X
pre, post, comb = mhc.gen_params(X, self.hc_ffn_base, self.hc_ffn_fn,
self.hc_ffn_scale)
sub_in = MHC.hc_pre(X, pre)
sub_in = self.ffn_norm(sub_in)
ffn_out = self.ffn(sub_in, input_ids)
X = MHC.hc_post(ffn_out, residual, post, comb)
# Head: pre-only mHC collapse, then norm, then shared lm_head
head_pre = mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base,
self.hc_head_scale)
h_out = MHC.hc_pre(X, head_pre)
h_out = self.norm(h_out)
return head(h_out)
# =============================================================================
# PreTrainedModel base + top-level classes
# =============================================================================
class DeepseekV4PreTrainedModel(PreTrainedModel):
config_class = DeepseekV4Config
base_model_prefix = "" # flat layout, no `model.` prefix
supports_gradient_checkpointing = True
_no_split_modules = ["DeepseekV4Layer", "DeepseekV4MTPModule"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
class DeepseekV4Model(DeepseekV4PreTrainedModel):
"""The base model exposes the same fields as ForCausalLM (flat layout)
so that names match the official safetensors. We instantiate it as part of
ForCausalLM rather than wrapping it.
"""
def __init__(self, config: DeepseekV4Config):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.layers = nn.ModuleList([
DeepseekV4Layer(config, i) for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Head-side mHC (collapses [B,S,n_hc,d] residual stream back to [B,S,d])
# Head-side mHC: ONLY computes the `pre` (collapse hc -> 1) weights,
# so shapes are [hc, hc*d] / [hc] / [1] (matching official ParallelHead).
n_hc = config.hc_mult
flat = n_hc * config.hidden_size
self.hc_head_fn = nn.Parameter(torch.zeros(n_hc, flat))
nn.init.normal_(self.hc_head_fn, mean=0.0, std=config.initializer_range)
self.hc_head_base = nn.Parameter(torch.zeros(n_hc))
self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2))
self._mhc = MHC(config.hidden_size, config.hc_mult,
sinkhorn_iters=config.hc_sinkhorn_iters,
eps=config.rms_norm_eps)
# MTP modules
self.mtp = nn.ModuleList([
DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
])
self.post_init()
def _build_rope(self, max_len: int, device, dtype):
rope_dim = self.config.qk_rope_head_dim
cos, sin = build_rope_cache(max_len, rope_dim, self.config.rope_theta, device, dtype)
cos_c, sin_c = build_rope_cache(max_len, rope_dim,
self.config.compress_rope_theta, device, dtype)
return cos, sin, cos_c, sin_c
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs) -> BaseModelOutputWithPast:
Bsz, S = input_ids.shape
device = input_ids.device
h = self.embed(input_ids) # [B,S,d]
# Lift into mHC residual stream [B,S,n_hc,d]
n_hc = self.config.hc_mult
X = h.unsqueeze(-2).expand(-1, -1, n_hc, -1).contiguous()
if position_ids is None:
positions = torch.arange(S, device=device)
else:
positions = position_ids[0]
# Cap RoPE table at S to keep memory bounded (model still supports up to max_position_embeddings)
rope_cos, rope_sin, rope_cos_c, rope_sin_c = self._build_rope(S, device, h.dtype)
pad_mask = attention_mask.bool() if attention_mask is not None else None
for layer in self.layers:
X = layer(X, self._mhc, input_ids, positions,
rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
# Head-side mHC: collapse residual back to [B,S,d] using A_l
# Head mHC: pre-only collapse hc -> 1, then final norm
head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base,
self.hc_head_scale)
h_out = MHC.hc_pre(X, head_pre)
h_out = self.norm(h_out)
return BaseModelOutputWithPast(last_hidden_state=h_out)
class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel):
_tied_weights_keys: List[str] = [] # untied (matches V4 config)
def __init__(self, config: DeepseekV4Config):
super().__init__(config)
# Flat layout — instantiate the base model's fields directly on self
# so safetensors keys come out as `embed.weight`, `layers.0...`, etc.
self.embed = nn.Embedding(config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.layers = nn.ModuleList([
DeepseekV4Layer(config, i) for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Head-side mHC: ONLY computes the `pre` (collapse hc -> 1) weights,
# so shapes are [hc, hc*d] / [hc] / [1] (matching official ParallelHead).
n_hc = config.hc_mult
flat = n_hc * config.hidden_size
self.hc_head_fn = nn.Parameter(torch.zeros(n_hc, flat))
nn.init.normal_(self.hc_head_fn, mean=0.0, std=config.initializer_range)
self.hc_head_base = nn.Parameter(torch.zeros(n_hc))
self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2))
self._mhc = MHC(config.hidden_size, config.hc_mult,
sinkhorn_iters=config.hc_sinkhorn_iters,
eps=config.rms_norm_eps)
self.mtp = nn.ModuleList([
DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
])
self.post_init()
# HF auto methods
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.head
def set_output_embeddings(self, new):
self.head = new
def _backbone(self, input_ids, attention_mask, position_ids):
"""Runs embed -> hc-expand -> N layers and returns BOTH the post-layer
residual stream X (shape [B,S,n_hc,d], needed by MTP) and the head-collapsed
hidden state (shape [B,S,d], needed by lm_head).
"""
Bsz, S = input_ids.shape
device = input_ids.device
h = self.embed(input_ids)
n_hc = self.config.hc_mult
X = h.unsqueeze(-2).expand(-1, -1, n_hc, -1).contiguous()
if position_ids is None:
positions = torch.arange(S, device=device)
else:
positions = position_ids[0]
rope_dim = self.config.qk_rope_head_dim
rope_cos, rope_sin = build_rope_cache(S, rope_dim, self.config.rope_theta,
device, h.dtype)
rope_cos_c, rope_sin_c = build_rope_cache(S, rope_dim,
self.config.compress_rope_theta,
device, h.dtype)
pad_mask = attention_mask.bool() if attention_mask is not None else None
for layer in self.layers:
X = layer(X, self._mhc, input_ids, positions,
rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
# Head mHC: pre-only collapse hc -> 1, then final norm
head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base,
self.hc_head_scale)
h_out = MHC.hc_pre(X, head_pre)
h_out = self.norm(h_out)
return X, h_out, positions, rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_dict: bool = True,
use_mtp: bool = False,
**kwargs) -> CausalLMOutputWithPast:
X, hidden, positions, rc, rs, rcc, rsc, pad_mask = self._backbone(
input_ids, attention_mask, position_ids
)
logits = self.head(hidden)
loss = None
mtp_logits_list = []
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# MTP: each step k predicts token at offset +(k+2). Feed embedding of
# the next token shifted by (k+1) into the MTP module along with the
# current residual stream.
for k, mtp in enumerate(self.mtp):
shift = k + 1
next_ids = F.pad(input_ids[:, shift:], (0, shift), value=0)
mtp_logits = mtp(X, self.embed, self.head, next_ids, self._mhc,
positions, rc, rs, rcc, rsc, pad_mask)
mtp_target = F.pad(labels[:, shift + 1:], (0, shift + 1), value=-100)
mtp_loss = F.cross_entropy(
mtp_logits.view(-1, mtp_logits.size(-1)),
mtp_target.view(-1),
ignore_index=-100,
)
loss = loss + 0.3 * mtp_loss
mtp_logits_list.append(mtp_logits)
return CausalLMOutputWithPast(loss=loss, logits=logits)