Text Generation
Transformers
Safetensors
deepseek_v4
deepseek-v4
mixture-of-experts
Mixture of Experts
mhc
csa
hca
scaffold
random-init
conversational
Instructions to use kshitijthakkar/deepseek-v4-mini-3B-init with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use kshitijthakkar/deepseek-v4-mini-3B-init with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="kshitijthakkar/deepseek-v4-mini-3B-init") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("kshitijthakkar/deepseek-v4-mini-3B-init") model = AutoModelForCausalLM.from_pretrained("kshitijthakkar/deepseek-v4-mini-3B-init") messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use kshitijthakkar/deepseek-v4-mini-3B-init with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "kshitijthakkar/deepseek-v4-mini-3B-init" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "kshitijthakkar/deepseek-v4-mini-3B-init", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/kshitijthakkar/deepseek-v4-mini-3B-init
- SGLang
How to use kshitijthakkar/deepseek-v4-mini-3B-init with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "kshitijthakkar/deepseek-v4-mini-3B-init" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "kshitijthakkar/deepseek-v4-mini-3B-init", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "kshitijthakkar/deepseek-v4-mini-3B-init" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "kshitijthakkar/deepseek-v4-mini-3B-init", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use kshitijthakkar/deepseek-v4-mini-3B-init with Docker Model Runner:
docker model run hf.co/kshitijthakkar/deepseek-v4-mini-3B-init
| """DeepSeek-V4 modeling code (faithful small-scale replica). | |
| Parameter naming mirrors the official DeepSeek-V4 safetensors index so that | |
| weights can later be transferred / sliced from real V4-Pro / V4-Flash | |
| checkpoints. Top-level layout (flat, no ``model.`` prefix): | |
| embed.weight | |
| layers.{i}.attn_norm.weight | |
| layers.{i}.ffn_norm.weight | |
| layers.{i}.hc_attn_{base,fn,scale} | |
| layers.{i}.hc_ffn_{base,fn,scale} | |
| layers.{i}.attn.{wq_a, wq_b, wkv, wo_a, wo_b, q_norm, kv_norm, attn_sink} | |
| layers.{i}.attn.compressor.{wkv, wgate, ape, norm} # CSA / HCA only | |
| layers.{i}.attn.indexer.{wq_b, weights_proj, compressor.*}# CSA only | |
| layers.{i}.ffn.gate.{weight, bias} # routed MoE | |
| layers.{i}.ffn.gate.tid2eid # hash MoE | |
| layers.{i}.ffn.experts.{j}.{w1, w2, w3}.weight | |
| layers.{i}.ffn.shared_experts.{w1, w2, w3}.weight | |
| norm.weight | |
| head.weight | |
| hc_head_{base, fn, scale} | |
| mtp.{k}.{...} # one per MTP step | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional, Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from .configuration_deepseek_v4 import DeepseekV4Config | |
| # ============================================================================= | |
| # Norms, RoPE, utilities | |
| # ============================================================================= | |
| class RMSNorm(nn.Module): | |
| def __init__(self, hidden_size: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| in_dtype = x.dtype | |
| x32 = x.float() | |
| var = x32.pow(2).mean(-1, keepdim=True) | |
| x32 = x32 * torch.rsqrt(var + self.eps) | |
| return (self.weight * x32).to(in_dtype) | |
| def fixed_rmsnorm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: | |
| """RMSNorm without a learnable scale (used inside mHC).""" | |
| in_dtype = x.dtype | |
| x32 = x.float() | |
| var = x32.pow(2).mean(-1, keepdim=True) | |
| return (x32 * torch.rsqrt(var + eps)).to(in_dtype) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def build_rope_cache(seq_len: int, dim: int, base: float, device, dtype): | |
| if dim <= 0: | |
| return torch.zeros(seq_len, 0, device=device, dtype=dtype), \ | |
| torch.zeros(seq_len, 0, device=device, dtype=dtype) | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) | |
| t = torch.arange(seq_len, device=device, dtype=torch.float32) | |
| freqs = torch.outer(t, inv_freq) | |
| emb = torch.cat([freqs, freqs], dim=-1) | |
| return emb.cos().to(dtype), emb.sin().to(dtype) | |
| def apply_partial_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, | |
| rope_dim: int, positions: torch.Tensor) -> torch.Tensor: | |
| """Apply RoPE to last `rope_dim` dims of x at given positions. | |
| x: [..., L, D]; cos/sin: [P, rope_dim]; positions: [L] long. | |
| """ | |
| if rope_dim <= 0: | |
| return x | |
| x_pass, x_rot = x[..., :-rope_dim], x[..., -rope_dim:] | |
| c = cos[positions] | |
| s = sin[positions] | |
| while c.dim() < x_rot.dim(): | |
| c = c.unsqueeze(0) | |
| s = s.unsqueeze(0) | |
| x_rot = (x_rot * c) + (_rotate_half(x_rot) * s) | |
| return torch.cat([x_pass, x_rot], dim=-1) | |
| # ============================================================================= | |
| # Manifold-Constrained Hyper-Connections (mHC) | |
| # ============================================================================= | |
| class MHC(nn.Module): | |
| """Manifold-Constrained Hyper-Connections. | |
| Parameter layout matches official safetensors / kernel.py exactly: | |
| - {prefix}_fn [mix_hc, n*d] dynamic generator (single combined matmul) | |
| - {prefix}_base [mix_hc] static biases | |
| - {prefix}_scale [3] three scalar gates (one per pre/post/comb part) | |
| with mix_hc = (2 + n) * n. | |
| Math (matches inference/kernel.py:hc_split_sinkhorn_kernel): | |
| flat = X.flatten(-2) # [B,S,n*d] | |
| rsqrt = rsqrt(mean(flat^2) + eps) # row-wise | |
| mixes = (flat @ fn.T) * rsqrt # [B,S, mix_hc] | |
| pre[i] = sigmoid(mixes[:, i] * scale[0] + base[i]) + eps for i in [0,n) | |
| post[i] = 2 * sigmoid(mixes[:, n+i] * scale[1] + base[n+i]) for i in [0,n) | |
| comb_raw = mixes[:, 2n + j*n + k] * scale[2] + base[2n+j*n+k] [n,n] | |
| comb = softmax(comb_raw, dim=-1) + eps # row softmax then +eps | |
| comb = comb / (comb.sum(-2, keepdim=True) + eps) # column normalize | |
| repeat (sinkhorn_iters - 1) times: | |
| comb = comb / (comb.sum(-1, keepdim=True) + eps) | |
| comb = comb / (comb.sum(-2, keepdim=True) + eps) | |
| Apply (matches Block.hc_pre / hc_post): | |
| sublayer_in = sum_i pre[i] * X[i] # [B,S,d] | |
| new_X[i] = post[i] * F_out + sum_j comb[i,j] * X[j] # [B,S,n,d] | |
| """ | |
| def __init__(self, hidden_size: int, n_hc: int, sinkhorn_iters: int = 20, | |
| eps: float = 1e-6): | |
| super().__init__() | |
| self.d = hidden_size | |
| self.n = n_hc | |
| self.iters = sinkhorn_iters | |
| self.eps = eps | |
| self.flat = n_hc * hidden_size | |
| self.mix_hc = (2 + n_hc) * n_hc # = 24 for n=4 | |
| def split_and_construct(self, mixes: torch.Tensor, base: torch.Tensor, | |
| scale: torch.Tensor): | |
| """mixes: [..., mix_hc]; base: [mix_hc]; scale: [3]. | |
| Returns (pre [...,n], post [...,n], comb [...,n,n]). | |
| All math is in fp32 (matches official ``with set_dtype(torch.float32)`` | |
| block around hc_*_fn / base / scale params); base/scale may be stored | |
| in any dtype but are promoted to mixes.dtype for arithmetic. | |
| """ | |
| n = self.n | |
| base = base.to(mixes.dtype) | |
| scale = scale.to(mixes.dtype) | |
| # Indexing: pre = first n, post = next n, comb = last n*n flattened row-major. | |
| pre_raw = mixes[..., :n] | |
| post_raw = mixes[..., n:2 * n] | |
| comb_raw = mixes[..., 2 * n:].reshape(*mixes.shape[:-1], n, n) | |
| base_pre = base[:n] | |
| base_post = base[n:2 * n] | |
| base_comb = base[2 * n:].view(n, n) | |
| pre = torch.sigmoid(scale[0] * pre_raw + base_pre) + self.eps | |
| post = 2.0 * torch.sigmoid(scale[1] * post_raw + base_post) | |
| comb_pre = scale[2] * comb_raw + base_comb | |
| # Row-softmax then +eps, then column normalize, then alternating row/col norms. | |
| comb = F.softmax(comb_pre, dim=-1) + self.eps | |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + self.eps) | |
| for _ in range(self.iters - 1): | |
| comb = comb / (comb.sum(dim=-1, keepdim=True) + self.eps) | |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + self.eps) | |
| return pre, post, comb | |
| def gen_params(self, X: torch.Tensor, base: torch.Tensor, fn: torch.Tensor, | |
| scale: torch.Tensor): | |
| """X: [B,S,n,d]. Returns (pre [B,S,n], post [B,S,n], comb [B,S,n,n]). | |
| Always computed in fp32 (matches official `with set_dtype(fp32)` for mHC). | |
| """ | |
| Bsz, S, n, d = X.shape | |
| flat = X.reshape(Bsz, S, n * d).float() | |
| rsqrt = torch.rsqrt(flat.square().mean(-1, keepdim=True) + self.eps) | |
| mixes = F.linear(flat, fn.float()) * rsqrt # [B,S, mix_hc] | |
| return self.split_and_construct(mixes, base, scale) | |
| def hc_pre(X: torch.Tensor, pre: torch.Tensor) -> torch.Tensor: | |
| """X: [B,S,n,d], pre: [B,S,n]. Returns [B,S,d].""" | |
| return torch.sum(pre.unsqueeze(-1).to(X.dtype) * X, dim=-2) | |
| def hc_post(new_x: torch.Tensor, residual: torch.Tensor, | |
| post: torch.Tensor, comb: torch.Tensor) -> torch.Tensor: | |
| """new_x: [B,S,d], residual: [B,S,n,d], post: [B,S,n], comb: [B,S,n,n]. | |
| out[i] = post[i] * new_x + sum_j comb[i,j] * residual[j] | |
| """ | |
| post_e = post.unsqueeze(-1).to(new_x.dtype) # [B,S,n,1] | |
| comb_e = comb.to(residual.dtype) # [B,S,n,n] | |
| return post_e * new_x.unsqueeze(-2) + torch.matmul(comb_e, residual) | |
| # --- Head-side mHC: only computes `pre`, no Sinkhorn. --- | |
| def gen_head_pre(self, X: torch.Tensor, fn: torch.Tensor, | |
| base: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
| """fn: [n, n*d]; base: [n]; scale: [1] or scalar. Returns pre: [B,S,n].""" | |
| Bsz, S, n, d = X.shape | |
| flat = X.reshape(Bsz, S, n * d).float() | |
| rsqrt = torch.rsqrt(flat.square().mean(-1, keepdim=True) + self.eps) | |
| mixes = F.linear(flat, fn.float()) * rsqrt # [B,S,n] | |
| scale = scale.float() | |
| base = base.float() | |
| s = scale.view(-1)[0] if scale.numel() else 1.0 | |
| return torch.sigmoid(s * mixes + base) + self.eps | |
| # ============================================================================= | |
| # Attention sink helper | |
| # ============================================================================= | |
| def sink_softmax(logits: torch.Tensor, sink: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
| """Softmax with extra learnable per-head sink logit in the denominator. | |
| logits: [..., H, ..., K]; sink: [H] (broadcast). | |
| Caller must shape sink so it broadcasts with logits. | |
| """ | |
| m = logits.amax(dim=dim, keepdim=True) | |
| m = torch.maximum(m, sink) | |
| ex = torch.exp(logits - m) | |
| sink_ex = torch.exp(sink - m) | |
| return ex / (ex.sum(dim=dim, keepdim=True) + sink_ex) | |
| # ============================================================================= | |
| # Compressor (token-level pooling); shared for HCA, CSA, and indexer keys | |
| # ============================================================================= | |
| class Compressor(nn.Module): | |
| """Compresses every `m` hidden states into one entry via softmax-weighted pool. | |
| Matches inference/model.py:Compressor exactly: | |
| - When overlap (compress_ratio == 4): wkv and wgate output 2*head_dim | |
| (first half = overlap stream, second half = current); ape: [m, 2*head_dim]. | |
| ``overlap_transform`` rearranges [b,nb,m,2c] -> [b,nb,2m,c] before softmax. | |
| - When non-overlap: outputs head_dim, ape: [m, head_dim], plain softmax pool. | |
| Tensor names: ``norm.weight``, ``wkv.weight``, ``wgate.weight``, ``ape``. | |
| """ | |
| def __init__(self, hidden_size: int, head_dim: int, m: int, overlap: bool): | |
| super().__init__() | |
| self.m = m | |
| self.overlap = overlap | |
| self.head_dim = head_dim | |
| coff = 2 if overlap else 1 | |
| self.coff = coff | |
| self.norm = RMSNorm(head_dim) | |
| self.wkv = nn.Linear(hidden_size, coff * head_dim, bias=False) | |
| self.wgate = nn.Linear(hidden_size, coff * head_dim, bias=False) | |
| self.ape = nn.Parameter(torch.zeros(m, coff * head_dim)) | |
| def _overlap_transform(t: torch.Tensor, head_dim: int, fill_value) -> torch.Tensor: | |
| """t: [b, nb, m, 2*head_dim]; returns [b, nb, 2*m, head_dim]. | |
| First m positions of each block come from the previous block's overlap-half; | |
| next m positions come from the current block's current-half. | |
| """ | |
| b, nb, m, _ = t.shape | |
| d = head_dim | |
| out = t.new_full((b, nb, 2 * m, d), fill_value) | |
| out[:, :, m:] = t[:, :, :, d:] # current half | |
| out[:, 1:, :m] = t[:, :-1, :, :d] # prev block's overlap half, shift +1 | |
| return out | |
| def forward(self, h: torch.Tensor) -> torch.Tensor: | |
| """h: [B, n, D]. Returns compressed [B, ceil(n/m), head_dim].""" | |
| Bsz, n, _ = h.shape | |
| m, d = self.m, self.head_dim | |
| # Matmul in whatever dtype the wkv/wgate weights live in (bf16 / fp32). | |
| # We then upcast to fp32 for the softmax-weighted pool (numerical stability). | |
| param_dtype = self.wkv.weight.dtype | |
| xx = h.to(param_dtype) | |
| kv = self.wkv(xx).float() # [B, n, coff*d] | |
| score = self.wgate(xx).float() # [B, n, coff*d] | |
| # Pad to multiple of m | |
| pad = (m - n % m) % m | |
| if pad: | |
| kv = F.pad(kv, (0, 0, 0, pad)) | |
| score = F.pad(score, (0, 0, 0, pad)) | |
| nb = kv.size(1) // m | |
| kv = kv.view(Bsz, nb, m, -1) # [B,nb,m, coff*d] | |
| score = score.view(Bsz, nb, m, -1) + self.ape.float() # bias by ape (fp32) | |
| if self.overlap: | |
| kv = self._overlap_transform(kv, d, 0.0) # [B,nb, 2m, d] | |
| score = self._overlap_transform(score, d, float("-inf")) # [B,nb, 2m, d] | |
| # Softmax over the m (or 2m) positions, weighted sum to one entry per block | |
| kv = (kv * score.softmax(dim=2)).sum(dim=2) # [B,nb, d] | |
| return self.norm(kv.to(h.dtype)) | |
| # ============================================================================= | |
| # Lightning Indexer | |
| # ============================================================================= | |
| class LightningIndexer(nn.Module): | |
| """ | |
| Names: | |
| indexer.compressor.* (separate Compressor for indexer keys) | |
| indexer.wq_b.weight (q-up from shared cQ -> H_I * head_dim) | |
| indexer.weights_proj.weight (per-head weight w_t,h) | |
| """ | |
| def __init__(self, hidden_size: int, q_lora_rank: int, | |
| index_n_heads: int, index_head_dim: int, | |
| m: int, overlap: bool): | |
| super().__init__() | |
| self.n_heads = index_n_heads | |
| self.head_dim = index_head_dim | |
| self.compressor = Compressor(hidden_size, index_head_dim, m, overlap=overlap) | |
| self.wq_b = nn.Linear(q_lora_rank, index_n_heads * index_head_dim, bias=False) | |
| self.weights_proj = nn.Linear(hidden_size, index_n_heads, bias=False) | |
| # Score scaling: softmax_scale * 1/sqrt(n_heads), as in inference/model.py | |
| self.score_scale = (index_head_dim ** -0.5) * (index_n_heads ** -0.5) | |
| def keys(self, h: torch.Tensor) -> torch.Tensor: | |
| return self.compressor(h) # [B, nb, head_dim] | |
| def select(self, h: torch.Tensor, cQ: torch.Tensor, K: torch.Tensor, | |
| positions: torch.Tensor, m: int, top_k: int): | |
| """Returns (idx [B,Lq,k], mask [B,Lq,k] bool).""" | |
| Bsz, Lq, _ = h.shape | |
| nb = K.size(1) | |
| qI = self.wq_b(cQ).view(Bsz, Lq, self.n_heads, self.head_dim) | |
| wI = self.weights_proj(h) * self.score_scale # [B,Lq,H_I] | |
| qK = torch.einsum("blhd,bsd->blhs", qI, K) | |
| qK = F.relu(qK) | |
| scores = (wI.unsqueeze(-1) * qK).sum(dim=2) # [B,Lq,nb] | |
| # Causal: query at pos t may attend to block s if (s+1)*m - 1 < t ⇔ s < t/m | |
| s_idx = torch.arange(nb, device=h.device) | |
| causal = s_idx.unsqueeze(0) < (positions.unsqueeze(-1) // m) # [Lq, nb] | |
| scores = scores.masked_fill(~causal.unsqueeze(0), float("-inf")) | |
| k = min(top_k, nb) | |
| if k <= 0: | |
| empty = torch.zeros(Bsz, Lq, 0, dtype=torch.long, device=h.device) | |
| return empty, empty.bool() | |
| topk = scores.topk(k, dim=-1) | |
| return topk.indices, torch.isfinite(topk.values) | |
| # ============================================================================= | |
| # Attention layer (CSA / HCA / pure sliding-window) | |
| # ============================================================================= | |
| class DeepseekV4Attention(nn.Module): | |
| """One attention layer. | |
| compress_ratio: 0 -> pure SW; small (>0, <16) -> CSA; large (>=16) -> HCA. | |
| """ | |
| def __init__(self, config: DeepseekV4Config, compress_ratio: int): | |
| super().__init__() | |
| self.config = config | |
| self.compress_ratio = compress_ratio | |
| d = config.hidden_size | |
| H = config.num_attention_heads | |
| c = config.head_dim | |
| self.H = H | |
| self.c = c | |
| self.q_lora_rank = config.q_lora_rank | |
| self.o_groups = config.o_groups | |
| assert H % self.o_groups == 0 | |
| self.heads_per_group = H // self.o_groups | |
| self.d_g = config.o_lora_rank | |
| self.rope_dim = config.qk_rope_head_dim | |
| self.window = config.sliding_window | |
| if compress_ratio == 0: | |
| self.mode = "sw" | |
| elif compress_ratio < 16: | |
| self.mode = "csa" | |
| else: | |
| self.mode = "hca" | |
| # Query path: low-rank with norm at q_lora; per-head rsqrt-norm applied at use time | |
| self.wq_a = nn.Linear(d, config.q_lora_rank, bias=False) | |
| self.q_norm = RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) | |
| self.wq_b = nn.Linear(config.q_lora_rank, H * c, bias=False) | |
| # Sliding-window KV (always present, single shared head — MQA) | |
| self.wkv = nn.Linear(d, c, bias=False) | |
| self.kv_norm = RMSNorm(c, eps=config.rms_norm_eps) | |
| # Output projection: per-group wo_a (n_groups separate sub-matrices stored in | |
| # one Linear; reshape weight to [n_groups, o_lora_rank, heads_per_group*c] | |
| # at use time and apply via einsum). | |
| self.wo_a = nn.Linear(self.heads_per_group * c, | |
| self.o_groups * self.d_g, bias=False) | |
| self.wo_b = nn.Linear(self.o_groups * self.d_g, d, bias=False) | |
| self.attn_sink = nn.Parameter(torch.zeros(H)) | |
| if self.mode in ("csa", "hca"): | |
| self.compressor = Compressor(d, c, compress_ratio, overlap=(self.mode == "csa")) | |
| if self.mode == "csa": | |
| self.indexer = LightningIndexer(d, config.q_lora_rank, | |
| config.index_n_heads, config.index_head_dim, | |
| m=compress_ratio, overlap=True) | |
| def _output_proj(self, attn_out: torch.Tensor) -> torch.Tensor: | |
| """attn_out: [B, S, H, c]. Returns [B, S, d]. | |
| Uses per-group wo_a: weight is [n_groups*o_lora, heads_per_group*c]; we | |
| reshape to [n_groups, o_lora, heads_per_group*c] and apply via einsum | |
| so each group has its own projection (matching official inference). | |
| """ | |
| B, S, H, c = attn_out.shape | |
| out_g = attn_out.reshape(B, S, self.o_groups, self.heads_per_group * c) | |
| wo_a = self.wo_a.weight.view(self.o_groups, self.d_g, self.heads_per_group * c) | |
| out = torch.einsum("bsgd,grd->bsgr", out_g, wo_a) # [B,S,g,d_g] | |
| out = out.reshape(B, S, self.o_groups * self.d_g) | |
| return self.wo_b(out) | |
| def _apply_output_rope(self, out: torch.Tensor, rope_cos, rope_sin, | |
| positions) -> torch.Tensor: | |
| """V4 trick: rotate output by -position so contributions carry relative pos.""" | |
| if self.rope_dim <= 0: | |
| return out | |
| cos = rope_cos[positions] # [S, rope_dim] | |
| sin = -rope_sin[positions] # negate -> rotate by -i | |
| cos = cos.unsqueeze(0).unsqueeze(2) | |
| sin = sin.unsqueeze(0).unsqueeze(2) | |
| out_pass, out_rot = out[..., :-self.rope_dim], out[..., -self.rope_dim:] | |
| out_rot = (out_rot * cos) + (_rotate_half(out_rot) * sin) | |
| return torch.cat([out_pass, out_rot], dim=-1) | |
| def forward(self, x: torch.Tensor, positions: torch.Tensor, | |
| rope_cos: torch.Tensor, rope_sin: torch.Tensor, | |
| rope_cos_c: torch.Tensor, rope_sin_c: torch.Tensor, | |
| pad_mask: Optional[torch.Tensor]) -> torch.Tensor: | |
| """x: [B,S,d]; positions: [S] long; pad_mask: [B,S] bool (True=valid) or None.""" | |
| Bsz, S, _ = x.shape | |
| H, c, m = self.H, self.c, self.compress_ratio | |
| # Queries: low-rank, latent norm, then per-head no-weight RMSNorm (paper), | |
| # then partial RoPE on the last `rope_dim` dims. | |
| cQ = self.q_norm(self.wq_a(x)) # [B,S,q_lora] | |
| q = self.wq_b(cQ).view(Bsz, S, H, c) | |
| # Per-head fixed RMSNorm (no learnable weight) — see inference/model.py | |
| q = q * torch.rsqrt(q.float().square().mean(-1, keepdim=True) + | |
| self.config.rms_norm_eps).to(q.dtype) | |
| q = apply_partial_rope(q.transpose(1, 2), rope_cos, rope_sin, | |
| self.rope_dim, positions).transpose(1, 2) | |
| # q now [B,S,H,c] | |
| # Sliding-window KV | |
| kv_sw = self.kv_norm(self.wkv(x)) # [B,S,c] | |
| kv_sw = apply_partial_rope(kv_sw, rope_cos, rope_sin, self.rope_dim, positions) | |
| # Build SW causal+window mask: [S, S] then expand to [B,S,S] | |
| i = positions.unsqueeze(-1) | |
| j = positions.unsqueeze(0) | |
| sw_mask = (j <= i) & (j > i - self.window) # [S,S] | |
| sw_mask = sw_mask.unsqueeze(0).expand(Bsz, -1, -1) | |
| if pad_mask is not None: | |
| sw_mask = sw_mask & pad_mask.unsqueeze(1) # mask padded keys | |
| # ---------------- compressed branch ---------------- | |
| if self.mode in ("csa", "hca"): | |
| # The compressor has its OWN internal RMSNorm on output — do not | |
| # re-apply self.kv_norm (that one is for the sliding-window path). | |
| kv_comp = self.compressor(x) # [B,nb,c] | |
| nb = kv_comp.size(1) | |
| comp_pos = (torch.arange(nb, device=x.device) * m + (m - 1)).clamp( | |
| max=rope_cos_c.size(0) - 1 | |
| ) | |
| kv_comp = apply_partial_rope(kv_comp, rope_cos_c, rope_sin_c, | |
| self.rope_dim, comp_pos) | |
| # Per-query causal mask over compressed blocks | |
| block_end = torch.arange(nb, device=x.device) * m + (m - 1) | |
| comp_mask = (block_end.unsqueeze(0) < positions.unsqueeze(-1)) # [S,nb] | |
| comp_mask = comp_mask.unsqueeze(0).expand(Bsz, -1, -1) # [B,S,nb] | |
| else: | |
| kv_comp = None | |
| comp_mask = None | |
| nb = 0 | |
| if self.mode == "csa": | |
| K_idx = self.indexer.keys(x) # [B,nb,idx_head_dim] | |
| idx, sel_mask = self.indexer.select(x, cQ, K_idx, positions, m, | |
| self.config.index_topk) | |
| # Gather selected compressed entries for each query | |
| kk = idx.size(-1) | |
| if kk == 0: | |
| kv_sel = kv_comp.new_zeros(Bsz, S, 0, c) | |
| sel_mask = sel_mask.new_zeros(Bsz, S, 0, dtype=torch.bool) | |
| else: | |
| idx_safe = idx.clamp(min=0) | |
| kv_comp_exp = kv_comp.unsqueeze(1).expand(-1, S, -1, -1) # [B,S,nb,c] | |
| kv_sel = torch.gather( | |
| kv_comp_exp, 2, | |
| idx_safe.unsqueeze(-1).expand(-1, -1, -1, c) | |
| ) # [B,S,kk,c] | |
| else: | |
| kv_sel = None | |
| sel_mask = None | |
| kk = 0 | |
| # ---------------- core attention ---------------- | |
| scale = 1.0 / math.sqrt(c) | |
| # Sliding-window logits: einsum over the shared-KV (single head broadcast) | |
| # q: [B,S,H,c], kv_sw: [B,S,c] | |
| sw_logits = torch.einsum("bthd,bjd->bthj", q, kv_sw) * scale # [B,S,H,S] | |
| if self.mode == "sw": | |
| mask = sw_mask.unsqueeze(2) # [B,S,1,S] | |
| logits = sw_logits.masked_fill(~mask, float("-inf")) | |
| sink = self.attn_sink.view(1, 1, -1, 1) | |
| probs = sink_softmax(logits, sink, dim=-1) | |
| kv_v = kv_sw.unsqueeze(1).expand(-1, S, -1, -1) # [B,S,S,c] (broadcast read) | |
| out = torch.einsum("bthj,btjd->bthd", probs, kv_v) # [B,S,H,c] | |
| elif self.mode == "hca": | |
| # logits over [compressed blocks (nb)] + [SW window (S)] | |
| comp_logits = torch.einsum("bthd,bjd->bthj", q, kv_comp) * scale # [B,S,H,nb] | |
| comp_logits = comp_logits.masked_fill(~comp_mask.unsqueeze(2), float("-inf")) | |
| sw_logits = sw_logits.masked_fill(~sw_mask.unsqueeze(2), float("-inf")) | |
| logits = torch.cat([comp_logits, sw_logits], dim=-1) # [B,S,H,nb+S] | |
| sink = self.attn_sink.view(1, 1, -1, 1) | |
| probs = sink_softmax(logits, sink, dim=-1) | |
| p_comp, p_sw = probs.split([nb, S], dim=-1) | |
| out = ( | |
| torch.einsum("bthj,bjd->bthd", p_comp, kv_comp) + | |
| torch.einsum("bthj,bjd->bthd", p_sw, kv_sw) | |
| ) # [B,S,H,c] | |
| else: # csa | |
| # logits over [selected (kk)] + [SW window (S)] | |
| sel_logits = torch.einsum("bthd,btjd->bthj", q, kv_sel) * scale # [B,S,H,kk] | |
| if kk > 0: | |
| sel_logits = sel_logits.masked_fill(~sel_mask.unsqueeze(2), float("-inf")) | |
| sw_logits = sw_logits.masked_fill(~sw_mask.unsqueeze(2), float("-inf")) | |
| logits = torch.cat([sel_logits, sw_logits], dim=-1) # [B,S,H,kk+S] | |
| sink = self.attn_sink.view(1, 1, -1, 1) | |
| probs = sink_softmax(logits, sink, dim=-1) | |
| p_sel, p_sw = probs.split([kk, S], dim=-1) | |
| out = ( | |
| (torch.einsum("bthj,btjd->bthd", p_sel, kv_sel) if kk > 0 else 0) + | |
| torch.einsum("bthj,bjd->bthd", p_sw, kv_sw) | |
| ) | |
| # Output RoPE-by-(-i) trick | |
| out = self._apply_output_rope(out, rope_cos, rope_sin, positions) | |
| return self._output_proj(out) | |
| # ============================================================================= | |
| # Clamped SwiGLU expert | |
| # ============================================================================= | |
| class SwiGLUExpert(nn.Module): | |
| """w1 = gate, w3 = up, w2 = down (matches official naming).""" | |
| def __init__(self, hidden_size: int, intermediate_size: int, limit: float): | |
| super().__init__() | |
| self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| self.limit = limit | |
| def forward(self, x): | |
| g = self.w1(x) | |
| u = self.w3(x) | |
| # V4 SwiGLU clamping: linear in [-limit, limit], gate <= limit | |
| u = torch.clamp(u, -self.limit, self.limit) | |
| g = torch.minimum(g, torch.full_like(g, self.limit)) | |
| return self.w2(F.silu(g) * u) | |
| # ============================================================================= | |
| # DeepseekMoE (sqrt-softplus routing, aux-loss-free) + Hash variant | |
| # ============================================================================= | |
| class MoEGate(nn.Module): | |
| """Gate parameters matching official ``inference/model.py:Gate``: | |
| Always present: | |
| - ``weight`` [n_routed_experts, hidden_size]: produces routing scores | |
| (sqrt(softplus) by default in V4) for BOTH hash and non-hash layers. | |
| For hash layers the score still defines per-token expert weights; | |
| only the *index selection* uses the hash table. | |
| Conditional: | |
| - ``bias`` [n_routed_experts] (non-hash only): aux-loss-free routing | |
| bias added to scores at top-k selection time. Stored as a learnable | |
| float32 parameter to match the official layout. | |
| - ``tid2eid`` [vocab_size, top_k] (hash only): non-trainable lookup | |
| table mapping token-id -> expert indices. | |
| """ | |
| def __init__(self, hidden_size: int, num_experts: int, vocab_size: int, | |
| hash_routing: bool, top_k: int): | |
| super().__init__() | |
| self.hash_routing = hash_routing | |
| # Gate weight is ALWAYS present — used to compute routing scores even | |
| # in hash-routed layers (only the index selection differs there). | |
| self.weight = nn.Parameter(torch.zeros(num_experts, hidden_size)) | |
| if hash_routing: | |
| # tid2eid is non-trainable; matches official (requires_grad=False) | |
| self.tid2eid = nn.Parameter( | |
| torch.zeros(vocab_size, top_k, dtype=torch.long), | |
| requires_grad=False, | |
| ) | |
| self.bias = None | |
| else: | |
| self.bias = nn.Parameter(torch.zeros(num_experts, dtype=torch.float32)) | |
| class DeepseekV4MoE(nn.Module): | |
| def __init__(self, config: DeepseekV4Config, hash_routing: bool): | |
| super().__init__() | |
| self.config = config | |
| self.hash_routing = hash_routing | |
| self.num_experts = config.n_routed_experts | |
| self.top_k = config.num_experts_per_tok | |
| self.norm_topk_prob = config.norm_topk_prob | |
| self.routed_scaling = config.routed_scaling_factor | |
| d = config.hidden_size | |
| inter = config.moe_intermediate_size | |
| limit = config.swiglu_limit | |
| self.gate = MoEGate(d, self.num_experts, config.vocab_size, | |
| hash_routing=hash_routing, top_k=self.top_k) | |
| self.experts = nn.ModuleList([ | |
| SwiGLUExpert(d, inter, limit) for _ in range(self.num_experts) | |
| ]) | |
| if config.n_shared_experts > 0: | |
| self.shared_experts = SwiGLUExpert(d, inter * config.n_shared_experts, limit) | |
| else: | |
| self.shared_experts = None | |
| def _routed_indices(self, x_flat: torch.Tensor, token_ids_flat: torch.Tensor): | |
| """Matches inference/model.py:Gate exactly. Hash layers still derive | |
| weights from the learned gate (only the index selection differs). | |
| """ | |
| # Score in fp32 for stability, matches official. | |
| logits = F.linear(x_flat.float(), self.gate.weight.float()) # [N, E] | |
| if self.config.scoring_func == "softmax": | |
| scores = logits.softmax(dim=-1) | |
| elif self.config.scoring_func == "sigmoid": | |
| scores = torch.sigmoid(logits) | |
| else: # sqrtsoftplus (V4 default) | |
| scores = F.softplus(logits).sqrt() | |
| original_scores = scores | |
| if self.hash_routing: | |
| idx = self.gate.tid2eid[token_ids_flat].long() # [N, K] | |
| else: | |
| biased = scores + self.gate.bias.float() | |
| idx = biased.topk(self.top_k, dim=-1).indices | |
| weights = original_scores.gather(-1, idx) | |
| if self.config.scoring_func != "softmax": | |
| weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-9) | |
| weights = weights * self.routed_scaling | |
| return idx, weights.to(x_flat.dtype) | |
| def forward(self, x: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: | |
| Bsz, S, D = x.shape | |
| N = Bsz * S | |
| x_flat = x.reshape(N, D) | |
| token_ids_flat = token_ids.reshape(N) | |
| idx, w = self._routed_indices(x_flat, token_ids_flat) # [N,K], [N,K] | |
| out = torch.zeros_like(x_flat) | |
| flat_idx = idx.reshape(-1) | |
| flat_w = w.reshape(-1) | |
| flat_tok = torch.arange(N, device=x.device).unsqueeze(-1).expand(-1, self.top_k).reshape(-1) | |
| for e in range(self.num_experts): | |
| mask = flat_idx == e | |
| if not mask.any(): | |
| continue | |
| t = flat_tok[mask] | |
| inp = x_flat[t] | |
| y = self.experts[e](inp) * flat_w[mask].unsqueeze(-1) | |
| out.index_add_(0, t, y) | |
| if self.shared_experts is not None: | |
| out = out + self.shared_experts(x_flat) | |
| return out.reshape(Bsz, S, D) | |
| # ============================================================================= | |
| # Decoder layer | |
| # ============================================================================= | |
| class DeepseekV4Layer(nn.Module): | |
| def __init__(self, config: DeepseekV4Config, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| compress_ratio = config.compress_ratios[layer_idx] | |
| is_hash = layer_idx < config.num_hash_layers | |
| self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.attn = DeepseekV4Attention(config, compress_ratio) | |
| self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.ffn = DeepseekV4MoE(config, hash_routing=is_hash) | |
| # mHC: parameter shapes match official ((2+n)*n outputs from a single | |
| # combined `_fn` matmul; 3 scalar `_scale` gates; mix_hc-sized `_base`). | |
| # Init: zeros for `_base` (so initial pre/post = sigmoid(0)+eps = 0.5+eps, | |
| # comb starts as a near-uniform softmax then Sinkhorn-projected); small | |
| # random for `_fn`; small `_scale`. | |
| n_hc = config.hc_mult | |
| mix_hc = (2 + n_hc) * n_hc | |
| flat = n_hc * config.hidden_size | |
| self.hc_attn_fn = nn.Parameter(torch.zeros(mix_hc, flat)) | |
| self.hc_ffn_fn = nn.Parameter(torch.zeros(mix_hc, flat)) | |
| nn.init.normal_(self.hc_attn_fn, mean=0.0, std=config.initializer_range) | |
| nn.init.normal_(self.hc_ffn_fn, mean=0.0, std=config.initializer_range) | |
| self.hc_attn_base = nn.Parameter(torch.zeros(mix_hc)) | |
| self.hc_ffn_base = nn.Parameter(torch.zeros(mix_hc)) | |
| self.hc_attn_scale = nn.Parameter(torch.full((3,), 1e-2)) | |
| self.hc_ffn_scale = nn.Parameter(torch.full((3,), 1e-2)) | |
| def forward(self, X: torch.Tensor, mhc: MHC, token_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| rope_cos, rope_sin, rope_cos_c, rope_sin_c, | |
| pad_mask: Optional[torch.Tensor]) -> torch.Tensor: | |
| # Attention sub-block: hc_pre collapses [B,S,n,d] -> [B,S,d] via `pre` weights; | |
| # hc_post produces [B,S,n,d] = post * new_x + comb @ residual. | |
| residual = X | |
| pre, post, comb = mhc.gen_params(X, self.hc_attn_base, self.hc_attn_fn, | |
| self.hc_attn_scale) | |
| sub_in = MHC.hc_pre(X, pre) | |
| sub_in = self.attn_norm(sub_in) | |
| attn_out = self.attn(sub_in, positions, rope_cos, rope_sin, | |
| rope_cos_c, rope_sin_c, pad_mask) | |
| X = MHC.hc_post(attn_out, residual, post, comb) | |
| # FFN sub-block | |
| residual = X | |
| pre, post, comb = mhc.gen_params(X, self.hc_ffn_base, self.hc_ffn_fn, | |
| self.hc_ffn_scale) | |
| sub_in = MHC.hc_pre(X, pre) | |
| sub_in = self.ffn_norm(sub_in) | |
| ffn_out = self.ffn(sub_in, token_ids) | |
| X = MHC.hc_post(ffn_out, residual, post, comb) | |
| return X | |
| # ============================================================================= | |
| # MTP module (V3-style single-step) | |
| # ============================================================================= | |
| class DeepseekV4MTPModule(nn.Module): | |
| """One MTP step. Mirrors the official ``MTPBlock`` (which inherits from Block): | |
| Pre-block: | |
| e = embed(input_ids); e = enorm(e); X = hnorm(X) | |
| X = e_proj(e).unsqueeze(2) + h_proj(X) # broadcast e across hc copies | |
| Block: | |
| full hc_pre / attn / hc_post / hc_pre / ffn / hc_post | |
| Post-block: | |
| logits = head( hc_head_collapse(X), through final norm + lm_head ) | |
| hc_attn_* and hc_ffn_* shapes: [(2+n)*n, n*d] / [(2+n)*n] / [3] (full mHC) | |
| hc_head_* shapes: [n, n*d] / [n] / [1] (pre-only) | |
| """ | |
| def __init__(self, config: DeepseekV4Config): | |
| super().__init__() | |
| d = config.hidden_size | |
| self.enorm = RMSNorm(d, eps=config.rms_norm_eps) | |
| self.hnorm = RMSNorm(d, eps=config.rms_norm_eps) | |
| self.e_proj = nn.Linear(d, d, bias=False) | |
| self.h_proj = nn.Linear(d, d, bias=False) | |
| # One transformer block (dense / pure-SW attention) | |
| self.attn_norm = RMSNorm(d, eps=config.rms_norm_eps) | |
| self.attn = DeepseekV4Attention(config, compress_ratio=0) | |
| self.ffn_norm = RMSNorm(d, eps=config.rms_norm_eps) | |
| self.ffn = DeepseekV4MoE(config, hash_routing=False) | |
| self.norm = RMSNorm(d, eps=config.rms_norm_eps) | |
| n_hc = config.hc_mult | |
| mix_hc = (2 + n_hc) * n_hc | |
| flat = n_hc * d | |
| # Full mHC for attn and ffn sub-blocks | |
| for prefix in ("hc_attn", "hc_ffn"): | |
| fn_p = nn.Parameter(torch.zeros(mix_hc, flat)) | |
| nn.init.normal_(fn_p, mean=0.0, std=config.initializer_range) | |
| setattr(self, f"{prefix}_fn", fn_p) | |
| setattr(self, f"{prefix}_base", nn.Parameter(torch.zeros(mix_hc))) | |
| setattr(self, f"{prefix}_scale", nn.Parameter(torch.full((3,), 1e-2))) | |
| # Pre-only mHC for head collapse | |
| head_fn = nn.Parameter(torch.zeros(n_hc, flat)) | |
| nn.init.normal_(head_fn, mean=0.0, std=config.initializer_range) | |
| self.hc_head_fn = head_fn | |
| self.hc_head_base = nn.Parameter(torch.zeros(n_hc)) | |
| self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2)) | |
| def forward(self, X: torch.Tensor, embed: nn.Embedding, head: nn.Linear, | |
| input_ids: torch.Tensor, mhc: MHC, positions: torch.Tensor, | |
| rope_cos, rope_sin, rope_cos_c, rope_sin_c, | |
| pad_mask: Optional[torch.Tensor]) -> torch.Tensor: | |
| """X: [B,S,n,d] residual stream from main model. Returns logits [B,S,V].""" | |
| e = embed(input_ids) # [B,S,d] | |
| e = self.enorm(e) | |
| Xn = self.hnorm(X) # [B,S,n,d] | |
| # Mix in next-token embedding broadcast across hc copies | |
| X = self.e_proj(e).unsqueeze(-2) + self.h_proj(Xn) # [B,S,n,d] | |
| # Attention sub-block via full mHC | |
| residual = X | |
| pre, post, comb = mhc.gen_params(X, self.hc_attn_base, self.hc_attn_fn, | |
| self.hc_attn_scale) | |
| sub_in = MHC.hc_pre(X, pre) | |
| sub_in = self.attn_norm(sub_in) | |
| attn_out = self.attn(sub_in, positions, rope_cos, rope_sin, | |
| rope_cos_c, rope_sin_c, pad_mask) | |
| X = MHC.hc_post(attn_out, residual, post, comb) | |
| # FFN sub-block via full mHC | |
| residual = X | |
| pre, post, comb = mhc.gen_params(X, self.hc_ffn_base, self.hc_ffn_fn, | |
| self.hc_ffn_scale) | |
| sub_in = MHC.hc_pre(X, pre) | |
| sub_in = self.ffn_norm(sub_in) | |
| ffn_out = self.ffn(sub_in, input_ids) | |
| X = MHC.hc_post(ffn_out, residual, post, comb) | |
| # Head: pre-only mHC collapse, then norm, then shared lm_head | |
| head_pre = mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base, | |
| self.hc_head_scale) | |
| h_out = MHC.hc_pre(X, head_pre) | |
| h_out = self.norm(h_out) | |
| return head(h_out) | |
| # ============================================================================= | |
| # PreTrainedModel base + top-level classes | |
| # ============================================================================= | |
| class DeepseekV4PreTrainedModel(PreTrainedModel): | |
| config_class = DeepseekV4Config | |
| base_model_prefix = "" # flat layout, no `model.` prefix | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["DeepseekV4Layer", "DeepseekV4MTPModule"] | |
| def _init_weights(self, module): | |
| std = self.config.initializer_range | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=std) | |
| class DeepseekV4Model(DeepseekV4PreTrainedModel): | |
| """The base model exposes the same fields as ForCausalLM (flat layout) | |
| so that names match the official safetensors. We instantiate it as part of | |
| ForCausalLM rather than wrapping it. | |
| """ | |
| def __init__(self, config: DeepseekV4Config): | |
| super().__init__(config) | |
| self.embed = nn.Embedding(config.vocab_size, config.hidden_size, | |
| padding_idx=config.pad_token_id) | |
| self.layers = nn.ModuleList([ | |
| DeepseekV4Layer(config, i) for i in range(config.num_hidden_layers) | |
| ]) | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| # Head-side mHC (collapses [B,S,n_hc,d] residual stream back to [B,S,d]) | |
| # Head-side mHC: ONLY computes the `pre` (collapse hc -> 1) weights, | |
| # so shapes are [hc, hc*d] / [hc] / [1] (matching official ParallelHead). | |
| n_hc = config.hc_mult | |
| flat = n_hc * config.hidden_size | |
| self.hc_head_fn = nn.Parameter(torch.zeros(n_hc, flat)) | |
| nn.init.normal_(self.hc_head_fn, mean=0.0, std=config.initializer_range) | |
| self.hc_head_base = nn.Parameter(torch.zeros(n_hc)) | |
| self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2)) | |
| self._mhc = MHC(config.hidden_size, config.hc_mult, | |
| sinkhorn_iters=config.hc_sinkhorn_iters, | |
| eps=config.rms_norm_eps) | |
| # MTP modules | |
| self.mtp = nn.ModuleList([ | |
| DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers) | |
| ]) | |
| self.post_init() | |
| def _build_rope(self, max_len: int, device, dtype): | |
| rope_dim = self.config.qk_rope_head_dim | |
| cos, sin = build_rope_cache(max_len, rope_dim, self.config.rope_theta, device, dtype) | |
| cos_c, sin_c = build_rope_cache(max_len, rope_dim, | |
| self.config.compress_rope_theta, device, dtype) | |
| return cos, sin, cos_c, sin_c | |
| def forward(self, input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| **kwargs) -> BaseModelOutputWithPast: | |
| Bsz, S = input_ids.shape | |
| device = input_ids.device | |
| h = self.embed(input_ids) # [B,S,d] | |
| # Lift into mHC residual stream [B,S,n_hc,d] | |
| n_hc = self.config.hc_mult | |
| X = h.unsqueeze(-2).expand(-1, -1, n_hc, -1).contiguous() | |
| if position_ids is None: | |
| positions = torch.arange(S, device=device) | |
| else: | |
| positions = position_ids[0] | |
| # Cap RoPE table at S to keep memory bounded (model still supports up to max_position_embeddings) | |
| rope_cos, rope_sin, rope_cos_c, rope_sin_c = self._build_rope(S, device, h.dtype) | |
| pad_mask = attention_mask.bool() if attention_mask is not None else None | |
| for layer in self.layers: | |
| X = layer(X, self._mhc, input_ids, positions, | |
| rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask) | |
| # Head-side mHC: collapse residual back to [B,S,d] using A_l | |
| # Head mHC: pre-only collapse hc -> 1, then final norm | |
| head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base, | |
| self.hc_head_scale) | |
| h_out = MHC.hc_pre(X, head_pre) | |
| h_out = self.norm(h_out) | |
| return BaseModelOutputWithPast(last_hidden_state=h_out) | |
| class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel): | |
| _tied_weights_keys: List[str] = [] # untied (matches V4 config) | |
| def __init__(self, config: DeepseekV4Config): | |
| super().__init__(config) | |
| # Flat layout — instantiate the base model's fields directly on self | |
| # so safetensors keys come out as `embed.weight`, `layers.0...`, etc. | |
| self.embed = nn.Embedding(config.vocab_size, config.hidden_size, | |
| padding_idx=config.pad_token_id) | |
| self.layers = nn.ModuleList([ | |
| DeepseekV4Layer(config, i) for i in range(config.num_hidden_layers) | |
| ]) | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Head-side mHC: ONLY computes the `pre` (collapse hc -> 1) weights, | |
| # so shapes are [hc, hc*d] / [hc] / [1] (matching official ParallelHead). | |
| n_hc = config.hc_mult | |
| flat = n_hc * config.hidden_size | |
| self.hc_head_fn = nn.Parameter(torch.zeros(n_hc, flat)) | |
| nn.init.normal_(self.hc_head_fn, mean=0.0, std=config.initializer_range) | |
| self.hc_head_base = nn.Parameter(torch.zeros(n_hc)) | |
| self.hc_head_scale = nn.Parameter(torch.full((1,), 1e-2)) | |
| self._mhc = MHC(config.hidden_size, config.hc_mult, | |
| sinkhorn_iters=config.hc_sinkhorn_iters, | |
| eps=config.rms_norm_eps) | |
| self.mtp = nn.ModuleList([ | |
| DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers) | |
| ]) | |
| self.post_init() | |
| # HF auto methods | |
| def get_input_embeddings(self): | |
| return self.embed | |
| def set_input_embeddings(self, value): | |
| self.embed = value | |
| def get_output_embeddings(self): | |
| return self.head | |
| def set_output_embeddings(self, new): | |
| self.head = new | |
| def _backbone(self, input_ids, attention_mask, position_ids): | |
| """Runs embed -> hc-expand -> N layers and returns BOTH the post-layer | |
| residual stream X (shape [B,S,n_hc,d], needed by MTP) and the head-collapsed | |
| hidden state (shape [B,S,d], needed by lm_head). | |
| """ | |
| Bsz, S = input_ids.shape | |
| device = input_ids.device | |
| h = self.embed(input_ids) | |
| n_hc = self.config.hc_mult | |
| X = h.unsqueeze(-2).expand(-1, -1, n_hc, -1).contiguous() | |
| if position_ids is None: | |
| positions = torch.arange(S, device=device) | |
| else: | |
| positions = position_ids[0] | |
| rope_dim = self.config.qk_rope_head_dim | |
| rope_cos, rope_sin = build_rope_cache(S, rope_dim, self.config.rope_theta, | |
| device, h.dtype) | |
| rope_cos_c, rope_sin_c = build_rope_cache(S, rope_dim, | |
| self.config.compress_rope_theta, | |
| device, h.dtype) | |
| pad_mask = attention_mask.bool() if attention_mask is not None else None | |
| for layer in self.layers: | |
| X = layer(X, self._mhc, input_ids, positions, | |
| rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask) | |
| # Head mHC: pre-only collapse hc -> 1, then final norm | |
| head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base, | |
| self.hc_head_scale) | |
| h_out = MHC.hc_pre(X, head_pre) | |
| h_out = self.norm(h_out) | |
| return X, h_out, positions, rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask | |
| def forward(self, input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| use_mtp: bool = False, | |
| **kwargs) -> CausalLMOutputWithPast: | |
| X, hidden, positions, rc, rs, rcc, rsc, pad_mask = self._backbone( | |
| input_ids, attention_mask, position_ids | |
| ) | |
| logits = self.head(hidden) | |
| loss = None | |
| mtp_logits_list = [] | |
| if labels is not None: | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| # MTP: each step k predicts token at offset +(k+2). Feed embedding of | |
| # the next token shifted by (k+1) into the MTP module along with the | |
| # current residual stream. | |
| for k, mtp in enumerate(self.mtp): | |
| shift = k + 1 | |
| next_ids = F.pad(input_ids[:, shift:], (0, shift), value=0) | |
| mtp_logits = mtp(X, self.embed, self.head, next_ids, self._mhc, | |
| positions, rc, rs, rcc, rsc, pad_mask) | |
| mtp_target = F.pad(labels[:, shift + 1:], (0, shift + 1), value=-100) | |
| mtp_loss = F.cross_entropy( | |
| mtp_logits.view(-1, mtp_logits.size(-1)), | |
| mtp_target.view(-1), | |
| ignore_index=-100, | |
| ) | |
| loss = loss + 0.3 * mtp_loss | |
| mtp_logits_list.append(mtp_logits) | |
| return CausalLMOutputWithPast(loss=loss, logits=logits) | |