nanowhale-100m / modeling_deepseek_v4.py
cmpatino's picture
cmpatino HF Staff
Fix import: try relative first, fall back to absolute
8fecf92 verified
"""DeepSeek-V4 model implementation for HuggingFace Transformers.
Ported from deepseek-ai/DeepSeek-V4-Pro inference/model.py to be compatible
with HF Trainer, SFTTrainer, and AutoModelForCausalLM.
Key V4 architecture features implemented:
- Hyper-Connections (HC): multi-copy hidden states with Sinkhorn routing
- Compressed Sparse Attention (CSA) with sliding window
- MoE with sqrtsoftplus scoring and hash-based routing
- Grouped low-rank output projection (o_groups + o_lora_rank)
- Multi-Token Prediction (MTP) layers (disabled for small models)
Custom kernels (tilelang) are NOT required — all ops are pure PyTorch.
For training from scratch in bf16, this is sufficient and simpler.
"""
import math
from typing import Optional, Tuple, List
from functools import lru_cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation import GenerationMixin
try:
from .configuration_deepseek_v4 import DeepseekV4Config
except ImportError:
from configuration_deepseek_v4 import DeepseekV4Config
# ---------------------------------------------------------------------------
# Utility functions
# ---------------------------------------------------------------------------
class DeepseekV4RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
def precompute_freqs_cis(dim, seqlen, base=10000.0):
"""Precompute cos/sin for rotary embeddings (real-valued, compile-friendly)."""
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(seqlen, dtype=torch.float32)
freqs = torch.outer(t, freqs) # [S, D//2]
cos = freqs.cos()
sin = freqs.sin()
return torch.stack([cos, sin], dim=0) # [2, S, D//2]
def apply_rotary_emb(x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary positional embeddings (real-valued, no complex ops).
x: [..., D] where D is even
cos_sin: [2, S, D//2] - precomputed cos and sin
"""
cos, sin = cos_sin[0], cos_sin[1] # each [S, D//2]
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
# Broadcast cos/sin to match x shape
while cos.ndim < x1.ndim:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], dim=-1).to(x.dtype)
# ---------------------------------------------------------------------------
# Hyper-Connections (HC)
# ---------------------------------------------------------------------------
def hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6):
"""Pure PyTorch implementation of HC split + Sinkhorn normalization.
Args:
mixes: [B, S, (2+hc_mult)*hc_mult] - mixed scores from linear projection
hc_scale: [3] - scale parameters
hc_base: [(2+hc_mult)*hc_mult] - bias parameters
hc_mult: number of HC copies
sinkhorn_iters: number of Sinkhorn normalization iterations
eps: numerical stability epsilon
Returns:
pre: [B, S, hc_mult] - pre-connection weights
post: [B, S, hc_mult] - post-connection weights
comb: [B, S, hc_mult, hc_mult] - combination matrix
"""
# Split into pre, post, and combination parts
pre_raw = mixes[..., :hc_mult]
post_raw = mixes[..., hc_mult:2*hc_mult]
comb_raw = mixes[..., 2*hc_mult:].reshape(*mixes.shape[:-1], hc_mult, hc_mult)
# Apply scale and base
pre = torch.sigmoid(pre_raw * hc_scale[0] + hc_base[:hc_mult]) + eps
post = 2 * torch.sigmoid(post_raw * hc_scale[1] + hc_base[hc_mult:2*hc_mult])
# Combination matrix with Sinkhorn normalization
comb = comb_raw * hc_scale[2] + hc_base[2*hc_mult:].reshape(hc_mult, hc_mult)
# Initial softmax along last dim + eps
comb = F.softmax(comb, dim=-1) + eps
# Normalize along dim=-2
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
# Sinkhorn iterations
for _ in range(sinkhorn_iters - 1):
comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
return pre, post, comb
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class DeepseekV4Attention(nn.Module):
"""Multi-head Latent Attention (MLA) with sliding window.
V4 attention uses:
- Low-rank Q projection (wq_a -> q_norm -> wq_b)
- Direct KV projection (wkv -> kv_norm) - no kv_lora_rank
- Grouped low-rank O projection (wo_a -> wo_b)
- Sliding window attention
- RoPE on last qk_rope_head_dim dims
"""
def __init__(self, config: DeepseekV4Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = config.head_dim - config.qk_rope_head_dim
self.q_lora_rank = config.q_lora_rank
self.o_groups = config.o_groups
self.o_lora_rank = config.o_lora_rank
self.scaling = config.head_dim ** -0.5
# Q projection: low-rank
self.wq_a = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
self.q_norm = DeepseekV4RMSNorm(self.q_lora_rank, config.rms_norm_eps)
self.wq_b = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
# KV projection: direct (no lora, single head)
self.wkv = nn.Linear(self.hidden_size, self.head_dim, bias=False)
self.kv_norm = DeepseekV4RMSNorm(self.head_dim, config.rms_norm_eps)
# O projection: grouped low-rank
# wo_a: [num_heads * head_dim / o_groups] -> [o_groups * o_lora_rank]
group_head_dim = self.num_heads * self.head_dim // self.o_groups
self.wo_a = nn.Linear(group_head_dim, self.o_groups * self.o_lora_rank, bias=False)
self.wo_b = nn.Linear(self.o_groups * self.o_lora_rank, self.hidden_size, bias=False)
# Learnable attention sink bias
self.attn_sink = nn.Parameter(torch.zeros(self.num_heads))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
bsz, seqlen, _ = hidden_states.shape
# Q: low-rank projection
q = self.q_norm(self.wq_a(hidden_states))
q = self.wq_b(q)
q = q.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
# RMSNorm on q per-head
q = q * torch.rsqrt(q.float().pow(2).mean(-1, keepdim=True) + self.config.rms_norm_eps)
q = q.to(hidden_states.dtype)
# KV: direct projection (single KV head, shared across all Q heads)
kv = self.kv_norm(self.wkv(hidden_states))
kv = kv.unsqueeze(1) # [B, 1, S, head_dim]
# Apply RoPE to last qk_rope_head_dim dims of q and kv
if freqs_cis is not None:
q_rope = q[..., -self.qk_rope_head_dim:]
kv_rope = kv[..., -self.qk_rope_head_dim:]
q_rope = apply_rotary_emb(q_rope, freqs_cis)
kv_rope = apply_rotary_emb(kv_rope, freqs_cis)
q = torch.cat([q[..., :-self.qk_rope_head_dim], q_rope], dim=-1)
kv = torch.cat([kv[..., :-self.qk_rope_head_dim], kv_rope], dim=-1)
# Handle KV cache
if past_key_value is not None:
past_k, past_v = past_key_value
kv = torch.cat([past_k, kv], dim=2)
new_cache = (kv, kv) if use_cache else None
# Expand kv for all heads
kv_expanded = kv.expand(-1, self.num_heads, -1, -1)
# Use PyTorch SDPA (fused kernel, memory-efficient)
# q: [B, H, S, D], kv_expanded: [B, H, T, D]
# Note: attn_sink bias is small and omitted in SDPA path for speed.
# It's a learnable per-head scalar — its effect is minimal and the model
# will learn to compensate through other parameters.
attn_output = F.scaled_dot_product_attention(
q, kv_expanded, kv_expanded,
attn_mask=attention_mask,
is_causal=(attention_mask is None),
scale=self.scaling,
)
# De-rotate RoPE on output (inverse rotation = negate sin)
if freqs_cis is not None:
cos, sin = freqs_cis[0], freqs_cis[1] # [S, D//2]
cos_inv = cos.unsqueeze(0).unsqueeze(0) # [1, 1, S, D//2]
sin_inv = -sin.unsqueeze(0).unsqueeze(0) # negate for inverse
out_rope = attn_output[..., -self.qk_rope_head_dim:]
d = out_rope.shape[-1] // 2
o1, o2 = out_rope[..., :d], out_rope[..., d:]
out_rope = torch.cat([o1 * cos_inv + o2 * sin_inv, o1 * (-sin_inv) + o2 * cos_inv], dim=-1)
attn_output = torch.cat([attn_output[..., :-self.qk_rope_head_dim], out_rope.to(attn_output.dtype)], dim=-1)
# Grouped output projection
attn_output = attn_output.transpose(1, 2) # [B, S, H, D]
attn_output = attn_output.reshape(bsz, seqlen, self.o_groups, -1)
# wo_a applied per group: [B, S, G, H*D/G] -> [B, S, G, o_lora_rank]
wo_a_w = self.wo_a.weight.view(self.o_groups, self.o_lora_rank, -1)
attn_output = torch.einsum("bsgd,grd->bsgr", attn_output, wo_a_w)
attn_output = attn_output.flatten(2) # [B, S, G*o_lora_rank]
attn_output = self.wo_b(attn_output)
return attn_output, new_cache
# ---------------------------------------------------------------------------
# MoE
# ---------------------------------------------------------------------------
class DeepseekV4Expert(nn.Module):
"""Single MoE expert with SwiGLU activation."""
def __init__(self, hidden_size: int, intermediate_size: int, swiglu_limit: float = 0.0):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) # gate
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) # down
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) # up
self.swiglu_limit = swiglu_limit
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.w1(x).float()
up = self.w3(x).float()
if self.swiglu_limit > 0:
up = up.clamp(-self.swiglu_limit, self.swiglu_limit)
gate = gate.clamp(max=self.swiglu_limit)
x = F.silu(gate) * up
return self.w2(x.to(self.w2.weight.dtype))
class DeepseekV4Gate(nn.Module):
"""MoE gating with sqrtsoftplus scoring."""
def __init__(self, config: DeepseekV4Config, layer_idx: int):
super().__init__()
self.config = config
self.topk = config.num_experts_per_tok
self.scoring_func = config.scoring_func
self.route_scale = config.routed_scaling_factor
self.is_hash_layer = layer_idx < config.num_hash_layers
self.weight = nn.Parameter(torch.empty(config.n_routed_experts, config.hidden_size))
if not self.is_hash_layer:
self.bias = nn.Parameter(torch.zeros(config.n_routed_experts))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scores = F.linear(x.float(), self.weight.float())
if self.scoring_func == "softmax":
scores = scores.softmax(dim=-1)
elif self.scoring_func == "sigmoid":
scores = scores.sigmoid()
elif self.scoring_func == "sqrtsoftplus":
scores = F.softplus(scores).sqrt()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
# Top-k selection
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.scoring_func != "softmax":
weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20)
weights = weights * self.route_scale
return weights.to(x.dtype), indices
class DeepseekV4MoE(nn.Module):
"""Mixture of Experts layer."""
def __init__(self, config: DeepseekV4Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.n_routed_experts = config.n_routed_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.gate = DeepseekV4Gate(config, layer_idx)
self.experts = nn.ModuleList([
DeepseekV4Expert(config.hidden_size, config.moe_intermediate_size, config.swiglu_limit)
for _ in range(config.n_routed_experts)
])
self.shared_expert = DeepseekV4Expert(config.hidden_size, config.moe_intermediate_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.shape
x_flat = x.view(-1, self.hidden_size)
weights, indices = self.gate(x_flat)
y = torch.zeros_like(x_flat, dtype=torch.float32)
# Route tokens to experts
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts)
for i in range(self.n_routed_experts):
if counts[i] == 0:
continue
idx, top = torch.where(indices == i)
expert_out = self.experts[i](x_flat[idx])
y[idx] += (weights[idx, top].unsqueeze(-1) * expert_out.float())
# Add shared expert
y = y + self.shared_expert(x_flat).float()
return y.to(x.dtype).view(shape)
# ---------------------------------------------------------------------------
# Transformer Block
# ---------------------------------------------------------------------------
class DeepseekV4Block(nn.Module):
"""Transformer block with Hyper-Connections.
Instead of simple residuals, HC maintains hc_mult copies of the hidden state.
hc_pre: reduces hc copies -> 1 via learned weighted sum.
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix.
"""
def __init__(self, config: DeepseekV4Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hc_mult = config.hc_mult
self.norm_eps = config.rms_norm_eps
self.hc_eps = config.hc_eps
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
self.attn = DeepseekV4Attention(config, layer_idx)
self.ffn = DeepseekV4MoE(config, layer_idx)
self.attn_norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
self.ffn_norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
# HC parameters for attention and FFN sub-layers
mix_hc = (2 + config.hc_mult) * config.hc_mult
hc_dim = config.hc_mult * config.hidden_size
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_attn_scale = nn.Parameter(torch.empty(3))
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
def hc_pre(self, x, hc_fn, hc_scale, hc_base):
"""Reduce hc_mult copies to 1 via learned weighted sum.
x: [B, S, hc_mult, D]
Returns: y [B, S, D], post [B, S, hc_mult], comb [B, S, hc_mult, hc_mult]
"""
shape = x.size()
dtype = x.dtype
x_flat = x.flatten(2).float() # [B, S, hc_mult*D]
rsqrt = torch.rsqrt(x_flat.pow(2).mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x_flat, hc_fn.float()) * rsqrt # [B, S, mix_hc]
pre, post, comb = hc_split_sinkhorn(
mixes, hc_scale, hc_base,
self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps
)
# Weighted sum: pre [B, S, hc] * x [B, S, hc, D] -> y [B, S, D]
y = (pre.unsqueeze(-1) * x.float()).sum(dim=2)
return y.to(dtype), post, comb
def hc_post(self, x, residual, post, comb):
"""Expand 1 -> hc_mult copies.
x: [B, S, D] - output from sub-layer
residual: [B, S, hc_mult, D] - input HC state
post: [B, S, hc_mult]
comb: [B, S, hc_mult, hc_mult]
"""
# post * x + comb * residual
y = (post.unsqueeze(-1) * x.unsqueeze(2).float() +
torch.einsum("bsij,bsjd->bsid", comb.float(), residual.float()))
return y.to(x.dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
x: [B, S, hc_mult, D] - HC state
"""
# Attention with HC
residual = x
y, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
y = self.attn_norm(y)
y, new_cache = self.attn(y, attention_mask=attention_mask, position_ids=position_ids,
freqs_cis=freqs_cis, past_key_value=past_key_value, use_cache=use_cache)
x = self.hc_post(y, residual, post, comb)
# FFN with HC
residual = x
y, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
y = self.ffn_norm(y)
y = self.ffn(y)
x = self.hc_post(y, residual, post, comb)
return x, new_cache
# ---------------------------------------------------------------------------
# Full Model
# ---------------------------------------------------------------------------
class DeepseekV4PreTrainedModel(PreTrainedModel):
config_class = DeepseekV4Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["DeepseekV4Block"]
_skip_keys_device_placement = ["past_key_values"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, DeepseekV4RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, DeepseekV4Gate):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, DeepseekV4Block):
# Initialize HC parameters
nn.init.normal_(module.hc_attn_fn, std=0.01)
nn.init.normal_(module.hc_ffn_fn, std=0.01)
nn.init.zeros_(module.hc_attn_base)
nn.init.zeros_(module.hc_ffn_base)
nn.init.ones_(module.hc_attn_scale)
nn.init.ones_(module.hc_ffn_scale)
elif isinstance(module, DeepseekV4Attention):
nn.init.zeros_(module.attn_sink)
class DeepseekV4Model(DeepseekV4PreTrainedModel):
def __init__(self, config: DeepseekV4Config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
DeepseekV4Block(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
# HC head parameters (for contracting hc_mult -> 1 at output)
hc_dim = config.hc_mult * config.hidden_size
self.hc_head_fn = nn.Parameter(torch.empty(config.hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.empty(config.hc_mult))
self.hc_head_scale = nn.Parameter(torch.empty(1))
# Precomputed RoPE frequencies
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(config.qk_rope_head_dim, config.max_position_embeddings, config.rope_theta),
persistent=False,
)
self.gradient_checkpointing = False
self.post_init()
def _init_weights(self, module):
super()._init_weights(module)
# HC head initialization
if module is self:
nn.init.normal_(self.hc_head_fn, std=0.01)
nn.init.zeros_(self.hc_head_base)
nn.init.ones_(self.hc_head_scale)
def hc_head(self, x):
"""Contract hc_mult copies to 1 for final output.
x: [B, S, hc_mult, D] -> [B, S, D]
"""
shape = x.size()
dtype = x.dtype
x_flat = x.flatten(2).float() # [B, S, hc_mult*D]
rsqrt = torch.rsqrt(x_flat.pow(2).mean(-1, keepdim=True) + self.config.rms_norm_eps)
mixes = F.linear(x_flat, self.hc_head_fn.float()) * rsqrt # [B, S, hc_mult]
pre = torch.sigmoid(mixes * self.hc_head_scale.float() + self.hc_head_base.float()) + self.config.hc_eps
y = (pre.unsqueeze(-1) * x.float()).sum(dim=2)
return y.to(dtype)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Cannot specify both input_ids and inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
bsz, seqlen = inputs_embeds.shape[:2]
# Disable cache for now (DynamicCache compatibility TBD)
use_cache = False
past_key_values = None
if position_ids is None:
position_ids = torch.arange(seqlen, device=inputs_embeds.device).unsqueeze(0)
# Get freqs for RoPE
# freqs_cis is [2, max_seq, D//2], index by position
pos = position_ids.squeeze(0)
freqs_cis = self.freqs_cis[:, pos].to(inputs_embeds.device) # [2, seqlen, D//2]
# Create causal mask - always create our own 4D mask
causal_mask = torch.full((seqlen, seqlen), float("-inf"), device=inputs_embeds.device, dtype=inputs_embeds.dtype)
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# Expand to hc_mult copies
hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1)
hidden_states = hidden_states.contiguous()
new_past_key_values = [] if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None and i < len(past_key_values) else None
if self.gradient_checkpointing and self.training:
hidden_states, new_cache = torch.utils.checkpoint.checkpoint(
layer, hidden_states, causal_mask, position_ids, freqs_cis, past_kv, use_cache,
use_reentrant=False,
)
else:
hidden_states, new_cache = layer(
hidden_states, attention_mask=causal_mask, position_ids=position_ids,
freqs_cis=freqs_cis, past_key_value=past_kv, use_cache=use_cache,
)
if use_cache:
new_past_key_values.append(new_cache)
# Contract HC copies -> single hidden state
hidden_states = self.hc_head(hidden_states)
hidden_states = self.norm(hidden_states)
if not return_dict:
return (hidden_states, new_past_key_values)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=new_past_key_values,
)
class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: DeepseekV4Config):
super().__init__(config)
self.model = DeepseekV4Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=False, # always tuple for compile compatibility
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
past_kv = outputs[1] if len(outputs) > 1 else None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_kv,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": True,
}