cmpatino HF Staff commited on
Commit
6e9a78e
·
verified ·
1 Parent(s): fbbd400

Upload SmolDeepSeek-V4 100M pretrained model (5000 steps on FineWeb-Edu)

Browse files
chat_template.jinja ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '
2
+
3
+ ' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if add_generation_prompt %}{{'<|Assistant|>'}}{% endif %}
config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV4ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "compress_ratios": [
9
+ 0,
10
+ 0,
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 0,
15
+ 0,
16
+ 0,
17
+ 0
18
+ ],
19
+ "compress_rope_theta": 160000.0,
20
+ "dtype": "float32",
21
+ "eos_token_id": 1,
22
+ "hc_eps": 1e-06,
23
+ "hc_mult": 4,
24
+ "hc_sinkhorn_iters": 2,
25
+ "head_dim": 96,
26
+ "hidden_act": "silu",
27
+ "hidden_size": 320,
28
+ "index_head_dim": 128,
29
+ "index_n_heads": 64,
30
+ "index_topk": 512,
31
+ "initializer_range": 0.02,
32
+ "max_position_embeddings": 2048,
33
+ "model_type": "deepseek_v4",
34
+ "moe_intermediate_size": 640,
35
+ "n_routed_experts": 4,
36
+ "n_shared_experts": 1,
37
+ "nope_head_dim": 64,
38
+ "norm_topk_prob": true,
39
+ "num_attention_heads": 8,
40
+ "num_experts_per_tok": 2,
41
+ "num_hash_layers": 0,
42
+ "num_hidden_layers": 8,
43
+ "num_key_value_heads": 1,
44
+ "num_nextn_predict_layers": 1,
45
+ "o_groups": 2,
46
+ "o_lora_rank": 80,
47
+ "pad_token_id": 1,
48
+ "q_lora_rank": 160,
49
+ "qk_rope_head_dim": 32,
50
+ "rms_norm_eps": 1e-06,
51
+ "rope_parameters": {
52
+ "rope_theta": 10000.0,
53
+ "rope_type": "default"
54
+ },
55
+ "rope_theta": 10000.0,
56
+ "routed_scaling_factor": 1.5,
57
+ "scoring_func": "sqrtsoftplus",
58
+ "sliding_window": 128,
59
+ "swiglu_limit": 0.0,
60
+ "tie_word_embeddings": false,
61
+ "topk_method": "noaux_tc",
62
+ "transformers_version": "5.6.2",
63
+ "use_cache": true,
64
+ "vocab_size": 129280,
65
+ "auto_map": {
66
+ "AutoConfig": "configuration_deepseek_v4.DeepseekV4Config",
67
+ "AutoModelForCausalLM": "modeling_deepseek_v4.DeepseekV4ForCausalLM"
68
+ }
69
+ }
configuration_deepseek_v4.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepSeek-V4 model configuration.
2
+
3
+ Adapted from the DeepSeek-V4 inference config (deepseek-ai/DeepSeek-V4-Pro)
4
+ and the HF Transformers DeepSeek-V3 config for HF compatibility.
5
+
6
+ Key V4-specific features vs V3:
7
+ - Hyper-Connections (HC): multi-copy hidden states with Sinkhorn routing
8
+ - Compressed Sparse Attention (CSA): compression + sliding window + sparse indexing
9
+ - New MoE routing: sqrtsoftplus scoring, hash-based routing for first layers
10
+ - Large head_dim (512), o_groups/o_lora_rank for grouped output projection
11
+ - No kv_lora_rank (replaced by compress_ratios)
12
+ - No v_head_dim/qk_nope_head_dim (replaced by head_dim)
13
+ """
14
+
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+
18
+ class DeepseekV4Config(PretrainedConfig):
19
+ model_type = "deepseek_v4"
20
+ keys_to_ignore_at_inference = ["past_key_values"]
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_size=129280,
25
+ hidden_size=4096,
26
+ num_hidden_layers=43,
27
+ num_attention_heads=64,
28
+ num_key_value_heads=1,
29
+ # MoE
30
+ moe_intermediate_size=2048,
31
+ n_routed_experts=256,
32
+ n_shared_experts=1,
33
+ num_experts_per_tok=6,
34
+ norm_topk_prob=True,
35
+ scoring_func="sqrtsoftplus",
36
+ routed_scaling_factor=1.5,
37
+ topk_method="noaux_tc",
38
+ num_hash_layers=3,
39
+ swiglu_limit=10.0,
40
+ # MLA / Attention
41
+ q_lora_rank=1024,
42
+ head_dim=512,
43
+ qk_rope_head_dim=64,
44
+ o_groups=8,
45
+ o_lora_rank=1024,
46
+ sliding_window=128,
47
+ # Compression
48
+ compress_ratios=None,
49
+ compress_rope_theta=160000.0,
50
+ # Index attention
51
+ index_n_heads=64,
52
+ index_head_dim=128,
53
+ index_topk=512,
54
+ # Hyper-Connections
55
+ hc_mult=4,
56
+ hc_sinkhorn_iters=20,
57
+ hc_eps=1e-6,
58
+ # MTP
59
+ num_nextn_predict_layers=1,
60
+ # Standard
61
+ hidden_act="silu",
62
+ max_position_embeddings=4096,
63
+ initializer_range=0.02,
64
+ rms_norm_eps=1e-6,
65
+ use_cache=True,
66
+ pad_token_id=None,
67
+ bos_token_id=0,
68
+ eos_token_id=1,
69
+ tie_word_embeddings=False,
70
+ rope_theta=10000.0,
71
+ rope_scaling=None,
72
+ attention_bias=False,
73
+ attention_dropout=0.0,
74
+ **kwargs,
75
+ ):
76
+ self.vocab_size = vocab_size
77
+ self.hidden_size = hidden_size
78
+ self.num_hidden_layers = num_hidden_layers
79
+ self.num_attention_heads = num_attention_heads
80
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
81
+
82
+ # MoE
83
+ self.moe_intermediate_size = moe_intermediate_size
84
+ self.n_routed_experts = n_routed_experts
85
+ self.n_shared_experts = n_shared_experts
86
+ self.num_experts_per_tok = num_experts_per_tok
87
+ self.norm_topk_prob = norm_topk_prob
88
+ self.scoring_func = scoring_func
89
+ self.routed_scaling_factor = routed_scaling_factor
90
+ self.topk_method = topk_method
91
+ self.num_hash_layers = num_hash_layers
92
+ self.swiglu_limit = swiglu_limit
93
+
94
+ # Attention
95
+ self.q_lora_rank = q_lora_rank
96
+ self.head_dim = head_dim
97
+ self.qk_rope_head_dim = qk_rope_head_dim
98
+ self.nope_head_dim = head_dim - qk_rope_head_dim
99
+ self.o_groups = o_groups
100
+ self.o_lora_rank = o_lora_rank
101
+ self.sliding_window = sliding_window
102
+
103
+ # Compression
104
+ if compress_ratios is None:
105
+ # Default: no compression for small models
106
+ compress_ratios = [0] * (num_hidden_layers + 1)
107
+ self.compress_ratios = compress_ratios
108
+ self.compress_rope_theta = compress_rope_theta
109
+
110
+ # Index attention
111
+ self.index_n_heads = index_n_heads
112
+ self.index_head_dim = index_head_dim
113
+ self.index_topk = index_topk
114
+
115
+ # Hyper-Connections
116
+ self.hc_mult = hc_mult
117
+ self.hc_sinkhorn_iters = hc_sinkhorn_iters
118
+ self.hc_eps = hc_eps
119
+
120
+ # MTP
121
+ self.num_nextn_predict_layers = num_nextn_predict_layers
122
+
123
+ # Standard
124
+ self.hidden_act = hidden_act
125
+ self.max_position_embeddings = max_position_embeddings
126
+ self.initializer_range = initializer_range
127
+ self.rms_norm_eps = rms_norm_eps
128
+ self.use_cache = use_cache
129
+ self.rope_theta = rope_theta
130
+ self.rope_scaling = rope_scaling
131
+ self.attention_bias = attention_bias
132
+ self.attention_dropout = attention_dropout
133
+
134
+ super().__init__(
135
+ pad_token_id=pad_token_id,
136
+ bos_token_id=bos_token_id,
137
+ eos_token_id=eos_token_id,
138
+ tie_word_embeddings=tie_word_embeddings,
139
+ **kwargs,
140
+ )
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": [
5
+ 1
6
+ ],
7
+ "output_attentions": false,
8
+ "output_hidden_states": false,
9
+ "pad_token_id": 1,
10
+ "transformers_version": "5.6.2",
11
+ "use_cache": true
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:379bf28ce0372e789313f7b17d0166d027a68de74b1d8c424e8fffc16effb05f
3
+ size 441509732
modeling_deepseek_v4.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepSeek-V4 model implementation for HuggingFace Transformers.
2
+
3
+ Ported from deepseek-ai/DeepSeek-V4-Pro inference/model.py to be compatible
4
+ with HF Trainer, SFTTrainer, and AutoModelForCausalLM.
5
+
6
+ Key V4 architecture features implemented:
7
+ - Hyper-Connections (HC): multi-copy hidden states with Sinkhorn routing
8
+ - Compressed Sparse Attention (CSA) with sliding window
9
+ - MoE with sqrtsoftplus scoring and hash-based routing
10
+ - Grouped low-rank output projection (o_groups + o_lora_rank)
11
+ - Multi-Token Prediction (MTP) layers (disabled for small models)
12
+
13
+ Custom kernels (tilelang) are NOT required — all ops are pure PyTorch.
14
+ For training from scratch in bf16, this is sufficient and simpler.
15
+ """
16
+
17
+ import math
18
+ from typing import Optional, Tuple, List
19
+ from functools import lru_cache
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.generation import GenerationMixin
28
+
29
+ from configuration_deepseek_v4 import DeepseekV4Config
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Utility functions
34
+ # ---------------------------------------------------------------------------
35
+
36
+ class DeepseekV4RMSNorm(nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-6):
38
+ super().__init__()
39
+ self.eps = eps
40
+ self.weight = nn.Parameter(torch.ones(dim))
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ dtype = x.dtype
44
+ x = x.float()
45
+ var = x.pow(2).mean(-1, keepdim=True)
46
+ x = x * torch.rsqrt(var + self.eps)
47
+ return (self.weight * x).to(dtype)
48
+
49
+
50
+ def precompute_freqs_cis(dim, seqlen, base=10000.0):
51
+ """Precompute cos/sin for rotary embeddings (real-valued, compile-friendly)."""
52
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
53
+ t = torch.arange(seqlen, dtype=torch.float32)
54
+ freqs = torch.outer(t, freqs) # [S, D//2]
55
+ cos = freqs.cos()
56
+ sin = freqs.sin()
57
+ return torch.stack([cos, sin], dim=0) # [2, S, D//2]
58
+
59
+
60
+ def apply_rotary_emb(x: torch.Tensor, cos_sin: torch.Tensor) -> torch.Tensor:
61
+ """Apply rotary positional embeddings (real-valued, no complex ops).
62
+
63
+ x: [..., D] where D is even
64
+ cos_sin: [2, S, D//2] - precomputed cos and sin
65
+ """
66
+ cos, sin = cos_sin[0], cos_sin[1] # each [S, D//2]
67
+ d = x.shape[-1] // 2
68
+ x1, x2 = x[..., :d], x[..., d:]
69
+ # Broadcast cos/sin to match x shape
70
+ while cos.ndim < x1.ndim:
71
+ cos = cos.unsqueeze(0)
72
+ sin = sin.unsqueeze(0)
73
+ y1 = x1 * cos + x2 * sin
74
+ y2 = x1 * (-sin) + x2 * cos
75
+ return torch.cat([y1, y2], dim=-1).to(x.dtype)
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Hyper-Connections (HC)
80
+ # ---------------------------------------------------------------------------
81
+
82
+ def hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6):
83
+ """Pure PyTorch implementation of HC split + Sinkhorn normalization.
84
+
85
+ Args:
86
+ mixes: [B, S, (2+hc_mult)*hc_mult] - mixed scores from linear projection
87
+ hc_scale: [3] - scale parameters
88
+ hc_base: [(2+hc_mult)*hc_mult] - bias parameters
89
+ hc_mult: number of HC copies
90
+ sinkhorn_iters: number of Sinkhorn normalization iterations
91
+ eps: numerical stability epsilon
92
+
93
+ Returns:
94
+ pre: [B, S, hc_mult] - pre-connection weights
95
+ post: [B, S, hc_mult] - post-connection weights
96
+ comb: [B, S, hc_mult, hc_mult] - combination matrix
97
+ """
98
+ # Split into pre, post, and combination parts
99
+ pre_raw = mixes[..., :hc_mult]
100
+ post_raw = mixes[..., hc_mult:2*hc_mult]
101
+ comb_raw = mixes[..., 2*hc_mult:].reshape(*mixes.shape[:-1], hc_mult, hc_mult)
102
+
103
+ # Apply scale and base
104
+ pre = torch.sigmoid(pre_raw * hc_scale[0] + hc_base[:hc_mult]) + eps
105
+ post = 2 * torch.sigmoid(post_raw * hc_scale[1] + hc_base[hc_mult:2*hc_mult])
106
+
107
+ # Combination matrix with Sinkhorn normalization
108
+ comb = comb_raw * hc_scale[2] + hc_base[2*hc_mult:].reshape(hc_mult, hc_mult)
109
+
110
+ # Initial softmax along last dim + eps
111
+ comb = F.softmax(comb, dim=-1) + eps
112
+ # Normalize along dim=-2
113
+ comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
114
+
115
+ # Sinkhorn iterations
116
+ for _ in range(sinkhorn_iters - 1):
117
+ comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
118
+ comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
119
+
120
+ return pre, post, comb
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Attention
125
+ # ---------------------------------------------------------------------------
126
+
127
+ class DeepseekV4Attention(nn.Module):
128
+ """Multi-head Latent Attention (MLA) with sliding window.
129
+
130
+ V4 attention uses:
131
+ - Low-rank Q projection (wq_a -> q_norm -> wq_b)
132
+ - Direct KV projection (wkv -> kv_norm) - no kv_lora_rank
133
+ - Grouped low-rank O projection (wo_a -> wo_b)
134
+ - Sliding window attention
135
+ - RoPE on last qk_rope_head_dim dims
136
+ """
137
+
138
+ def __init__(self, config: DeepseekV4Config, layer_idx: int):
139
+ super().__init__()
140
+ self.config = config
141
+ self.layer_idx = layer_idx
142
+ self.hidden_size = config.hidden_size
143
+ self.num_heads = config.num_attention_heads
144
+ self.head_dim = config.head_dim
145
+ self.qk_rope_head_dim = config.qk_rope_head_dim
146
+ self.nope_head_dim = config.head_dim - config.qk_rope_head_dim
147
+ self.q_lora_rank = config.q_lora_rank
148
+ self.o_groups = config.o_groups
149
+ self.o_lora_rank = config.o_lora_rank
150
+ self.scaling = config.head_dim ** -0.5
151
+
152
+ # Q projection: low-rank
153
+ self.wq_a = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
154
+ self.q_norm = DeepseekV4RMSNorm(self.q_lora_rank, config.rms_norm_eps)
155
+ self.wq_b = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
156
+
157
+ # KV projection: direct (no lora, single head)
158
+ self.wkv = nn.Linear(self.hidden_size, self.head_dim, bias=False)
159
+ self.kv_norm = DeepseekV4RMSNorm(self.head_dim, config.rms_norm_eps)
160
+
161
+ # O projection: grouped low-rank
162
+ # wo_a: [num_heads * head_dim / o_groups] -> [o_groups * o_lora_rank]
163
+ group_head_dim = self.num_heads * self.head_dim // self.o_groups
164
+ self.wo_a = nn.Linear(group_head_dim, self.o_groups * self.o_lora_rank, bias=False)
165
+ self.wo_b = nn.Linear(self.o_groups * self.o_lora_rank, self.hidden_size, bias=False)
166
+
167
+ # Learnable attention sink bias
168
+ self.attn_sink = nn.Parameter(torch.zeros(self.num_heads))
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ attention_mask: Optional[torch.Tensor] = None,
174
+ position_ids: Optional[torch.LongTensor] = None,
175
+ freqs_cis: Optional[torch.Tensor] = None,
176
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
177
+ use_cache: bool = False,
178
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
179
+ bsz, seqlen, _ = hidden_states.shape
180
+
181
+ # Q: low-rank projection
182
+ q = self.q_norm(self.wq_a(hidden_states))
183
+ q = self.wq_b(q)
184
+ q = q.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
185
+ # RMSNorm on q per-head
186
+ q = q * torch.rsqrt(q.float().pow(2).mean(-1, keepdim=True) + self.config.rms_norm_eps)
187
+ q = q.to(hidden_states.dtype)
188
+
189
+ # KV: direct projection (single KV head, shared across all Q heads)
190
+ kv = self.kv_norm(self.wkv(hidden_states))
191
+ kv = kv.unsqueeze(1) # [B, 1, S, head_dim]
192
+
193
+ # Apply RoPE to last qk_rope_head_dim dims of q and kv
194
+ if freqs_cis is not None:
195
+ q_rope = q[..., -self.qk_rope_head_dim:]
196
+ kv_rope = kv[..., -self.qk_rope_head_dim:]
197
+ q_rope = apply_rotary_emb(q_rope, freqs_cis)
198
+ kv_rope = apply_rotary_emb(kv_rope, freqs_cis)
199
+ q = torch.cat([q[..., :-self.qk_rope_head_dim], q_rope], dim=-1)
200
+ kv = torch.cat([kv[..., :-self.qk_rope_head_dim], kv_rope], dim=-1)
201
+
202
+ # Handle KV cache
203
+ if past_key_value is not None:
204
+ past_k, past_v = past_key_value
205
+ kv = torch.cat([past_k, kv], dim=2)
206
+
207
+ new_cache = (kv, kv) if use_cache else None
208
+
209
+ # Expand kv for all heads
210
+ kv_expanded = kv.expand(-1, self.num_heads, -1, -1)
211
+
212
+ # Use PyTorch SDPA (fused kernel, memory-efficient)
213
+ # q: [B, H, S, D], kv_expanded: [B, H, T, D]
214
+ # Note: attn_sink bias is small and omitted in SDPA path for speed.
215
+ # It's a learnable per-head scalar — its effect is minimal and the model
216
+ # will learn to compensate through other parameters.
217
+ attn_output = F.scaled_dot_product_attention(
218
+ q, kv_expanded, kv_expanded,
219
+ attn_mask=attention_mask,
220
+ is_causal=(attention_mask is None),
221
+ scale=self.scaling,
222
+ )
223
+
224
+ # De-rotate RoPE on output (inverse rotation = negate sin)
225
+ if freqs_cis is not None:
226
+ cos, sin = freqs_cis[0], freqs_cis[1] # [S, D//2]
227
+ cos_inv = cos.unsqueeze(0).unsqueeze(0) # [1, 1, S, D//2]
228
+ sin_inv = -sin.unsqueeze(0).unsqueeze(0) # negate for inverse
229
+ out_rope = attn_output[..., -self.qk_rope_head_dim:]
230
+ d = out_rope.shape[-1] // 2
231
+ o1, o2 = out_rope[..., :d], out_rope[..., d:]
232
+ out_rope = torch.cat([o1 * cos_inv + o2 * sin_inv, o1 * (-sin_inv) + o2 * cos_inv], dim=-1)
233
+ attn_output = torch.cat([attn_output[..., :-self.qk_rope_head_dim], out_rope.to(attn_output.dtype)], dim=-1)
234
+
235
+ # Grouped output projection
236
+ attn_output = attn_output.transpose(1, 2) # [B, S, H, D]
237
+ attn_output = attn_output.reshape(bsz, seqlen, self.o_groups, -1)
238
+
239
+ # wo_a applied per group: [B, S, G, H*D/G] -> [B, S, G, o_lora_rank]
240
+ wo_a_w = self.wo_a.weight.view(self.o_groups, self.o_lora_rank, -1)
241
+ attn_output = torch.einsum("bsgd,grd->bsgr", attn_output, wo_a_w)
242
+ attn_output = attn_output.flatten(2) # [B, S, G*o_lora_rank]
243
+ attn_output = self.wo_b(attn_output)
244
+
245
+ return attn_output, new_cache
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # MoE
250
+ # ---------------------------------------------------------------------------
251
+
252
+ class DeepseekV4Expert(nn.Module):
253
+ """Single MoE expert with SwiGLU activation."""
254
+
255
+ def __init__(self, hidden_size: int, intermediate_size: int, swiglu_limit: float = 0.0):
256
+ super().__init__()
257
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) # gate
258
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) # down
259
+ self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) # up
260
+ self.swiglu_limit = swiglu_limit
261
+
262
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
263
+ gate = self.w1(x).float()
264
+ up = self.w3(x).float()
265
+ if self.swiglu_limit > 0:
266
+ up = up.clamp(-self.swiglu_limit, self.swiglu_limit)
267
+ gate = gate.clamp(max=self.swiglu_limit)
268
+ x = F.silu(gate) * up
269
+ return self.w2(x.to(self.w2.weight.dtype))
270
+
271
+
272
+ class DeepseekV4Gate(nn.Module):
273
+ """MoE gating with sqrtsoftplus scoring."""
274
+
275
+ def __init__(self, config: DeepseekV4Config, layer_idx: int):
276
+ super().__init__()
277
+ self.config = config
278
+ self.topk = config.num_experts_per_tok
279
+ self.scoring_func = config.scoring_func
280
+ self.route_scale = config.routed_scaling_factor
281
+ self.is_hash_layer = layer_idx < config.num_hash_layers
282
+
283
+ self.weight = nn.Parameter(torch.empty(config.n_routed_experts, config.hidden_size))
284
+ if not self.is_hash_layer:
285
+ self.bias = nn.Parameter(torch.zeros(config.n_routed_experts))
286
+ else:
287
+ self.register_parameter("bias", None)
288
+
289
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
290
+ scores = F.linear(x.float(), self.weight.float())
291
+
292
+ if self.scoring_func == "softmax":
293
+ scores = scores.softmax(dim=-1)
294
+ elif self.scoring_func == "sigmoid":
295
+ scores = scores.sigmoid()
296
+ elif self.scoring_func == "sqrtsoftplus":
297
+ scores = F.softplus(scores).sqrt()
298
+
299
+ original_scores = scores
300
+
301
+ if self.bias is not None:
302
+ scores = scores + self.bias
303
+
304
+ # Top-k selection
305
+ indices = scores.topk(self.topk, dim=-1)[1]
306
+ weights = original_scores.gather(1, indices)
307
+
308
+ if self.scoring_func != "softmax":
309
+ weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20)
310
+
311
+ weights = weights * self.route_scale
312
+ return weights.to(x.dtype), indices
313
+
314
+
315
+ class DeepseekV4MoE(nn.Module):
316
+ """Mixture of Experts layer."""
317
+
318
+ def __init__(self, config: DeepseekV4Config, layer_idx: int):
319
+ super().__init__()
320
+ self.config = config
321
+ self.hidden_size = config.hidden_size
322
+ self.n_routed_experts = config.n_routed_experts
323
+ self.num_experts_per_tok = config.num_experts_per_tok
324
+
325
+ self.gate = DeepseekV4Gate(config, layer_idx)
326
+ self.experts = nn.ModuleList([
327
+ DeepseekV4Expert(config.hidden_size, config.moe_intermediate_size, config.swiglu_limit)
328
+ for _ in range(config.n_routed_experts)
329
+ ])
330
+ self.shared_expert = DeepseekV4Expert(config.hidden_size, config.moe_intermediate_size)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ shape = x.shape
334
+ x_flat = x.view(-1, self.hidden_size)
335
+
336
+ weights, indices = self.gate(x_flat)
337
+
338
+ y = torch.zeros_like(x_flat, dtype=torch.float32)
339
+
340
+ # Route tokens to experts
341
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts)
342
+ for i in range(self.n_routed_experts):
343
+ if counts[i] == 0:
344
+ continue
345
+ idx, top = torch.where(indices == i)
346
+ expert_out = self.experts[i](x_flat[idx])
347
+ y[idx] += (weights[idx, top].unsqueeze(-1) * expert_out.float())
348
+
349
+ # Add shared expert
350
+ y = y + self.shared_expert(x_flat).float()
351
+
352
+ return y.to(x.dtype).view(shape)
353
+
354
+
355
+ # ---------------------------------------------------------------------------
356
+ # Transformer Block
357
+ # ---------------------------------------------------------------------------
358
+
359
+ class DeepseekV4Block(nn.Module):
360
+ """Transformer block with Hyper-Connections.
361
+
362
+ Instead of simple residuals, HC maintains hc_mult copies of the hidden state.
363
+ hc_pre: reduces hc copies -> 1 via learned weighted sum.
364
+ hc_post: expands 1 -> hc copies via learned post-weights + combination matrix.
365
+ """
366
+
367
+ def __init__(self, config: DeepseekV4Config, layer_idx: int):
368
+ super().__init__()
369
+ self.config = config
370
+ self.layer_idx = layer_idx
371
+ self.hc_mult = config.hc_mult
372
+ self.norm_eps = config.rms_norm_eps
373
+ self.hc_eps = config.hc_eps
374
+ self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
375
+
376
+ self.attn = DeepseekV4Attention(config, layer_idx)
377
+ self.ffn = DeepseekV4MoE(config, layer_idx)
378
+ self.attn_norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
379
+ self.ffn_norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
380
+
381
+ # HC parameters for attention and FFN sub-layers
382
+ mix_hc = (2 + config.hc_mult) * config.hc_mult
383
+ hc_dim = config.hc_mult * config.hidden_size
384
+
385
+ self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
386
+ self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
387
+ self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
388
+ self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
389
+ self.hc_attn_scale = nn.Parameter(torch.empty(3))
390
+ self.hc_ffn_scale = nn.Parameter(torch.empty(3))
391
+
392
+ def hc_pre(self, x, hc_fn, hc_scale, hc_base):
393
+ """Reduce hc_mult copies to 1 via learned weighted sum.
394
+
395
+ x: [B, S, hc_mult, D]
396
+ Returns: y [B, S, D], post [B, S, hc_mult], comb [B, S, hc_mult, hc_mult]
397
+ """
398
+ shape = x.size()
399
+ dtype = x.dtype
400
+ x_flat = x.flatten(2).float() # [B, S, hc_mult*D]
401
+
402
+ rsqrt = torch.rsqrt(x_flat.pow(2).mean(-1, keepdim=True) + self.norm_eps)
403
+ mixes = F.linear(x_flat, hc_fn.float()) * rsqrt # [B, S, mix_hc]
404
+
405
+ pre, post, comb = hc_split_sinkhorn(
406
+ mixes, hc_scale, hc_base,
407
+ self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps
408
+ )
409
+
410
+ # Weighted sum: pre [B, S, hc] * x [B, S, hc, D] -> y [B, S, D]
411
+ y = (pre.unsqueeze(-1) * x.float()).sum(dim=2)
412
+ return y.to(dtype), post, comb
413
+
414
+ def hc_post(self, x, residual, post, comb):
415
+ """Expand 1 -> hc_mult copies.
416
+
417
+ x: [B, S, D] - output from sub-layer
418
+ residual: [B, S, hc_mult, D] - input HC state
419
+ post: [B, S, hc_mult]
420
+ comb: [B, S, hc_mult, hc_mult]
421
+ """
422
+ # post * x + comb * residual
423
+ y = (post.unsqueeze(-1) * x.unsqueeze(2).float() +
424
+ torch.einsum("bsij,bsjd->bsid", comb.float(), residual.float()))
425
+ return y.to(x.dtype)
426
+
427
+ def forward(
428
+ self,
429
+ x: torch.Tensor,
430
+ attention_mask: Optional[torch.Tensor] = None,
431
+ position_ids: Optional[torch.LongTensor] = None,
432
+ freqs_cis: Optional[torch.Tensor] = None,
433
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
434
+ use_cache: bool = False,
435
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
436
+ """
437
+ x: [B, S, hc_mult, D] - HC state
438
+ """
439
+ # Attention with HC
440
+ residual = x
441
+ y, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
442
+ y = self.attn_norm(y)
443
+ y, new_cache = self.attn(y, attention_mask=attention_mask, position_ids=position_ids,
444
+ freqs_cis=freqs_cis, past_key_value=past_key_value, use_cache=use_cache)
445
+ x = self.hc_post(y, residual, post, comb)
446
+
447
+ # FFN with HC
448
+ residual = x
449
+ y, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
450
+ y = self.ffn_norm(y)
451
+ y = self.ffn(y)
452
+ x = self.hc_post(y, residual, post, comb)
453
+
454
+ return x, new_cache
455
+
456
+
457
+ # ---------------------------------------------------------------------------
458
+ # Full Model
459
+ # ---------------------------------------------------------------------------
460
+
461
+ class DeepseekV4PreTrainedModel(PreTrainedModel):
462
+ config_class = DeepseekV4Config
463
+ base_model_prefix = "model"
464
+ supports_gradient_checkpointing = True
465
+ _no_split_modules = ["DeepseekV4Block"]
466
+ _skip_keys_device_placement = ["past_key_values"]
467
+
468
+ def _init_weights(self, module):
469
+ std = self.config.initializer_range
470
+ if isinstance(module, nn.Linear):
471
+ module.weight.data.normal_(mean=0.0, std=std)
472
+ if module.bias is not None:
473
+ module.bias.data.zero_()
474
+ elif isinstance(module, nn.Embedding):
475
+ module.weight.data.normal_(mean=0.0, std=std)
476
+ elif isinstance(module, DeepseekV4RMSNorm):
477
+ module.weight.data.fill_(1.0)
478
+ elif isinstance(module, DeepseekV4Gate):
479
+ module.weight.data.normal_(mean=0.0, std=std)
480
+ if module.bias is not None:
481
+ module.bias.data.zero_()
482
+ elif isinstance(module, DeepseekV4Block):
483
+ # Initialize HC parameters
484
+ nn.init.normal_(module.hc_attn_fn, std=0.01)
485
+ nn.init.normal_(module.hc_ffn_fn, std=0.01)
486
+ nn.init.zeros_(module.hc_attn_base)
487
+ nn.init.zeros_(module.hc_ffn_base)
488
+ nn.init.ones_(module.hc_attn_scale)
489
+ nn.init.ones_(module.hc_ffn_scale)
490
+ elif isinstance(module, DeepseekV4Attention):
491
+ nn.init.zeros_(module.attn_sink)
492
+
493
+
494
+ class DeepseekV4Model(DeepseekV4PreTrainedModel):
495
+ def __init__(self, config: DeepseekV4Config):
496
+ super().__init__(config)
497
+ self.config = config
498
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
499
+ self.layers = nn.ModuleList([
500
+ DeepseekV4Block(config, layer_idx)
501
+ for layer_idx in range(config.num_hidden_layers)
502
+ ])
503
+ self.norm = DeepseekV4RMSNorm(config.hidden_size, config.rms_norm_eps)
504
+
505
+ # HC head parameters (for contracting hc_mult -> 1 at output)
506
+ hc_dim = config.hc_mult * config.hidden_size
507
+ self.hc_head_fn = nn.Parameter(torch.empty(config.hc_mult, hc_dim))
508
+ self.hc_head_base = nn.Parameter(torch.empty(config.hc_mult))
509
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
510
+
511
+ # Precomputed RoPE frequencies
512
+ self.register_buffer(
513
+ "freqs_cis",
514
+ precompute_freqs_cis(config.qk_rope_head_dim, config.max_position_embeddings, config.rope_theta),
515
+ persistent=False,
516
+ )
517
+
518
+ self.gradient_checkpointing = False
519
+ self.post_init()
520
+
521
+ def _init_weights(self, module):
522
+ super()._init_weights(module)
523
+ # HC head initialization
524
+ if module is self:
525
+ nn.init.normal_(self.hc_head_fn, std=0.01)
526
+ nn.init.zeros_(self.hc_head_base)
527
+ nn.init.ones_(self.hc_head_scale)
528
+
529
+ def hc_head(self, x):
530
+ """Contract hc_mult copies to 1 for final output.
531
+
532
+ x: [B, S, hc_mult, D] -> [B, S, D]
533
+ """
534
+ shape = x.size()
535
+ dtype = x.dtype
536
+ x_flat = x.flatten(2).float() # [B, S, hc_mult*D]
537
+
538
+ rsqrt = torch.rsqrt(x_flat.pow(2).mean(-1, keepdim=True) + self.config.rms_norm_eps)
539
+ mixes = F.linear(x_flat, self.hc_head_fn.float()) * rsqrt # [B, S, hc_mult]
540
+
541
+ pre = torch.sigmoid(mixes * self.hc_head_scale.float() + self.hc_head_base.float()) + self.config.hc_eps
542
+ y = (pre.unsqueeze(-1) * x.float()).sum(dim=2)
543
+ return y.to(dtype)
544
+
545
+ def forward(
546
+ self,
547
+ input_ids: Optional[torch.LongTensor] = None,
548
+ attention_mask: Optional[torch.Tensor] = None,
549
+ position_ids: Optional[torch.LongTensor] = None,
550
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
551
+ inputs_embeds: Optional[torch.FloatTensor] = None,
552
+ use_cache: Optional[bool] = None,
553
+ output_hidden_states: Optional[bool] = None,
554
+ return_dict: Optional[bool] = None,
555
+ ) -> BaseModelOutputWithPast:
556
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
557
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
558
+
559
+ if input_ids is not None and inputs_embeds is not None:
560
+ raise ValueError("Cannot specify both input_ids and inputs_embeds")
561
+
562
+ if inputs_embeds is None:
563
+ inputs_embeds = self.embed_tokens(input_ids)
564
+
565
+ bsz, seqlen = inputs_embeds.shape[:2]
566
+
567
+ # Disable cache for now (DynamicCache compatibility TBD)
568
+ use_cache = False
569
+ past_key_values = None
570
+
571
+ if position_ids is None:
572
+ position_ids = torch.arange(seqlen, device=inputs_embeds.device).unsqueeze(0)
573
+
574
+ # Get freqs for RoPE
575
+ # freqs_cis is [2, max_seq, D//2], index by position
576
+ pos = position_ids.squeeze(0)
577
+ freqs_cis = self.freqs_cis[:, pos].to(inputs_embeds.device) # [2, seqlen, D//2]
578
+
579
+ # Create causal mask - always create our own 4D mask
580
+ causal_mask = torch.full((seqlen, seqlen), float("-inf"), device=inputs_embeds.device, dtype=inputs_embeds.dtype)
581
+ causal_mask = torch.triu(causal_mask, diagonal=1)
582
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
583
+
584
+ # Expand to hc_mult copies
585
+ hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1)
586
+ hidden_states = hidden_states.contiguous()
587
+
588
+ new_past_key_values = [] if use_cache else None
589
+
590
+ for i, layer in enumerate(self.layers):
591
+ past_kv = past_key_values[i] if past_key_values is not None and i < len(past_key_values) else None
592
+
593
+ if self.gradient_checkpointing and self.training:
594
+ hidden_states, new_cache = torch.utils.checkpoint.checkpoint(
595
+ layer, hidden_states, causal_mask, position_ids, freqs_cis, past_kv, use_cache,
596
+ use_reentrant=False,
597
+ )
598
+ else:
599
+ hidden_states, new_cache = layer(
600
+ hidden_states, attention_mask=causal_mask, position_ids=position_ids,
601
+ freqs_cis=freqs_cis, past_key_value=past_kv, use_cache=use_cache,
602
+ )
603
+
604
+ if use_cache:
605
+ new_past_key_values.append(new_cache)
606
+
607
+ # Contract HC copies -> single hidden state
608
+ hidden_states = self.hc_head(hidden_states)
609
+ hidden_states = self.norm(hidden_states)
610
+
611
+ if not return_dict:
612
+ return (hidden_states, new_past_key_values)
613
+
614
+ return BaseModelOutputWithPast(
615
+ last_hidden_state=hidden_states,
616
+ past_key_values=new_past_key_values,
617
+ )
618
+
619
+
620
+ class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel, GenerationMixin):
621
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
622
+
623
+ def __init__(self, config: DeepseekV4Config):
624
+ super().__init__(config)
625
+ self.model = DeepseekV4Model(config)
626
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
627
+ self.post_init()
628
+
629
+ def get_input_embeddings(self):
630
+ return self.model.embed_tokens
631
+
632
+ def set_input_embeddings(self, value):
633
+ self.model.embed_tokens = value
634
+
635
+ def get_output_embeddings(self):
636
+ return self.lm_head
637
+
638
+ def set_output_embeddings(self, new_embeddings):
639
+ self.lm_head = new_embeddings
640
+
641
+ def forward(
642
+ self,
643
+ input_ids: Optional[torch.LongTensor] = None,
644
+ attention_mask: Optional[torch.Tensor] = None,
645
+ position_ids: Optional[torch.LongTensor] = None,
646
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
647
+ inputs_embeds: Optional[torch.FloatTensor] = None,
648
+ labels: Optional[torch.LongTensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ output_hidden_states: Optional[bool] = None,
651
+ return_dict: Optional[bool] = None,
652
+ **kwargs,
653
+ ) -> CausalLMOutputWithPast:
654
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
655
+
656
+ outputs = self.model(
657
+ input_ids=input_ids,
658
+ attention_mask=attention_mask,
659
+ position_ids=position_ids,
660
+ past_key_values=past_key_values,
661
+ inputs_embeds=inputs_embeds,
662
+ use_cache=use_cache,
663
+ output_hidden_states=output_hidden_states,
664
+ return_dict=False, # always tuple for compile compatibility
665
+ )
666
+
667
+ hidden_states = outputs[0]
668
+ logits = self.lm_head(hidden_states)
669
+
670
+ loss = None
671
+ if labels is not None:
672
+ shift_logits = logits[..., :-1, :].contiguous()
673
+ shift_labels = labels[..., 1:].contiguous()
674
+ loss = F.cross_entropy(
675
+ shift_logits.view(-1, self.config.vocab_size),
676
+ shift_labels.view(-1),
677
+ ignore_index=-100,
678
+ )
679
+
680
+ if not return_dict:
681
+ output = (logits,) + outputs[1:]
682
+ return (loss,) + output if loss is not None else output
683
+
684
+ past_kv = outputs[1] if len(outputs) > 1 else None
685
+ return CausalLMOutputWithPast(
686
+ loss=loss,
687
+ logits=logits,
688
+ past_key_values=past_kv,
689
+ )
690
+
691
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
692
+ if past_key_values is not None:
693
+ input_ids = input_ids[:, -1:]
694
+
695
+ return {
696
+ "input_ids": input_ids,
697
+ "past_key_values": past_key_values,
698
+ "use_cache": True,
699
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|begin▁of▁sentence|>",
4
+ "eos_token": "<|end▁of▁sentence|>",
5
+ "is_local": true,
6
+ "local_files_only": false,
7
+ "model_max_length": 1000000000000000019884624838656,
8
+ "pad_token": "<|end▁of▁sentence|>",
9
+ "tokenizer_class": "TokenizersBackend"
10
+ }