Raincleared commited on
Commit
8b38b65
·
verified ·
1 Parent(s): 724f467

Upload folder using huggingface_hub

Browse files
c4_validation.json ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BlockFFNForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_blockffn.BlockFFNConfig",
7
+ "AutoModel": "modeling_blockffn.BlockFFNModel",
8
+ "AutoModelForCausalLM": "modeling_blockffn.BlockFFNForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": [
12
+ 2,
13
+ 73440
14
+ ],
15
+ "pad_token_id": 2,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 768,
18
+ "initializer_range": 0.1,
19
+ "intermediate_size": 10240,
20
+ "head_dim": 128,
21
+ "max_position_embeddings": 4096,
22
+ "model_type": "blockffn",
23
+ "num_attention_heads": 6,
24
+ "num_hidden_layers": 32,
25
+ "num_key_value_heads": 2,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": null,
28
+ "rope_theta": 10000.0,
29
+ "torch_dtype": "bfloat16",
30
+ "transformers_version": "4.36.0",
31
+ "use_cache": true,
32
+ "vocab_size": 73448,
33
+ "use_mup": false,
34
+ "num_experts": 42,
35
+ "moe_ffn_hidden_size": 64,
36
+ "moe_shared_expert_intermediate_size": 128,
37
+ "moe_layer_freq": [
38
+ 0,
39
+ 1,
40
+ 1,
41
+ 1,
42
+ 1,
43
+ 1,
44
+ 1,
45
+ 1,
46
+ 1,
47
+ 1,
48
+ 1,
49
+ 1,
50
+ 1,
51
+ 1,
52
+ 1,
53
+ 1
54
+ ],
55
+ "moe_router_dtype": "fp32",
56
+ "router_act_func": "relu",
57
+ "router_norm_type": "simple",
58
+ "expert_act_func": "norm_silu",
59
+ "expert_act_norm_type": "normal",
60
+ "num_layers": 16,
61
+ "ffn_hidden_size": 1920,
62
+ "num_query_groups": 6,
63
+ "norm_epsilon": 1e-05,
64
+ "use_blockffn": true,
65
+ "router_type": "topk",
66
+ "moe_router_enable_expert_bias": false,
67
+ "expert_not_gated": true,
68
+ "moe_router_pre_softmax": false,
69
+ "moe_router_topk": 2,
70
+ "moe_router_topp": 0.5,
71
+ "moe_router_score_function": "softmax",
72
+ "moe_router_topk_scaling_factor": null
73
+ }
configuration_blockffn.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """BlockFFN model configuration"""
21
+
22
+ from transformers import PretrainedConfig
23
+ from transformers.modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class BlockFFNConfig(PretrainedConfig):
27
+
28
+ model_type = "blockffn"
29
+ keys_to_ignore_at_inference = ["past_key_values"]
30
+ # Default tensor parallel plan for base model `BlockFFNModel`
31
+ base_model_tp_plan = {
32
+ "layers.*.self_attn.q_proj": "colwise",
33
+ "layers.*.self_attn.k_proj": "colwise",
34
+ "layers.*.self_attn.v_proj": "colwise",
35
+ "layers.*.self_attn.o_proj": "rowwise",
36
+ "layers.*.mlp.gate_proj": "colwise",
37
+ "layers.*.mlp.up_proj": "colwise",
38
+ "layers.*.mlp.down_proj": "rowwise",
39
+ }
40
+ base_model_pp_plan = {
41
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
42
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
43
+ "norm": (["hidden_states"], ["hidden_states"]),
44
+ }
45
+
46
+ def __init__(
47
+ self,
48
+ vocab_size=32000,
49
+ hidden_size=4096,
50
+ ffn_hidden_size=11008,
51
+ num_layers=32,
52
+ num_attention_heads=32,
53
+ num_query_groups=None,
54
+ hidden_act="silu",
55
+ max_position_embeddings=2048,
56
+ initializer_range=0.02,
57
+ norm_epsilon=1e-6,
58
+ use_cache=True,
59
+ pad_token_id=None,
60
+ bos_token_id=1,
61
+ eos_token_id=2,
62
+ pretraining_tp=1,
63
+ tie_word_embeddings=False,
64
+ rope_theta=10000.0,
65
+ rope_scaling=None,
66
+ attention_bias=False,
67
+ attention_dropout=0.0,
68
+ mlp_bias=False,
69
+ head_dim=None,
70
+ use_mup=True,
71
+ mup_emb_scale=12,
72
+ mup_depth_scale=1.4,
73
+ mup_base_hidden_size=256,
74
+ num_experts=180,
75
+ moe_ffn_hidden_size=128,
76
+ moe_shared_expert_intermediate_size=128,
77
+ moe_layer_freq="([0]*3+[1]*29)",
78
+ moe_router_dtype="fp32",
79
+ router_act_func="relu",
80
+ router_norm_type="simple",
81
+ expert_act_func="norm_silu",
82
+ expert_act_norm_type="normal",
83
+ use_blockffn=False,
84
+ router_type="topk",
85
+ moe_router_topk=0,
86
+ moe_router_topp=0,
87
+ moe_router_enable_expert_bias=False,
88
+ moe_router_score_function="sigmoid",
89
+ moe_router_topk_scaling_factor=2.5,
90
+ expert_not_gated=False,
91
+ moe_router_pre_softmax=False,
92
+ **kwargs,
93
+ ):
94
+ self.vocab_size = vocab_size
95
+ self.max_position_embeddings = max_position_embeddings
96
+ self.hidden_size = hidden_size
97
+ self.ffn_hidden_size = ffn_hidden_size
98
+ self.num_layers = num_layers
99
+ self.num_attention_heads = num_attention_heads
100
+
101
+ # for backward compatibility
102
+ if num_query_groups is None:
103
+ num_query_groups = num_attention_heads
104
+
105
+ self.num_query_groups = num_query_groups
106
+ self.hidden_act = hidden_act
107
+ self.initializer_range = initializer_range
108
+ self.norm_epsilon = norm_epsilon
109
+ self.pretraining_tp = pretraining_tp
110
+ self.use_cache = use_cache
111
+ self.rope_theta = rope_theta
112
+ self.rope_scaling = rope_scaling
113
+ self.attention_bias = attention_bias
114
+ self.attention_dropout = attention_dropout
115
+ self.mlp_bias = mlp_bias
116
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
117
+ self.use_mup = use_mup
118
+ self.mup_emb_scale = mup_emb_scale
119
+ self.mup_depth_scale = mup_depth_scale
120
+ self.mup_base_hidden_size = mup_base_hidden_size
121
+
122
+ self.num_experts = num_experts
123
+ self.moe_ffn_hidden_size = moe_ffn_hidden_size
124
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
125
+ self.moe_layer_freq = moe_layer_freq if isinstance(moe_layer_freq, (str, list)) else ([0] * num_layers)
126
+ self.moe_router_dtype = moe_router_dtype
127
+ self.router_act_func = router_act_func
128
+ self.router_norm_type = router_norm_type
129
+ self.expert_act_func = expert_act_func
130
+ self.expert_act_norm_type = expert_act_norm_type
131
+
132
+ self.use_blockffn = use_blockffn
133
+ self.router_type = router_type
134
+ self.moe_router_topk = moe_router_topk
135
+ self.moe_router_topp = moe_router_topp
136
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
137
+ self.moe_router_score_function = moe_router_score_function
138
+ self.moe_router_topk_scaling_factor = moe_router_topk_scaling_factor
139
+ self.expert_not_gated = expert_not_gated
140
+ self.moe_router_pre_softmax = moe_router_pre_softmax
141
+
142
+ # Validate the correctness of rotary position embeddings parameters
143
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
144
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
145
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
146
+ rope_config_validation(self)
147
+
148
+ super().__init__(
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+ @property
157
+ def mup_width_scale(self):
158
+ return (self.hidden_size / self.mup_base_hidden_size) if (self.use_mup and self.mup_base_hidden_size > 0) else 1
159
+
160
+
161
+ __all__ = ["BlockFFNConfig"]
evaluation.log ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__/results_2026-01-22T22-09-07.950131.json ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "results": {
3
+ "arc_challenge": {
4
+ "alias": "arc_challenge",
5
+ "acc,none": 0.2022184300341297,
6
+ "acc_stderr,none": 0.011737454431872105,
7
+ "acc_norm,none": 0.23464163822525597,
8
+ "acc_norm_stderr,none": 0.012383873560768687
9
+ },
10
+ "arc_easy": {
11
+ "alias": "arc_easy",
12
+ "acc,none": 0.44402356902356904,
13
+ "acc_stderr,none": 0.010195285580783957,
14
+ "acc_norm,none": 0.40404040404040403,
15
+ "acc_norm_stderr,none": 0.010069061649549549
16
+ },
17
+ "boolq": {
18
+ "alias": "boolq",
19
+ "acc,none": 0.5715596330275229,
20
+ "acc_stderr,none": 0.00865502856151977
21
+ },
22
+ "hellaswag": {
23
+ "alias": "hellaswag",
24
+ "acc,none": 0.27962557259510057,
25
+ "acc_stderr,none": 0.004478979795506762,
26
+ "acc_norm,none": 0.29177454690300736,
27
+ "acc_norm_stderr,none": 0.0045365007141480096
28
+ },
29
+ "lambada_standard": {
30
+ "alias": "lambada_standard",
31
+ "perplexity,none": 536.9090379599169,
32
+ "perplexity_stderr,none": 25.490598104897224,
33
+ "acc,none": 0.15796623326217737,
34
+ "acc_stderr,none": 0.005081114222421346
35
+ },
36
+ "piqa": {
37
+ "alias": "piqa",
38
+ "acc,none": 0.6093579978237215,
39
+ "acc_stderr,none": 0.01138337776005358,
40
+ "acc_norm,none": 0.6131664853101197,
41
+ "acc_norm_stderr,none": 0.011363095931902854
42
+ },
43
+ "social_iqa": {
44
+ "alias": "social_iqa",
45
+ "acc,none": 0.368474923234391,
46
+ "acc_stderr,none": 0.010915613431252628
47
+ },
48
+ "wikitext": {
49
+ "alias": "wikitext",
50
+ "word_perplexity,none": 54.32450126248072,
51
+ "word_perplexity_stderr,none": "N/A",
52
+ "byte_perplexity,none": 2.1108281516522545,
53
+ "byte_perplexity_stderr,none": "N/A",
54
+ "bits_per_byte,none": 1.077809129679421,
55
+ "bits_per_byte_stderr,none": "N/A"
56
+ },
57
+ "winogrande": {
58
+ "alias": "winogrande",
59
+ "acc,none": 0.5067087608524072,
60
+ "acc_stderr,none": 0.014051220692330349
61
+ }
62
+ },
63
+ "group_subtasks": {
64
+ "arc_challenge": [],
65
+ "arc_easy": [],
66
+ "boolq": [],
67
+ "hellaswag": [],
68
+ "lambada_standard": [],
69
+ "piqa": [],
70
+ "social_iqa": [],
71
+ "wikitext": [],
72
+ "winogrande": []
73
+ },
74
+ "configs": {
75
+ "arc_challenge": {
76
+ "task": "arc_challenge",
77
+ "tag": [
78
+ "ai2_arc"
79
+ ],
80
+ "dataset_path": "allenai/ai2_arc",
81
+ "dataset_name": "ARC-Challenge",
82
+ "training_split": "train",
83
+ "validation_split": "validation",
84
+ "test_split": "test",
85
+ "doc_to_text": "Question: {{question}}\nAnswer:",
86
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
87
+ "unsafe_code": false,
88
+ "doc_to_choice": "{{choices.text}}",
89
+ "description": "",
90
+ "target_delimiter": " ",
91
+ "fewshot_delimiter": "\n\n",
92
+ "num_fewshot": 0,
93
+ "metric_list": [
94
+ {
95
+ "metric": "acc",
96
+ "aggregation": "mean",
97
+ "higher_is_better": true
98
+ },
99
+ {
100
+ "metric": "acc_norm",
101
+ "aggregation": "mean",
102
+ "higher_is_better": true
103
+ }
104
+ ],
105
+ "output_type": "multiple_choice",
106
+ "repeats": 1,
107
+ "should_decontaminate": true,
108
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
109
+ "metadata": {
110
+ "version": 1.0,
111
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
112
+ "dtype": "bfloat16",
113
+ "trust_remote_code": true
114
+ }
115
+ },
116
+ "arc_easy": {
117
+ "task": "arc_easy",
118
+ "tag": [
119
+ "ai2_arc"
120
+ ],
121
+ "dataset_path": "allenai/ai2_arc",
122
+ "dataset_name": "ARC-Easy",
123
+ "training_split": "train",
124
+ "validation_split": "validation",
125
+ "test_split": "test",
126
+ "doc_to_text": "Question: {{question}}\nAnswer:",
127
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
128
+ "unsafe_code": false,
129
+ "doc_to_choice": "{{choices.text}}",
130
+ "description": "",
131
+ "target_delimiter": " ",
132
+ "fewshot_delimiter": "\n\n",
133
+ "num_fewshot": 0,
134
+ "metric_list": [
135
+ {
136
+ "metric": "acc",
137
+ "aggregation": "mean",
138
+ "higher_is_better": true
139
+ },
140
+ {
141
+ "metric": "acc_norm",
142
+ "aggregation": "mean",
143
+ "higher_is_better": true
144
+ }
145
+ ],
146
+ "output_type": "multiple_choice",
147
+ "repeats": 1,
148
+ "should_decontaminate": true,
149
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
150
+ "metadata": {
151
+ "version": 1.0,
152
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
153
+ "dtype": "bfloat16",
154
+ "trust_remote_code": true
155
+ }
156
+ },
157
+ "boolq": {
158
+ "task": "boolq",
159
+ "tag": [
160
+ "super-glue-lm-eval-v1"
161
+ ],
162
+ "dataset_path": "super_glue",
163
+ "dataset_name": "boolq",
164
+ "training_split": "train",
165
+ "validation_split": "validation",
166
+ "doc_to_text": "{{passage}}\nQuestion: {{question}}?\nAnswer:",
167
+ "doc_to_target": "label",
168
+ "unsafe_code": false,
169
+ "doc_to_choice": [
170
+ "no",
171
+ "yes"
172
+ ],
173
+ "description": "",
174
+ "target_delimiter": " ",
175
+ "fewshot_delimiter": "\n\n",
176
+ "num_fewshot": 0,
177
+ "metric_list": [
178
+ {
179
+ "metric": "acc"
180
+ }
181
+ ],
182
+ "output_type": "multiple_choice",
183
+ "repeats": 1,
184
+ "should_decontaminate": true,
185
+ "doc_to_decontamination_query": "passage",
186
+ "metadata": {
187
+ "version": 2.0,
188
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
189
+ "dtype": "bfloat16",
190
+ "trust_remote_code": true
191
+ }
192
+ },
193
+ "hellaswag": {
194
+ "task": "hellaswag",
195
+ "tag": [
196
+ "multiple_choice"
197
+ ],
198
+ "dataset_path": "Rowan/hellaswag",
199
+ "training_split": "train",
200
+ "validation_split": "validation",
201
+ "process_docs": "def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:\n def _process_doc(doc):\n ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n out_doc = {\n \"query\": preprocess(doc[\"activity_label\"] + \": \" + ctx),\n \"choices\": [preprocess(ending) for ending in doc[\"endings\"]],\n \"gold\": int(doc[\"label\"]),\n }\n return out_doc\n\n return dataset.map(_process_doc)\n",
202
+ "doc_to_text": "{{query}}",
203
+ "doc_to_target": "{{label}}",
204
+ "unsafe_code": false,
205
+ "doc_to_choice": "choices",
206
+ "description": "",
207
+ "target_delimiter": " ",
208
+ "fewshot_delimiter": "\n\n",
209
+ "num_fewshot": 0,
210
+ "metric_list": [
211
+ {
212
+ "metric": "acc",
213
+ "aggregation": "mean",
214
+ "higher_is_better": true
215
+ },
216
+ {
217
+ "metric": "acc_norm",
218
+ "aggregation": "mean",
219
+ "higher_is_better": true
220
+ }
221
+ ],
222
+ "output_type": "multiple_choice",
223
+ "repeats": 1,
224
+ "should_decontaminate": false,
225
+ "metadata": {
226
+ "version": 1.0,
227
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
228
+ "dtype": "bfloat16",
229
+ "trust_remote_code": true
230
+ }
231
+ },
232
+ "lambada_standard": {
233
+ "task": "lambada_standard",
234
+ "tag": [
235
+ "lambada"
236
+ ],
237
+ "dataset_path": "lambada",
238
+ "validation_split": "validation",
239
+ "test_split": "test",
240
+ "doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
241
+ "doc_to_target": "{{' '+text.split(' ')[-1]}}",
242
+ "unsafe_code": false,
243
+ "description": "",
244
+ "target_delimiter": " ",
245
+ "fewshot_delimiter": "\n\n",
246
+ "num_fewshot": 0,
247
+ "metric_list": [
248
+ {
249
+ "metric": "perplexity",
250
+ "aggregation": "perplexity",
251
+ "higher_is_better": false
252
+ },
253
+ {
254
+ "metric": "acc",
255
+ "aggregation": "mean",
256
+ "higher_is_better": true
257
+ }
258
+ ],
259
+ "output_type": "loglikelihood",
260
+ "repeats": 1,
261
+ "should_decontaminate": true,
262
+ "doc_to_decontamination_query": "{{text}}",
263
+ "metadata": {
264
+ "version": 1.0,
265
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
266
+ "dtype": "bfloat16",
267
+ "trust_remote_code": true
268
+ }
269
+ },
270
+ "piqa": {
271
+ "task": "piqa",
272
+ "dataset_path": "baber/piqa",
273
+ "training_split": "train",
274
+ "validation_split": "validation",
275
+ "doc_to_text": "Question: {{goal}}\nAnswer:",
276
+ "doc_to_target": "label",
277
+ "unsafe_code": false,
278
+ "doc_to_choice": "{{[sol1, sol2]}}",
279
+ "description": "",
280
+ "target_delimiter": " ",
281
+ "fewshot_delimiter": "\n\n",
282
+ "num_fewshot": 0,
283
+ "metric_list": [
284
+ {
285
+ "metric": "acc",
286
+ "aggregation": "mean",
287
+ "higher_is_better": true
288
+ },
289
+ {
290
+ "metric": "acc_norm",
291
+ "aggregation": "mean",
292
+ "higher_is_better": true
293
+ }
294
+ ],
295
+ "output_type": "multiple_choice",
296
+ "repeats": 1,
297
+ "should_decontaminate": true,
298
+ "doc_to_decontamination_query": "goal",
299
+ "metadata": {
300
+ "version": 1.0,
301
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
302
+ "dtype": "bfloat16",
303
+ "trust_remote_code": true
304
+ }
305
+ },
306
+ "social_iqa": {
307
+ "task": "social_iqa",
308
+ "dataset_path": "social_i_qa",
309
+ "training_split": "train",
310
+ "validation_split": "validation",
311
+ "doc_to_text": "Q: {{context}} {{question}}\nA:",
312
+ "doc_to_target": "{{ (label|int) - 1 }}",
313
+ "unsafe_code": false,
314
+ "doc_to_choice": "{{[answerA, answerB, answerC]}}",
315
+ "description": "",
316
+ "target_delimiter": " ",
317
+ "fewshot_delimiter": "\n\n",
318
+ "num_fewshot": 0,
319
+ "metric_list": [
320
+ {
321
+ "metric": "acc",
322
+ "aggregation": "mean",
323
+ "higher_is_better": true
324
+ }
325
+ ],
326
+ "output_type": "multiple_choice",
327
+ "repeats": 1,
328
+ "should_decontaminate": false,
329
+ "metadata": {
330
+ "version": 0.0,
331
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
332
+ "dtype": "bfloat16",
333
+ "trust_remote_code": true
334
+ }
335
+ },
336
+ "wikitext": {
337
+ "task": "wikitext",
338
+ "dataset_path": "EleutherAI/wikitext_document_level",
339
+ "dataset_name": "wikitext-2-raw-v1",
340
+ "training_split": "train",
341
+ "validation_split": "validation",
342
+ "test_split": "test",
343
+ "doc_to_text": "",
344
+ "doc_to_target": "def wikitext_detokenizer(doc):\n string = doc[\"page\"]\n # contractions\n string = string.replace(\"s '\", \"s'\")\n string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n # number separators\n string = string.replace(\" @-@ \", \"-\")\n string = string.replace(\" @,@ \", \",\")\n string = string.replace(\" @.@ \", \".\")\n # punctuation\n string = string.replace(\" : \", \": \")\n string = string.replace(\" ; \", \"; \")\n string = string.replace(\" . \", \". \")\n string = string.replace(\" ! \", \"! \")\n string = string.replace(\" ? \", \"? \")\n string = string.replace(\" , \", \", \")\n # double brackets\n string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n # miscellaneous\n string = string.replace(\"= = = =\", \"====\")\n string = string.replace(\"= = =\", \"===\")\n string = string.replace(\"= =\", \"==\")\n string = string.replace(\" \" + chr(176) + \" \", chr(176))\n string = string.replace(\" \\n\", \"\\n\")\n string = string.replace(\"\\n \", \"\\n\")\n string = string.replace(\" N \", \" 1 \")\n string = string.replace(\" 's\", \"'s\")\n\n return string\n",
345
+ "unsafe_code": false,
346
+ "process_results": "def process_results(doc, results):\n (loglikelihood,) = results\n # IMPORTANT: wikitext counts number of words in *original doc before detokenization*\n _words = len(re.split(r\"\\s+\", doc[\"page\"]))\n _bytes = len(doc[\"page\"].encode(\"utf-8\"))\n return {\n \"word_perplexity\": (loglikelihood, _words),\n \"byte_perplexity\": (loglikelihood, _bytes),\n \"bits_per_byte\": (loglikelihood, _bytes),\n }\n",
347
+ "description": "",
348
+ "target_delimiter": " ",
349
+ "fewshot_delimiter": "\n\n",
350
+ "num_fewshot": 0,
351
+ "metric_list": [
352
+ {
353
+ "metric": "word_perplexity"
354
+ },
355
+ {
356
+ "metric": "byte_perplexity"
357
+ },
358
+ {
359
+ "metric": "bits_per_byte"
360
+ }
361
+ ],
362
+ "output_type": "loglikelihood_rolling",
363
+ "repeats": 1,
364
+ "should_decontaminate": true,
365
+ "doc_to_decontamination_query": "{{page}}",
366
+ "metadata": {
367
+ "version": 2.0,
368
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
369
+ "dtype": "bfloat16",
370
+ "trust_remote_code": true
371
+ }
372
+ },
373
+ "winogrande": {
374
+ "task": "winogrande",
375
+ "dataset_path": "winogrande",
376
+ "dataset_name": "winogrande_xl",
377
+ "training_split": "train",
378
+ "validation_split": "validation",
379
+ "doc_to_text": "def doc_to_text(doc):\n answer_to_num = {\"1\": 0, \"2\": 1}\n return answer_to_num[doc[\"answer\"]]\n",
380
+ "doc_to_target": "def doc_to_target(doc):\n idx = doc[\"sentence\"].index(\"_\") + 1\n return doc[\"sentence\"][idx:].strip()\n",
381
+ "unsafe_code": false,
382
+ "doc_to_choice": "def doc_to_choice(doc):\n idx = doc[\"sentence\"].index(\"_\")\n options = [doc[\"option1\"], doc[\"option2\"]]\n return [doc[\"sentence\"][:idx] + opt for opt in options]\n",
383
+ "description": "",
384
+ "target_delimiter": " ",
385
+ "fewshot_delimiter": "\n\n",
386
+ "num_fewshot": 0,
387
+ "metric_list": [
388
+ {
389
+ "metric": "acc",
390
+ "aggregation": "mean",
391
+ "higher_is_better": true
392
+ }
393
+ ],
394
+ "output_type": "multiple_choice",
395
+ "repeats": 1,
396
+ "should_decontaminate": true,
397
+ "doc_to_decontamination_query": "sentence",
398
+ "metadata": {
399
+ "version": 1.0,
400
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
401
+ "dtype": "bfloat16",
402
+ "trust_remote_code": true
403
+ }
404
+ }
405
+ },
406
+ "versions": {
407
+ "arc_challenge": 1.0,
408
+ "arc_easy": 1.0,
409
+ "boolq": 2.0,
410
+ "hellaswag": 1.0,
411
+ "lambada_standard": 1.0,
412
+ "piqa": 1.0,
413
+ "social_iqa": 0.0,
414
+ "wikitext": 2.0,
415
+ "winogrande": 1.0
416
+ },
417
+ "n-shot": {
418
+ "arc_challenge": 0,
419
+ "arc_easy": 0,
420
+ "boolq": 0,
421
+ "hellaswag": 0,
422
+ "lambada_standard": 0,
423
+ "piqa": 0,
424
+ "social_iqa": 0,
425
+ "wikitext": 0,
426
+ "winogrande": 0
427
+ },
428
+ "higher_is_better": {
429
+ "arc_challenge": {
430
+ "acc": true,
431
+ "acc_norm": true
432
+ },
433
+ "arc_easy": {
434
+ "acc": true,
435
+ "acc_norm": true
436
+ },
437
+ "boolq": {
438
+ "acc": true
439
+ },
440
+ "hellaswag": {
441
+ "acc": true,
442
+ "acc_norm": true
443
+ },
444
+ "lambada_standard": {
445
+ "perplexity": false,
446
+ "acc": true
447
+ },
448
+ "piqa": {
449
+ "acc": true,
450
+ "acc_norm": true
451
+ },
452
+ "social_iqa": {
453
+ "acc": true
454
+ },
455
+ "wikitext": {
456
+ "word_perplexity": false,
457
+ "byte_perplexity": false,
458
+ "bits_per_byte": false
459
+ },
460
+ "winogrande": {
461
+ "acc": true
462
+ }
463
+ },
464
+ "n-samples": {
465
+ "winogrande": {
466
+ "original": 1267,
467
+ "effective": 1267
468
+ },
469
+ "wikitext": {
470
+ "original": 62,
471
+ "effective": 62
472
+ },
473
+ "social_iqa": {
474
+ "original": 1954,
475
+ "effective": 1954
476
+ },
477
+ "piqa": {
478
+ "original": 1838,
479
+ "effective": 1838
480
+ },
481
+ "lambada_standard": {
482
+ "original": 5153,
483
+ "effective": 5153
484
+ },
485
+ "hellaswag": {
486
+ "original": 10042,
487
+ "effective": 10042
488
+ },
489
+ "boolq": {
490
+ "original": 3270,
491
+ "effective": 3270
492
+ },
493
+ "arc_easy": {
494
+ "original": 2376,
495
+ "effective": 2376
496
+ },
497
+ "arc_challenge": {
498
+ "original": 1172,
499
+ "effective": 1172
500
+ }
501
+ },
502
+ "config": {
503
+ "model": "hf",
504
+ "model_args": "pretrained=results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/,dtype=bfloat16,trust_remote_code=True,trust_remote_code=True",
505
+ "model_num_parameters": 221854518,
506
+ "model_dtype": "torch.bfloat16",
507
+ "model_revision": "main",
508
+ "model_sha": "",
509
+ "batch_size": "8",
510
+ "batch_sizes": [],
511
+ "device": "cuda:0",
512
+ "use_cache": null,
513
+ "limit": null,
514
+ "bootstrap_iters": 100000,
515
+ "gen_kwargs": null,
516
+ "random_seed": 0,
517
+ "numpy_seed": 1234,
518
+ "torch_seed": 1234,
519
+ "fewshot_seed": 1234
520
+ },
521
+ "git_hash": "core_v0.12.0-110-g747c34dc4",
522
+ "date": 1769090663.001165,
523
+ "pretty_env_info": "PyTorch version: 2.6.0+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 22.04.4 LTS (x86_64)\nGCC version: (conda-forge gcc 9.5.0-19) 9.5.0\nClang version: Could not collect\nCMake version: version 3.30.1\nLibc version: glibc-2.35\n\nPython version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)\nPython platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.35\nIs CUDA available: True\nCUDA runtime version: 12.4.131\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: NVIDIA A800-SXM4-80GB\nGPU 1: NVIDIA A800-SXM4-80GB\nGPU 2: NVIDIA A800-SXM4-80GB\nGPU 3: NVIDIA A800-SXM4-80GB\nGPU 4: NVIDIA A800-SXM4-80GB\nGPU 5: NVIDIA A800-SXM4-80GB\nGPU 6: NVIDIA A800-SXM4-80GB\nGPU 7: NVIDIA A800-SXM4-80GB\n\nNvidia driver version: 550.54.15\ncuDNN version: Could not collect\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nAddress sizes: 52 bits physical, 57 bits virtual\nByte Order: Little Endian\nCPU(s): 104\nOn-line CPU(s) list: 0-103\nVendor ID: GenuineIntel\nModel name: Intel(R) Xeon(R) Platinum 8470\nCPU family: 6\nModel: 143\nThread(s) per core: 1\nCore(s) per socket: 52\nSocket(s): 2\nStepping: 8\nCPU max MHz: 3800.0000\nCPU min MHz: 800.0000\nBogoMIPS: 4000.00\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities\nVirtualization: VT-x\nL1d cache: 4.9 MiB (104 instances)\nL1i cache: 3.3 MiB (104 instances)\nL2 cache: 208 MiB (104 instances)\nL3 cache: 210 MiB (2 instances)\nNUMA node(s): 2\nNUMA node0 CPU(s): 0-51\nNUMA node1 CPU(s): 52-103\nVulnerability Gather data sampling: Not affected\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Not affected\nVulnerability Retbleed: Not affected\nVulnerability Spec rstack overflow: Not affected\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Not affected\n\nVersions of relevant libraries:\n[pip3] numpy==1.26.4\n[pip3] nvidia-cublas-cu12==12.4.5.8\n[pip3] nvidia-cuda-cupti-cu12==12.4.127\n[pip3] nvidia-cuda-nvrtc-cu12==12.4.127\n[pip3] nvidia-cuda-runtime-cu12==12.4.127\n[pip3] nvidia-cudnn-cu12==9.1.0.70\n[pip3] nvidia-cufft-cu12==11.2.1.3\n[pip3] nvidia-curand-cu12==10.3.5.147\n[pip3] nvidia-cusolver-cu12==11.6.1.9\n[pip3] nvidia-cusparse-cu12==12.3.1.170\n[pip3] nvidia-cusparselt-cu12==0.6.2\n[pip3] nvidia-nccl-cu11==2.21.5\n[pip3] nvidia-nccl-cu12==2.21.5\n[pip3] nvidia-nvjitlink-cu12==12.4.127\n[pip3] nvidia-nvtx-cu12==12.4.127\n[pip3] torch==2.6.0\n[pip3] torchaudio==2.6.0\n[pip3] torchdata==0.11.0\n[pip3] torchvision==0.21.0\n[pip3] triton==3.2.0\n[conda] cuda-cudart 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-cudart_linux-64 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-cupti 12.4.127 he02047a_2 conda-forge\n[conda] cuda-libraries 12.4.0 ha770c72_0 conda-forge\n[conda] cuda-nvrtc 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-nvtx 12.4.127 he02047a_2 conda-forge\n[conda] cuda-opencl 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-runtime 12.4.0 ha804496_0 conda-forge\n[conda] ffmpeg 4.3 hf484d3e_0 pytorch\n[conda] libcublas 12.4.2.65 hd3aeb46_0 conda-forge\n[conda] libcufft 11.2.0.44 hd3aeb46_0 conda-forge\n[conda] libcurand 10.3.5.119 hd3aeb46_0 conda-forge\n[conda] libcusolver 11.6.0.99 hd3aeb46_0 conda-forge\n[conda] libcusparse 12.3.0.142 hd3aeb46_0 conda-forge\n[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch\n[conda] libnvjitlink 12.4.99 hd3aeb46_0 conda-forge\n[conda] mkl 2023.1.0 h213fc3f_46344 defaults\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi\n[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi\n[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi\n[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi\n[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi\n[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi\n[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi\n[conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi\n[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi\n[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi\n[conda] pytorch-cuda 12.4 hc786d27_6 pytorch\n[conda] pytorch-mutex 1.0 cuda pytorch\n[conda] torch 2.6.0 pypi_0 pypi\n[conda] torchaudio 2.6.0 pypi_0 pypi\n[conda] torchdata 0.11.0 pypi_0 pypi\n[conda] torchvision 0.21.0 pypi_0 pypi\n[conda] triton 3.2.0 pypi_0 pypi",
524
+ "transformers_version": "4.55.2",
525
+ "lm_eval_version": "0.4.9.1",
526
+ "upper_git_hash": null,
527
+ "tokenizer_pad_token": [
528
+ "<unk>",
529
+ "0"
530
+ ],
531
+ "tokenizer_eos_token": [
532
+ "<|im_end|>",
533
+ "73440"
534
+ ],
535
+ "tokenizer_bos_token": [
536
+ "<s>",
537
+ "1"
538
+ ],
539
+ "eot_token_id": 73440,
540
+ "max_length": 4096,
541
+ "task_hashes": {},
542
+ "model_source": "hf",
543
+ "model_name": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
544
+ "model_name_sanitized": "results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__",
545
+ "system_instruction": null,
546
+ "system_instruction_sha": null,
547
+ "fewshot_as_multiturn": false,
548
+ "chat_template": null,
549
+ "chat_template_sha": null,
550
+ "start_time": 623980.687927145,
551
+ "end_time": 624273.9871482,
552
+ "total_evaluation_time_seconds": "293.2992210550001"
553
+ }
evaluation/results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__/results_2026-01-22T22-17-51.225009.json ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "results": {
3
+ "arc_challenge": {
4
+ "alias": "arc_challenge",
5
+ "acc,none": 0.2022184300341297,
6
+ "acc_stderr,none": 0.011737454431872105,
7
+ "acc_norm,none": 0.23464163822525597,
8
+ "acc_norm_stderr,none": 0.012383873560768687
9
+ },
10
+ "arc_easy": {
11
+ "alias": "arc_easy",
12
+ "acc,none": 0.44402356902356904,
13
+ "acc_stderr,none": 0.010195285580783957,
14
+ "acc_norm,none": 0.40404040404040403,
15
+ "acc_norm_stderr,none": 0.010069061649549549
16
+ },
17
+ "boolq": {
18
+ "alias": "boolq",
19
+ "acc,none": 0.5715596330275229,
20
+ "acc_stderr,none": 0.00865502856151977
21
+ },
22
+ "hellaswag": {
23
+ "alias": "hellaswag",
24
+ "acc,none": 0.27962557259510057,
25
+ "acc_stderr,none": 0.004478979795506762,
26
+ "acc_norm,none": 0.29177454690300736,
27
+ "acc_norm_stderr,none": 0.0045365007141480096
28
+ },
29
+ "lambada_openai": {
30
+ "alias": "lambada_openai",
31
+ "perplexity,none": 166.09188392334596,
32
+ "perplexity_stderr,none": 7.549846792492689,
33
+ "acc,none": 0.21288569765185328,
34
+ "acc_stderr,none": 0.005703011105808082
35
+ },
36
+ "lambada_standard": {
37
+ "alias": "lambada_standard",
38
+ "perplexity,none": 536.8992699093774,
39
+ "perplexity_stderr,none": 25.490016733651576,
40
+ "acc,none": 0.15796623326217737,
41
+ "acc_stderr,none": 0.005081114222421346
42
+ },
43
+ "piqa": {
44
+ "alias": "piqa",
45
+ "acc,none": 0.6093579978237215,
46
+ "acc_stderr,none": 0.01138337776005358,
47
+ "acc_norm,none": 0.6131664853101197,
48
+ "acc_norm_stderr,none": 0.011363095931902854
49
+ },
50
+ "social_iqa": {
51
+ "alias": "social_iqa",
52
+ "acc,none": 0.368474923234391,
53
+ "acc_stderr,none": 0.010915613431252628
54
+ },
55
+ "wikitext": {
56
+ "alias": "wikitext",
57
+ "word_perplexity,none": 54.32450126248072,
58
+ "word_perplexity_stderr,none": "N/A",
59
+ "byte_perplexity,none": 2.1108281516522545,
60
+ "byte_perplexity_stderr,none": "N/A",
61
+ "bits_per_byte,none": 1.077809129679421,
62
+ "bits_per_byte_stderr,none": "N/A"
63
+ },
64
+ "winogrande": {
65
+ "alias": "winogrande",
66
+ "acc,none": 0.5067087608524072,
67
+ "acc_stderr,none": 0.014051220692330349
68
+ }
69
+ },
70
+ "group_subtasks": {
71
+ "arc_challenge": [],
72
+ "arc_easy": [],
73
+ "boolq": [],
74
+ "hellaswag": [],
75
+ "lambada_openai": [],
76
+ "lambada_standard": [],
77
+ "piqa": [],
78
+ "social_iqa": [],
79
+ "wikitext": [],
80
+ "winogrande": []
81
+ },
82
+ "configs": {
83
+ "arc_challenge": {
84
+ "task": "arc_challenge",
85
+ "tag": [
86
+ "ai2_arc"
87
+ ],
88
+ "dataset_path": "allenai/ai2_arc",
89
+ "dataset_name": "ARC-Challenge",
90
+ "training_split": "train",
91
+ "validation_split": "validation",
92
+ "test_split": "test",
93
+ "doc_to_text": "Question: {{question}}\nAnswer:",
94
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
95
+ "unsafe_code": false,
96
+ "doc_to_choice": "{{choices.text}}",
97
+ "description": "",
98
+ "target_delimiter": " ",
99
+ "fewshot_delimiter": "\n\n",
100
+ "num_fewshot": 0,
101
+ "metric_list": [
102
+ {
103
+ "metric": "acc",
104
+ "aggregation": "mean",
105
+ "higher_is_better": true
106
+ },
107
+ {
108
+ "metric": "acc_norm",
109
+ "aggregation": "mean",
110
+ "higher_is_better": true
111
+ }
112
+ ],
113
+ "output_type": "multiple_choice",
114
+ "repeats": 1,
115
+ "should_decontaminate": true,
116
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
117
+ "metadata": {
118
+ "version": 1.0,
119
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
120
+ "dtype": "bfloat16",
121
+ "trust_remote_code": true
122
+ }
123
+ },
124
+ "arc_easy": {
125
+ "task": "arc_easy",
126
+ "tag": [
127
+ "ai2_arc"
128
+ ],
129
+ "dataset_path": "allenai/ai2_arc",
130
+ "dataset_name": "ARC-Easy",
131
+ "training_split": "train",
132
+ "validation_split": "validation",
133
+ "test_split": "test",
134
+ "doc_to_text": "Question: {{question}}\nAnswer:",
135
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
136
+ "unsafe_code": false,
137
+ "doc_to_choice": "{{choices.text}}",
138
+ "description": "",
139
+ "target_delimiter": " ",
140
+ "fewshot_delimiter": "\n\n",
141
+ "num_fewshot": 0,
142
+ "metric_list": [
143
+ {
144
+ "metric": "acc",
145
+ "aggregation": "mean",
146
+ "higher_is_better": true
147
+ },
148
+ {
149
+ "metric": "acc_norm",
150
+ "aggregation": "mean",
151
+ "higher_is_better": true
152
+ }
153
+ ],
154
+ "output_type": "multiple_choice",
155
+ "repeats": 1,
156
+ "should_decontaminate": true,
157
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
158
+ "metadata": {
159
+ "version": 1.0,
160
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
161
+ "dtype": "bfloat16",
162
+ "trust_remote_code": true
163
+ }
164
+ },
165
+ "boolq": {
166
+ "task": "boolq",
167
+ "tag": [
168
+ "super-glue-lm-eval-v1"
169
+ ],
170
+ "dataset_path": "super_glue",
171
+ "dataset_name": "boolq",
172
+ "training_split": "train",
173
+ "validation_split": "validation",
174
+ "doc_to_text": "{{passage}}\nQuestion: {{question}}?\nAnswer:",
175
+ "doc_to_target": "label",
176
+ "unsafe_code": false,
177
+ "doc_to_choice": [
178
+ "no",
179
+ "yes"
180
+ ],
181
+ "description": "",
182
+ "target_delimiter": " ",
183
+ "fewshot_delimiter": "\n\n",
184
+ "num_fewshot": 0,
185
+ "metric_list": [
186
+ {
187
+ "metric": "acc"
188
+ }
189
+ ],
190
+ "output_type": "multiple_choice",
191
+ "repeats": 1,
192
+ "should_decontaminate": true,
193
+ "doc_to_decontamination_query": "passage",
194
+ "metadata": {
195
+ "version": 2.0,
196
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
197
+ "dtype": "bfloat16",
198
+ "trust_remote_code": true
199
+ }
200
+ },
201
+ "hellaswag": {
202
+ "task": "hellaswag",
203
+ "tag": [
204
+ "multiple_choice"
205
+ ],
206
+ "dataset_path": "Rowan/hellaswag",
207
+ "training_split": "train",
208
+ "validation_split": "validation",
209
+ "process_docs": "def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:\n def _process_doc(doc):\n ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n out_doc = {\n \"query\": preprocess(doc[\"activity_label\"] + \": \" + ctx),\n \"choices\": [preprocess(ending) for ending in doc[\"endings\"]],\n \"gold\": int(doc[\"label\"]),\n }\n return out_doc\n\n return dataset.map(_process_doc)\n",
210
+ "doc_to_text": "{{query}}",
211
+ "doc_to_target": "{{label}}",
212
+ "unsafe_code": false,
213
+ "doc_to_choice": "choices",
214
+ "description": "",
215
+ "target_delimiter": " ",
216
+ "fewshot_delimiter": "\n\n",
217
+ "num_fewshot": 0,
218
+ "metric_list": [
219
+ {
220
+ "metric": "acc",
221
+ "aggregation": "mean",
222
+ "higher_is_better": true
223
+ },
224
+ {
225
+ "metric": "acc_norm",
226
+ "aggregation": "mean",
227
+ "higher_is_better": true
228
+ }
229
+ ],
230
+ "output_type": "multiple_choice",
231
+ "repeats": 1,
232
+ "should_decontaminate": false,
233
+ "metadata": {
234
+ "version": 1.0,
235
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
236
+ "dtype": "bfloat16",
237
+ "trust_remote_code": true
238
+ }
239
+ },
240
+ "lambada_openai": {
241
+ "task": "lambada_openai",
242
+ "tag": [
243
+ "lambada"
244
+ ],
245
+ "dataset_path": "EleutherAI/lambada_openai",
246
+ "dataset_name": "default",
247
+ "test_split": "test",
248
+ "doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
249
+ "doc_to_target": "{{' '+text.split(' ')[-1]}}",
250
+ "unsafe_code": false,
251
+ "description": "",
252
+ "target_delimiter": " ",
253
+ "fewshot_delimiter": "\n\n",
254
+ "num_fewshot": 0,
255
+ "metric_list": [
256
+ {
257
+ "metric": "perplexity",
258
+ "aggregation": "perplexity",
259
+ "higher_is_better": false
260
+ },
261
+ {
262
+ "metric": "acc",
263
+ "aggregation": "mean",
264
+ "higher_is_better": true
265
+ }
266
+ ],
267
+ "output_type": "loglikelihood",
268
+ "repeats": 1,
269
+ "should_decontaminate": true,
270
+ "doc_to_decontamination_query": "{{text}}",
271
+ "metadata": {
272
+ "version": 1.0,
273
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
274
+ "dtype": "bfloat16",
275
+ "trust_remote_code": true
276
+ }
277
+ },
278
+ "lambada_standard": {
279
+ "task": "lambada_standard",
280
+ "tag": [
281
+ "lambada"
282
+ ],
283
+ "dataset_path": "lambada",
284
+ "validation_split": "validation",
285
+ "test_split": "test",
286
+ "doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
287
+ "doc_to_target": "{{' '+text.split(' ')[-1]}}",
288
+ "unsafe_code": false,
289
+ "description": "",
290
+ "target_delimiter": " ",
291
+ "fewshot_delimiter": "\n\n",
292
+ "num_fewshot": 0,
293
+ "metric_list": [
294
+ {
295
+ "metric": "perplexity",
296
+ "aggregation": "perplexity",
297
+ "higher_is_better": false
298
+ },
299
+ {
300
+ "metric": "acc",
301
+ "aggregation": "mean",
302
+ "higher_is_better": true
303
+ }
304
+ ],
305
+ "output_type": "loglikelihood",
306
+ "repeats": 1,
307
+ "should_decontaminate": true,
308
+ "doc_to_decontamination_query": "{{text}}",
309
+ "metadata": {
310
+ "version": 1.0,
311
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
312
+ "dtype": "bfloat16",
313
+ "trust_remote_code": true
314
+ }
315
+ },
316
+ "piqa": {
317
+ "task": "piqa",
318
+ "dataset_path": "baber/piqa",
319
+ "training_split": "train",
320
+ "validation_split": "validation",
321
+ "doc_to_text": "Question: {{goal}}\nAnswer:",
322
+ "doc_to_target": "label",
323
+ "unsafe_code": false,
324
+ "doc_to_choice": "{{[sol1, sol2]}}",
325
+ "description": "",
326
+ "target_delimiter": " ",
327
+ "fewshot_delimiter": "\n\n",
328
+ "num_fewshot": 0,
329
+ "metric_list": [
330
+ {
331
+ "metric": "acc",
332
+ "aggregation": "mean",
333
+ "higher_is_better": true
334
+ },
335
+ {
336
+ "metric": "acc_norm",
337
+ "aggregation": "mean",
338
+ "higher_is_better": true
339
+ }
340
+ ],
341
+ "output_type": "multiple_choice",
342
+ "repeats": 1,
343
+ "should_decontaminate": true,
344
+ "doc_to_decontamination_query": "goal",
345
+ "metadata": {
346
+ "version": 1.0,
347
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
348
+ "dtype": "bfloat16",
349
+ "trust_remote_code": true
350
+ }
351
+ },
352
+ "social_iqa": {
353
+ "task": "social_iqa",
354
+ "dataset_path": "social_i_qa",
355
+ "training_split": "train",
356
+ "validation_split": "validation",
357
+ "doc_to_text": "Q: {{context}} {{question}}\nA:",
358
+ "doc_to_target": "{{ (label|int) - 1 }}",
359
+ "unsafe_code": false,
360
+ "doc_to_choice": "{{[answerA, answerB, answerC]}}",
361
+ "description": "",
362
+ "target_delimiter": " ",
363
+ "fewshot_delimiter": "\n\n",
364
+ "num_fewshot": 0,
365
+ "metric_list": [
366
+ {
367
+ "metric": "acc",
368
+ "aggregation": "mean",
369
+ "higher_is_better": true
370
+ }
371
+ ],
372
+ "output_type": "multiple_choice",
373
+ "repeats": 1,
374
+ "should_decontaminate": false,
375
+ "metadata": {
376
+ "version": 0.0,
377
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
378
+ "dtype": "bfloat16",
379
+ "trust_remote_code": true
380
+ }
381
+ },
382
+ "wikitext": {
383
+ "task": "wikitext",
384
+ "dataset_path": "EleutherAI/wikitext_document_level",
385
+ "dataset_name": "wikitext-2-raw-v1",
386
+ "training_split": "train",
387
+ "validation_split": "validation",
388
+ "test_split": "test",
389
+ "doc_to_text": "",
390
+ "doc_to_target": "def wikitext_detokenizer(doc):\n string = doc[\"page\"]\n # contractions\n string = string.replace(\"s '\", \"s'\")\n string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n # number separators\n string = string.replace(\" @-@ \", \"-\")\n string = string.replace(\" @,@ \", \",\")\n string = string.replace(\" @.@ \", \".\")\n # punctuation\n string = string.replace(\" : \", \": \")\n string = string.replace(\" ; \", \"; \")\n string = string.replace(\" . \", \". \")\n string = string.replace(\" ! \", \"! \")\n string = string.replace(\" ? \", \"? \")\n string = string.replace(\" , \", \", \")\n # double brackets\n string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n # miscellaneous\n string = string.replace(\"= = = =\", \"====\")\n string = string.replace(\"= = =\", \"===\")\n string = string.replace(\"= =\", \"==\")\n string = string.replace(\" \" + chr(176) + \" \", chr(176))\n string = string.replace(\" \\n\", \"\\n\")\n string = string.replace(\"\\n \", \"\\n\")\n string = string.replace(\" N \", \" 1 \")\n string = string.replace(\" 's\", \"'s\")\n\n return string\n",
391
+ "unsafe_code": false,
392
+ "process_results": "def process_results(doc, results):\n (loglikelihood,) = results\n # IMPORTANT: wikitext counts number of words in *original doc before detokenization*\n _words = len(re.split(r\"\\s+\", doc[\"page\"]))\n _bytes = len(doc[\"page\"].encode(\"utf-8\"))\n return {\n \"word_perplexity\": (loglikelihood, _words),\n \"byte_perplexity\": (loglikelihood, _bytes),\n \"bits_per_byte\": (loglikelihood, _bytes),\n }\n",
393
+ "description": "",
394
+ "target_delimiter": " ",
395
+ "fewshot_delimiter": "\n\n",
396
+ "num_fewshot": 0,
397
+ "metric_list": [
398
+ {
399
+ "metric": "word_perplexity"
400
+ },
401
+ {
402
+ "metric": "byte_perplexity"
403
+ },
404
+ {
405
+ "metric": "bits_per_byte"
406
+ }
407
+ ],
408
+ "output_type": "loglikelihood_rolling",
409
+ "repeats": 1,
410
+ "should_decontaminate": true,
411
+ "doc_to_decontamination_query": "{{page}}",
412
+ "metadata": {
413
+ "version": 2.0,
414
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
415
+ "dtype": "bfloat16",
416
+ "trust_remote_code": true
417
+ }
418
+ },
419
+ "winogrande": {
420
+ "task": "winogrande",
421
+ "dataset_path": "winogrande",
422
+ "dataset_name": "winogrande_xl",
423
+ "training_split": "train",
424
+ "validation_split": "validation",
425
+ "doc_to_text": "def doc_to_text(doc):\n answer_to_num = {\"1\": 0, \"2\": 1}\n return answer_to_num[doc[\"answer\"]]\n",
426
+ "doc_to_target": "def doc_to_target(doc):\n idx = doc[\"sentence\"].index(\"_\") + 1\n return doc[\"sentence\"][idx:].strip()\n",
427
+ "unsafe_code": false,
428
+ "doc_to_choice": "def doc_to_choice(doc):\n idx = doc[\"sentence\"].index(\"_\")\n options = [doc[\"option1\"], doc[\"option2\"]]\n return [doc[\"sentence\"][:idx] + opt for opt in options]\n",
429
+ "description": "",
430
+ "target_delimiter": " ",
431
+ "fewshot_delimiter": "\n\n",
432
+ "num_fewshot": 0,
433
+ "metric_list": [
434
+ {
435
+ "metric": "acc",
436
+ "aggregation": "mean",
437
+ "higher_is_better": true
438
+ }
439
+ ],
440
+ "output_type": "multiple_choice",
441
+ "repeats": 1,
442
+ "should_decontaminate": true,
443
+ "doc_to_decontamination_query": "sentence",
444
+ "metadata": {
445
+ "version": 1.0,
446
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
447
+ "dtype": "bfloat16",
448
+ "trust_remote_code": true
449
+ }
450
+ }
451
+ },
452
+ "versions": {
453
+ "arc_challenge": 1.0,
454
+ "arc_easy": 1.0,
455
+ "boolq": 2.0,
456
+ "hellaswag": 1.0,
457
+ "lambada_openai": 1.0,
458
+ "lambada_standard": 1.0,
459
+ "piqa": 1.0,
460
+ "social_iqa": 0.0,
461
+ "wikitext": 2.0,
462
+ "winogrande": 1.0
463
+ },
464
+ "n-shot": {
465
+ "arc_challenge": 0,
466
+ "arc_easy": 0,
467
+ "boolq": 0,
468
+ "hellaswag": 0,
469
+ "lambada_openai": 0,
470
+ "lambada_standard": 0,
471
+ "piqa": 0,
472
+ "social_iqa": 0,
473
+ "wikitext": 0,
474
+ "winogrande": 0
475
+ },
476
+ "higher_is_better": {
477
+ "arc_challenge": {
478
+ "acc": true,
479
+ "acc_norm": true
480
+ },
481
+ "arc_easy": {
482
+ "acc": true,
483
+ "acc_norm": true
484
+ },
485
+ "boolq": {
486
+ "acc": true
487
+ },
488
+ "hellaswag": {
489
+ "acc": true,
490
+ "acc_norm": true
491
+ },
492
+ "lambada_openai": {
493
+ "perplexity": false,
494
+ "acc": true
495
+ },
496
+ "lambada_standard": {
497
+ "perplexity": false,
498
+ "acc": true
499
+ },
500
+ "piqa": {
501
+ "acc": true,
502
+ "acc_norm": true
503
+ },
504
+ "social_iqa": {
505
+ "acc": true
506
+ },
507
+ "wikitext": {
508
+ "word_perplexity": false,
509
+ "byte_perplexity": false,
510
+ "bits_per_byte": false
511
+ },
512
+ "winogrande": {
513
+ "acc": true
514
+ }
515
+ },
516
+ "n-samples": {
517
+ "winogrande": {
518
+ "original": 1267,
519
+ "effective": 1267
520
+ },
521
+ "wikitext": {
522
+ "original": 62,
523
+ "effective": 62
524
+ },
525
+ "social_iqa": {
526
+ "original": 1954,
527
+ "effective": 1954
528
+ },
529
+ "piqa": {
530
+ "original": 1838,
531
+ "effective": 1838
532
+ },
533
+ "lambada_standard": {
534
+ "original": 5153,
535
+ "effective": 5153
536
+ },
537
+ "lambada_openai": {
538
+ "original": 5153,
539
+ "effective": 5153
540
+ },
541
+ "hellaswag": {
542
+ "original": 10042,
543
+ "effective": 10042
544
+ },
545
+ "boolq": {
546
+ "original": 3270,
547
+ "effective": 3270
548
+ },
549
+ "arc_easy": {
550
+ "original": 2376,
551
+ "effective": 2376
552
+ },
553
+ "arc_challenge": {
554
+ "original": 1172,
555
+ "effective": 1172
556
+ }
557
+ },
558
+ "config": {
559
+ "model": "hf",
560
+ "model_args": "pretrained=results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/,dtype=bfloat16,trust_remote_code=True,trust_remote_code=True",
561
+ "model_num_parameters": 221854518,
562
+ "model_dtype": "torch.bfloat16",
563
+ "model_revision": "main",
564
+ "model_sha": "",
565
+ "batch_size": "8",
566
+ "batch_sizes": [],
567
+ "device": "cuda:0",
568
+ "use_cache": null,
569
+ "limit": null,
570
+ "bootstrap_iters": 100000,
571
+ "gen_kwargs": null,
572
+ "random_seed": 0,
573
+ "numpy_seed": 1234,
574
+ "torch_seed": 1234,
575
+ "fewshot_seed": 1234
576
+ },
577
+ "git_hash": "core_v0.12.0-110-g747c34dc4",
578
+ "date": 1769091164.868056,
579
+ "pretty_env_info": "PyTorch version: 2.6.0+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 22.04.4 LTS (x86_64)\nGCC version: (conda-forge gcc 9.5.0-19) 9.5.0\nClang version: Could not collect\nCMake version: version 3.30.1\nLibc version: glibc-2.35\n\nPython version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)\nPython platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.35\nIs CUDA available: True\nCUDA runtime version: 12.4.131\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: NVIDIA A800-SXM4-80GB\nGPU 1: NVIDIA A800-SXM4-80GB\nGPU 2: NVIDIA A800-SXM4-80GB\nGPU 3: NVIDIA A800-SXM4-80GB\nGPU 4: NVIDIA A800-SXM4-80GB\nGPU 5: NVIDIA A800-SXM4-80GB\nGPU 6: NVIDIA A800-SXM4-80GB\nGPU 7: NVIDIA A800-SXM4-80GB\n\nNvidia driver version: 550.54.15\ncuDNN version: Could not collect\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nAddress sizes: 52 bits physical, 57 bits virtual\nByte Order: Little Endian\nCPU(s): 104\nOn-line CPU(s) list: 0-103\nVendor ID: GenuineIntel\nModel name: Intel(R) Xeon(R) Platinum 8470\nCPU family: 6\nModel: 143\nThread(s) per core: 1\nCore(s) per socket: 52\nSocket(s): 2\nStepping: 8\nCPU max MHz: 3800.0000\nCPU min MHz: 800.0000\nBogoMIPS: 4000.00\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities\nVirtualization: VT-x\nL1d cache: 4.9 MiB (104 instances)\nL1i cache: 3.3 MiB (104 instances)\nL2 cache: 208 MiB (104 instances)\nL3 cache: 210 MiB (2 instances)\nNUMA node(s): 2\nNUMA node0 CPU(s): 0-51\nNUMA node1 CPU(s): 52-103\nVulnerability Gather data sampling: Not affected\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Not affected\nVulnerability Retbleed: Not affected\nVulnerability Spec rstack overflow: Not affected\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Not affected\n\nVersions of relevant libraries:\n[pip3] numpy==1.26.4\n[pip3] nvidia-cublas-cu12==12.4.5.8\n[pip3] nvidia-cuda-cupti-cu12==12.4.127\n[pip3] nvidia-cuda-nvrtc-cu12==12.4.127\n[pip3] nvidia-cuda-runtime-cu12==12.4.127\n[pip3] nvidia-cudnn-cu12==9.1.0.70\n[pip3] nvidia-cufft-cu12==11.2.1.3\n[pip3] nvidia-curand-cu12==10.3.5.147\n[pip3] nvidia-cusolver-cu12==11.6.1.9\n[pip3] nvidia-cusparse-cu12==12.3.1.170\n[pip3] nvidia-cusparselt-cu12==0.6.2\n[pip3] nvidia-nccl-cu11==2.21.5\n[pip3] nvidia-nccl-cu12==2.21.5\n[pip3] nvidia-nvjitlink-cu12==12.4.127\n[pip3] nvidia-nvtx-cu12==12.4.127\n[pip3] torch==2.6.0\n[pip3] torchaudio==2.6.0\n[pip3] torchdata==0.11.0\n[pip3] torchvision==0.21.0\n[pip3] triton==3.2.0\n[conda] cuda-cudart 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-cudart_linux-64 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-cupti 12.4.127 he02047a_2 conda-forge\n[conda] cuda-libraries 12.4.0 ha770c72_0 conda-forge\n[conda] cuda-nvrtc 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-nvtx 12.4.127 he02047a_2 conda-forge\n[conda] cuda-opencl 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-runtime 12.4.0 ha804496_0 conda-forge\n[conda] ffmpeg 4.3 hf484d3e_0 pytorch\n[conda] libcublas 12.4.2.65 hd3aeb46_0 conda-forge\n[conda] libcufft 11.2.0.44 hd3aeb46_0 conda-forge\n[conda] libcurand 10.3.5.119 hd3aeb46_0 conda-forge\n[conda] libcusolver 11.6.0.99 hd3aeb46_0 conda-forge\n[conda] libcusparse 12.3.0.142 hd3aeb46_0 conda-forge\n[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch\n[conda] libnvjitlink 12.4.99 hd3aeb46_0 conda-forge\n[conda] mkl 2023.1.0 h213fc3f_46344 defaults\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi\n[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi\n[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi\n[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi\n[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi\n[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi\n[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi\n[conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi\n[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi\n[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi\n[conda] pytorch-cuda 12.4 hc786d27_6 pytorch\n[conda] pytorch-mutex 1.0 cuda pytorch\n[conda] torch 2.6.0 pypi_0 pypi\n[conda] torchaudio 2.6.0 pypi_0 pypi\n[conda] torchdata 0.11.0 pypi_0 pypi\n[conda] torchvision 0.21.0 pypi_0 pypi\n[conda] triton 3.2.0 pypi_0 pypi",
580
+ "transformers_version": "4.55.2",
581
+ "lm_eval_version": "0.4.9.1",
582
+ "upper_git_hash": null,
583
+ "tokenizer_pad_token": [
584
+ "<unk>",
585
+ "0"
586
+ ],
587
+ "tokenizer_eos_token": [
588
+ "<|im_end|>",
589
+ "73440"
590
+ ],
591
+ "tokenizer_bos_token": [
592
+ "<s>",
593
+ "1"
594
+ ],
595
+ "eot_token_id": 73440,
596
+ "max_length": 4096,
597
+ "task_hashes": {},
598
+ "model_source": "hf",
599
+ "model_name": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
600
+ "model_name_sanitized": "results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__",
601
+ "system_instruction": null,
602
+ "system_instruction_sha": null,
603
+ "fewshot_as_multiturn": false,
604
+ "chat_template": null,
605
+ "chat_template_sha": null,
606
+ "start_time": 624482.978993429,
607
+ "end_time": 624797.261986388,
608
+ "total_evaluation_time_seconds": "314.2829929590225"
609
+ }
evaluation/results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__/results_2026-01-23T12-47-10.690021.json ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "results": {
3
+ "arc_challenge": {
4
+ "alias": "arc_challenge",
5
+ "acc,none": 0.2022184300341297,
6
+ "acc_stderr,none": 0.011737454431872105,
7
+ "acc_norm,none": 0.23464163822525597,
8
+ "acc_norm_stderr,none": 0.012383873560768687
9
+ },
10
+ "arc_easy": {
11
+ "alias": "arc_easy",
12
+ "acc,none": 0.44402356902356904,
13
+ "acc_stderr,none": 0.010195285580783957,
14
+ "acc_norm,none": 0.40404040404040403,
15
+ "acc_norm_stderr,none": 0.010069061649549549
16
+ },
17
+ "boolq": {
18
+ "alias": "boolq",
19
+ "acc,none": 0.5715596330275229,
20
+ "acc_stderr,none": 0.00865502856151977
21
+ },
22
+ "hellaswag": {
23
+ "alias": "hellaswag",
24
+ "acc,none": 0.27962557259510057,
25
+ "acc_stderr,none": 0.004478979795506762,
26
+ "acc_norm,none": 0.29177454690300736,
27
+ "acc_norm_stderr,none": 0.0045365007141480096
28
+ },
29
+ "lambada_openai": {
30
+ "alias": "lambada_openai",
31
+ "perplexity,none": 166.09188392334596,
32
+ "perplexity_stderr,none": 7.549846792492689,
33
+ "acc,none": 0.21288569765185328,
34
+ "acc_stderr,none": 0.005703011105808082
35
+ },
36
+ "lambada_standard": {
37
+ "alias": "lambada_standard",
38
+ "perplexity,none": 536.8992699093774,
39
+ "perplexity_stderr,none": 25.490016733651576,
40
+ "acc,none": 0.15796623326217737,
41
+ "acc_stderr,none": 0.005081114222421346
42
+ },
43
+ "piqa": {
44
+ "alias": "piqa",
45
+ "acc,none": 0.6093579978237215,
46
+ "acc_stderr,none": 0.01138337776005358,
47
+ "acc_norm,none": 0.6131664853101197,
48
+ "acc_norm_stderr,none": 0.011363095931902854
49
+ },
50
+ "social_iqa": {
51
+ "alias": "social_iqa",
52
+ "acc,none": 0.368474923234391,
53
+ "acc_stderr,none": 0.010915613431252628
54
+ },
55
+ "wikitext": {
56
+ "alias": "wikitext",
57
+ "word_perplexity,none": 54.32450126248072,
58
+ "word_perplexity_stderr,none": "N/A",
59
+ "byte_perplexity,none": 2.1108281516522545,
60
+ "byte_perplexity_stderr,none": "N/A",
61
+ "bits_per_byte,none": 1.077809129679421,
62
+ "bits_per_byte_stderr,none": "N/A"
63
+ },
64
+ "winogrande": {
65
+ "alias": "winogrande",
66
+ "acc,none": 0.5067087608524072,
67
+ "acc_stderr,none": 0.014051220692330349
68
+ }
69
+ },
70
+ "group_subtasks": {
71
+ "arc_challenge": [],
72
+ "arc_easy": [],
73
+ "boolq": [],
74
+ "hellaswag": [],
75
+ "lambada_openai": [],
76
+ "lambada_standard": [],
77
+ "piqa": [],
78
+ "social_iqa": [],
79
+ "wikitext": [],
80
+ "winogrande": []
81
+ },
82
+ "configs": {
83
+ "arc_challenge": {
84
+ "task": "arc_challenge",
85
+ "tag": [
86
+ "ai2_arc"
87
+ ],
88
+ "dataset_path": "allenai/ai2_arc",
89
+ "dataset_name": "ARC-Challenge",
90
+ "training_split": "train",
91
+ "validation_split": "validation",
92
+ "test_split": "test",
93
+ "doc_to_text": "Question: {{question}}\nAnswer:",
94
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
95
+ "unsafe_code": false,
96
+ "doc_to_choice": "{{choices.text}}",
97
+ "description": "",
98
+ "target_delimiter": " ",
99
+ "fewshot_delimiter": "\n\n",
100
+ "num_fewshot": 0,
101
+ "metric_list": [
102
+ {
103
+ "metric": "acc",
104
+ "aggregation": "mean",
105
+ "higher_is_better": true
106
+ },
107
+ {
108
+ "metric": "acc_norm",
109
+ "aggregation": "mean",
110
+ "higher_is_better": true
111
+ }
112
+ ],
113
+ "output_type": "multiple_choice",
114
+ "repeats": 1,
115
+ "should_decontaminate": true,
116
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
117
+ "metadata": {
118
+ "version": 1.0,
119
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
120
+ "dtype": "bfloat16",
121
+ "trust_remote_code": true
122
+ }
123
+ },
124
+ "arc_easy": {
125
+ "task": "arc_easy",
126
+ "tag": [
127
+ "ai2_arc"
128
+ ],
129
+ "dataset_path": "allenai/ai2_arc",
130
+ "dataset_name": "ARC-Easy",
131
+ "training_split": "train",
132
+ "validation_split": "validation",
133
+ "test_split": "test",
134
+ "doc_to_text": "Question: {{question}}\nAnswer:",
135
+ "doc_to_target": "{{choices.label.index(answerKey)}}",
136
+ "unsafe_code": false,
137
+ "doc_to_choice": "{{choices.text}}",
138
+ "description": "",
139
+ "target_delimiter": " ",
140
+ "fewshot_delimiter": "\n\n",
141
+ "num_fewshot": 0,
142
+ "metric_list": [
143
+ {
144
+ "metric": "acc",
145
+ "aggregation": "mean",
146
+ "higher_is_better": true
147
+ },
148
+ {
149
+ "metric": "acc_norm",
150
+ "aggregation": "mean",
151
+ "higher_is_better": true
152
+ }
153
+ ],
154
+ "output_type": "multiple_choice",
155
+ "repeats": 1,
156
+ "should_decontaminate": true,
157
+ "doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
158
+ "metadata": {
159
+ "version": 1.0,
160
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
161
+ "dtype": "bfloat16",
162
+ "trust_remote_code": true
163
+ }
164
+ },
165
+ "boolq": {
166
+ "task": "boolq",
167
+ "tag": [
168
+ "super-glue-lm-eval-v1"
169
+ ],
170
+ "dataset_path": "super_glue",
171
+ "dataset_name": "boolq",
172
+ "training_split": "train",
173
+ "validation_split": "validation",
174
+ "doc_to_text": "{{passage}}\nQuestion: {{question}}?\nAnswer:",
175
+ "doc_to_target": "label",
176
+ "unsafe_code": false,
177
+ "doc_to_choice": [
178
+ "no",
179
+ "yes"
180
+ ],
181
+ "description": "",
182
+ "target_delimiter": " ",
183
+ "fewshot_delimiter": "\n\n",
184
+ "num_fewshot": 0,
185
+ "metric_list": [
186
+ {
187
+ "metric": "acc"
188
+ }
189
+ ],
190
+ "output_type": "multiple_choice",
191
+ "repeats": 1,
192
+ "should_decontaminate": true,
193
+ "doc_to_decontamination_query": "passage",
194
+ "metadata": {
195
+ "version": 2.0,
196
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
197
+ "dtype": "bfloat16",
198
+ "trust_remote_code": true
199
+ }
200
+ },
201
+ "hellaswag": {
202
+ "task": "hellaswag",
203
+ "tag": [
204
+ "multiple_choice"
205
+ ],
206
+ "dataset_path": "Rowan/hellaswag",
207
+ "training_split": "train",
208
+ "validation_split": "validation",
209
+ "process_docs": "def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:\n def _process_doc(doc):\n ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n out_doc = {\n \"query\": preprocess(doc[\"activity_label\"] + \": \" + ctx),\n \"choices\": [preprocess(ending) for ending in doc[\"endings\"]],\n \"gold\": int(doc[\"label\"]),\n }\n return out_doc\n\n return dataset.map(_process_doc)\n",
210
+ "doc_to_text": "{{query}}",
211
+ "doc_to_target": "{{label}}",
212
+ "unsafe_code": false,
213
+ "doc_to_choice": "choices",
214
+ "description": "",
215
+ "target_delimiter": " ",
216
+ "fewshot_delimiter": "\n\n",
217
+ "num_fewshot": 0,
218
+ "metric_list": [
219
+ {
220
+ "metric": "acc",
221
+ "aggregation": "mean",
222
+ "higher_is_better": true
223
+ },
224
+ {
225
+ "metric": "acc_norm",
226
+ "aggregation": "mean",
227
+ "higher_is_better": true
228
+ }
229
+ ],
230
+ "output_type": "multiple_choice",
231
+ "repeats": 1,
232
+ "should_decontaminate": false,
233
+ "metadata": {
234
+ "version": 1.0,
235
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
236
+ "dtype": "bfloat16",
237
+ "trust_remote_code": true
238
+ }
239
+ },
240
+ "lambada_openai": {
241
+ "task": "lambada_openai",
242
+ "tag": [
243
+ "lambada"
244
+ ],
245
+ "dataset_path": "EleutherAI/lambada_openai",
246
+ "dataset_name": "default",
247
+ "test_split": "test",
248
+ "doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
249
+ "doc_to_target": "{{' '+text.split(' ')[-1]}}",
250
+ "unsafe_code": false,
251
+ "description": "",
252
+ "target_delimiter": " ",
253
+ "fewshot_delimiter": "\n\n",
254
+ "num_fewshot": 0,
255
+ "metric_list": [
256
+ {
257
+ "metric": "perplexity",
258
+ "aggregation": "perplexity",
259
+ "higher_is_better": false
260
+ },
261
+ {
262
+ "metric": "acc",
263
+ "aggregation": "mean",
264
+ "higher_is_better": true
265
+ }
266
+ ],
267
+ "output_type": "loglikelihood",
268
+ "repeats": 1,
269
+ "should_decontaminate": true,
270
+ "doc_to_decontamination_query": "{{text}}",
271
+ "metadata": {
272
+ "version": 1.0,
273
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
274
+ "dtype": "bfloat16",
275
+ "trust_remote_code": true
276
+ }
277
+ },
278
+ "lambada_standard": {
279
+ "task": "lambada_standard",
280
+ "tag": [
281
+ "lambada"
282
+ ],
283
+ "dataset_path": "lambada",
284
+ "validation_split": "validation",
285
+ "test_split": "test",
286
+ "doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
287
+ "doc_to_target": "{{' '+text.split(' ')[-1]}}",
288
+ "unsafe_code": false,
289
+ "description": "",
290
+ "target_delimiter": " ",
291
+ "fewshot_delimiter": "\n\n",
292
+ "num_fewshot": 0,
293
+ "metric_list": [
294
+ {
295
+ "metric": "perplexity",
296
+ "aggregation": "perplexity",
297
+ "higher_is_better": false
298
+ },
299
+ {
300
+ "metric": "acc",
301
+ "aggregation": "mean",
302
+ "higher_is_better": true
303
+ }
304
+ ],
305
+ "output_type": "loglikelihood",
306
+ "repeats": 1,
307
+ "should_decontaminate": true,
308
+ "doc_to_decontamination_query": "{{text}}",
309
+ "metadata": {
310
+ "version": 1.0,
311
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
312
+ "dtype": "bfloat16",
313
+ "trust_remote_code": true
314
+ }
315
+ },
316
+ "piqa": {
317
+ "task": "piqa",
318
+ "dataset_path": "baber/piqa",
319
+ "training_split": "train",
320
+ "validation_split": "validation",
321
+ "doc_to_text": "Question: {{goal}}\nAnswer:",
322
+ "doc_to_target": "label",
323
+ "unsafe_code": false,
324
+ "doc_to_choice": "{{[sol1, sol2]}}",
325
+ "description": "",
326
+ "target_delimiter": " ",
327
+ "fewshot_delimiter": "\n\n",
328
+ "num_fewshot": 0,
329
+ "metric_list": [
330
+ {
331
+ "metric": "acc",
332
+ "aggregation": "mean",
333
+ "higher_is_better": true
334
+ },
335
+ {
336
+ "metric": "acc_norm",
337
+ "aggregation": "mean",
338
+ "higher_is_better": true
339
+ }
340
+ ],
341
+ "output_type": "multiple_choice",
342
+ "repeats": 1,
343
+ "should_decontaminate": true,
344
+ "doc_to_decontamination_query": "goal",
345
+ "metadata": {
346
+ "version": 1.0,
347
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
348
+ "dtype": "bfloat16",
349
+ "trust_remote_code": true
350
+ }
351
+ },
352
+ "social_iqa": {
353
+ "task": "social_iqa",
354
+ "dataset_path": "social_i_qa",
355
+ "training_split": "train",
356
+ "validation_split": "validation",
357
+ "doc_to_text": "Q: {{context}} {{question}}\nA:",
358
+ "doc_to_target": "{{ (label|int) - 1 }}",
359
+ "unsafe_code": false,
360
+ "doc_to_choice": "{{[answerA, answerB, answerC]}}",
361
+ "description": "",
362
+ "target_delimiter": " ",
363
+ "fewshot_delimiter": "\n\n",
364
+ "num_fewshot": 0,
365
+ "metric_list": [
366
+ {
367
+ "metric": "acc",
368
+ "aggregation": "mean",
369
+ "higher_is_better": true
370
+ }
371
+ ],
372
+ "output_type": "multiple_choice",
373
+ "repeats": 1,
374
+ "should_decontaminate": false,
375
+ "metadata": {
376
+ "version": 0.0,
377
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
378
+ "dtype": "bfloat16",
379
+ "trust_remote_code": true
380
+ }
381
+ },
382
+ "wikitext": {
383
+ "task": "wikitext",
384
+ "dataset_path": "EleutherAI/wikitext_document_level",
385
+ "dataset_name": "wikitext-2-raw-v1",
386
+ "training_split": "train",
387
+ "validation_split": "validation",
388
+ "test_split": "test",
389
+ "doc_to_text": "",
390
+ "doc_to_target": "def wikitext_detokenizer(doc):\n string = doc[\"page\"]\n # contractions\n string = string.replace(\"s '\", \"s'\")\n string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n # number separators\n string = string.replace(\" @-@ \", \"-\")\n string = string.replace(\" @,@ \", \",\")\n string = string.replace(\" @.@ \", \".\")\n # punctuation\n string = string.replace(\" : \", \": \")\n string = string.replace(\" ; \", \"; \")\n string = string.replace(\" . \", \". \")\n string = string.replace(\" ! \", \"! \")\n string = string.replace(\" ? \", \"? \")\n string = string.replace(\" , \", \", \")\n # double brackets\n string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n # miscellaneous\n string = string.replace(\"= = = =\", \"====\")\n string = string.replace(\"= = =\", \"===\")\n string = string.replace(\"= =\", \"==\")\n string = string.replace(\" \" + chr(176) + \" \", chr(176))\n string = string.replace(\" \\n\", \"\\n\")\n string = string.replace(\"\\n \", \"\\n\")\n string = string.replace(\" N \", \" 1 \")\n string = string.replace(\" 's\", \"'s\")\n\n return string\n",
391
+ "unsafe_code": false,
392
+ "process_results": "def process_results(doc, results):\n (loglikelihood,) = results\n # IMPORTANT: wikitext counts number of words in *original doc before detokenization*\n _words = len(re.split(r\"\\s+\", doc[\"page\"]))\n _bytes = len(doc[\"page\"].encode(\"utf-8\"))\n return {\n \"word_perplexity\": (loglikelihood, _words),\n \"byte_perplexity\": (loglikelihood, _bytes),\n \"bits_per_byte\": (loglikelihood, _bytes),\n }\n",
393
+ "description": "",
394
+ "target_delimiter": " ",
395
+ "fewshot_delimiter": "\n\n",
396
+ "num_fewshot": 0,
397
+ "metric_list": [
398
+ {
399
+ "metric": "word_perplexity"
400
+ },
401
+ {
402
+ "metric": "byte_perplexity"
403
+ },
404
+ {
405
+ "metric": "bits_per_byte"
406
+ }
407
+ ],
408
+ "output_type": "loglikelihood_rolling",
409
+ "repeats": 1,
410
+ "should_decontaminate": true,
411
+ "doc_to_decontamination_query": "{{page}}",
412
+ "metadata": {
413
+ "version": 2.0,
414
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
415
+ "dtype": "bfloat16",
416
+ "trust_remote_code": true
417
+ }
418
+ },
419
+ "winogrande": {
420
+ "task": "winogrande",
421
+ "dataset_path": "winogrande",
422
+ "dataset_name": "winogrande_xl",
423
+ "training_split": "train",
424
+ "validation_split": "validation",
425
+ "doc_to_text": "def doc_to_text(doc):\n answer_to_num = {\"1\": 0, \"2\": 1}\n return answer_to_num[doc[\"answer\"]]\n",
426
+ "doc_to_target": "def doc_to_target(doc):\n idx = doc[\"sentence\"].index(\"_\") + 1\n return doc[\"sentence\"][idx:].strip()\n",
427
+ "unsafe_code": false,
428
+ "doc_to_choice": "def doc_to_choice(doc):\n idx = doc[\"sentence\"].index(\"_\")\n options = [doc[\"option1\"], doc[\"option2\"]]\n return [doc[\"sentence\"][:idx] + opt for opt in options]\n",
429
+ "description": "",
430
+ "target_delimiter": " ",
431
+ "fewshot_delimiter": "\n\n",
432
+ "num_fewshot": 0,
433
+ "metric_list": [
434
+ {
435
+ "metric": "acc",
436
+ "aggregation": "mean",
437
+ "higher_is_better": true
438
+ }
439
+ ],
440
+ "output_type": "multiple_choice",
441
+ "repeats": 1,
442
+ "should_decontaminate": true,
443
+ "doc_to_decontamination_query": "sentence",
444
+ "metadata": {
445
+ "version": 1.0,
446
+ "pretrained": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
447
+ "dtype": "bfloat16",
448
+ "trust_remote_code": true
449
+ }
450
+ }
451
+ },
452
+ "versions": {
453
+ "arc_challenge": 1.0,
454
+ "arc_easy": 1.0,
455
+ "boolq": 2.0,
456
+ "hellaswag": 1.0,
457
+ "lambada_openai": 1.0,
458
+ "lambada_standard": 1.0,
459
+ "piqa": 1.0,
460
+ "social_iqa": 0.0,
461
+ "wikitext": 2.0,
462
+ "winogrande": 1.0
463
+ },
464
+ "n-shot": {
465
+ "arc_challenge": 0,
466
+ "arc_easy": 0,
467
+ "boolq": 0,
468
+ "hellaswag": 0,
469
+ "lambada_openai": 0,
470
+ "lambada_standard": 0,
471
+ "piqa": 0,
472
+ "social_iqa": 0,
473
+ "wikitext": 0,
474
+ "winogrande": 0
475
+ },
476
+ "higher_is_better": {
477
+ "arc_challenge": {
478
+ "acc": true,
479
+ "acc_norm": true
480
+ },
481
+ "arc_easy": {
482
+ "acc": true,
483
+ "acc_norm": true
484
+ },
485
+ "boolq": {
486
+ "acc": true
487
+ },
488
+ "hellaswag": {
489
+ "acc": true,
490
+ "acc_norm": true
491
+ },
492
+ "lambada_openai": {
493
+ "perplexity": false,
494
+ "acc": true
495
+ },
496
+ "lambada_standard": {
497
+ "perplexity": false,
498
+ "acc": true
499
+ },
500
+ "piqa": {
501
+ "acc": true,
502
+ "acc_norm": true
503
+ },
504
+ "social_iqa": {
505
+ "acc": true
506
+ },
507
+ "wikitext": {
508
+ "word_perplexity": false,
509
+ "byte_perplexity": false,
510
+ "bits_per_byte": false
511
+ },
512
+ "winogrande": {
513
+ "acc": true
514
+ }
515
+ },
516
+ "n-samples": {
517
+ "winogrande": {
518
+ "original": 1267,
519
+ "effective": 1267
520
+ },
521
+ "wikitext": {
522
+ "original": 62,
523
+ "effective": 62
524
+ },
525
+ "social_iqa": {
526
+ "original": 1954,
527
+ "effective": 1954
528
+ },
529
+ "piqa": {
530
+ "original": 1838,
531
+ "effective": 1838
532
+ },
533
+ "lambada_standard": {
534
+ "original": 5153,
535
+ "effective": 5153
536
+ },
537
+ "lambada_openai": {
538
+ "original": 5153,
539
+ "effective": 5153
540
+ },
541
+ "hellaswag": {
542
+ "original": 10042,
543
+ "effective": 10042
544
+ },
545
+ "boolq": {
546
+ "original": 3270,
547
+ "effective": 3270
548
+ },
549
+ "arc_easy": {
550
+ "original": 2376,
551
+ "effective": 2376
552
+ },
553
+ "arc_challenge": {
554
+ "original": 1172,
555
+ "effective": 1172
556
+ }
557
+ },
558
+ "config": {
559
+ "model": "hf",
560
+ "model_args": "pretrained=results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/,dtype=bfloat16,trust_remote_code=True,trust_remote_code=True",
561
+ "model_num_parameters": 221854518,
562
+ "model_dtype": "torch.bfloat16",
563
+ "model_revision": "main",
564
+ "model_sha": "",
565
+ "batch_size": "8",
566
+ "batch_sizes": [],
567
+ "device": "cuda:0",
568
+ "use_cache": null,
569
+ "limit": null,
570
+ "bootstrap_iters": 100000,
571
+ "gen_kwargs": null,
572
+ "random_seed": 0,
573
+ "numpy_seed": 1234,
574
+ "torch_seed": 1234,
575
+ "fewshot_seed": 1234
576
+ },
577
+ "git_hash": "core_v0.12.0-111-g418d5cb59",
578
+ "date": 1769143351.6943488,
579
+ "pretty_env_info": "PyTorch version: 2.6.0+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 22.04.4 LTS (x86_64)\nGCC version: (conda-forge gcc 9.5.0-19) 9.5.0\nClang version: Could not collect\nCMake version: version 3.30.1\nLibc version: glibc-2.35\n\nPython version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)\nPython platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.35\nIs CUDA available: True\nCUDA runtime version: 12.4.131\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: NVIDIA A800-SXM4-80GB\nGPU 1: NVIDIA A800-SXM4-80GB\nGPU 2: NVIDIA A800-SXM4-80GB\nGPU 3: NVIDIA A800-SXM4-80GB\nGPU 4: NVIDIA A800-SXM4-80GB\nGPU 5: NVIDIA A800-SXM4-80GB\nGPU 6: NVIDIA A800-SXM4-80GB\nGPU 7: NVIDIA A800-SXM4-80GB\n\nNvidia driver version: 550.54.15\ncuDNN version: Could not collect\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nAddress sizes: 52 bits physical, 57 bits virtual\nByte Order: Little Endian\nCPU(s): 104\nOn-line CPU(s) list: 0-103\nVendor ID: GenuineIntel\nModel name: Intel(R) Xeon(R) Platinum 8470\nCPU family: 6\nModel: 143\nThread(s) per core: 1\nCore(s) per socket: 52\nSocket(s): 2\nStepping: 8\nCPU max MHz: 3800.0000\nCPU min MHz: 800.0000\nBogoMIPS: 4000.00\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities\nVirtualization: VT-x\nL1d cache: 4.9 MiB (104 instances)\nL1i cache: 3.3 MiB (104 instances)\nL2 cache: 208 MiB (104 instances)\nL3 cache: 210 MiB (2 instances)\nNUMA node(s): 2\nNUMA node0 CPU(s): 0-51\nNUMA node1 CPU(s): 52-103\nVulnerability Gather data sampling: Not affected\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Not affected\nVulnerability Retbleed: Not affected\nVulnerability Spec rstack overflow: Not affected\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Not affected\n\nVersions of relevant libraries:\n[pip3] numpy==1.26.4\n[pip3] nvidia-cublas-cu12==12.4.5.8\n[pip3] nvidia-cuda-cupti-cu12==12.4.127\n[pip3] nvidia-cuda-nvrtc-cu12==12.4.127\n[pip3] nvidia-cuda-runtime-cu12==12.4.127\n[pip3] nvidia-cudnn-cu12==9.1.0.70\n[pip3] nvidia-cufft-cu12==11.2.1.3\n[pip3] nvidia-curand-cu12==10.3.5.147\n[pip3] nvidia-cusolver-cu12==11.6.1.9\n[pip3] nvidia-cusparse-cu12==12.3.1.170\n[pip3] nvidia-cusparselt-cu12==0.6.2\n[pip3] nvidia-nccl-cu11==2.21.5\n[pip3] nvidia-nccl-cu12==2.21.5\n[pip3] nvidia-nvjitlink-cu12==12.4.127\n[pip3] nvidia-nvtx-cu12==12.4.127\n[pip3] torch==2.6.0\n[pip3] torchaudio==2.6.0\n[pip3] torchdata==0.11.0\n[pip3] torchvision==0.21.0\n[pip3] triton==3.2.0\n[conda] cuda-cudart 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-cudart_linux-64 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-cupti 12.4.127 he02047a_2 conda-forge\n[conda] cuda-libraries 12.4.0 ha770c72_0 conda-forge\n[conda] cuda-nvrtc 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-nvtx 12.4.127 he02047a_2 conda-forge\n[conda] cuda-opencl 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-runtime 12.4.0 ha804496_0 conda-forge\n[conda] ffmpeg 4.3 hf484d3e_0 pytorch\n[conda] libcublas 12.4.2.65 hd3aeb46_0 conda-forge\n[conda] libcufft 11.2.0.44 hd3aeb46_0 conda-forge\n[conda] libcurand 10.3.5.119 hd3aeb46_0 conda-forge\n[conda] libcusolver 11.6.0.99 hd3aeb46_0 conda-forge\n[conda] libcusparse 12.3.0.142 hd3aeb46_0 conda-forge\n[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch\n[conda] libnvjitlink 12.4.99 hd3aeb46_0 conda-forge\n[conda] mkl 2023.1.0 h213fc3f_46344 defaults\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi\n[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi\n[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi\n[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi\n[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi\n[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi\n[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi\n[conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi\n[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi\n[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi\n[conda] pytorch-cuda 12.4 hc786d27_6 pytorch\n[conda] pytorch-mutex 1.0 cuda pytorch\n[conda] torch 2.6.0 pypi_0 pypi\n[conda] torchaudio 2.6.0 pypi_0 pypi\n[conda] torchdata 0.11.0 pypi_0 pypi\n[conda] torchvision 0.21.0 pypi_0 pypi\n[conda] triton 3.2.0 pypi_0 pypi",
580
+ "transformers_version": "4.55.2",
581
+ "lm_eval_version": "0.4.9.1",
582
+ "upper_git_hash": null,
583
+ "tokenizer_pad_token": [
584
+ "<unk>",
585
+ "0"
586
+ ],
587
+ "tokenizer_eos_token": [
588
+ "<|im_end|>",
589
+ "73440"
590
+ ],
591
+ "tokenizer_bos_token": [
592
+ "<s>",
593
+ "1"
594
+ ],
595
+ "eot_token_id": 73440,
596
+ "max_length": 4096,
597
+ "task_hashes": {},
598
+ "model_source": "hf",
599
+ "model_name": "results/hf_ckpts/blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64/",
600
+ "model_name_sanitized": "results__hf_ckpts__blockffn_01b_mul1002_withmean_d64_s128_lr1175e3_b64__",
601
+ "system_instruction": null,
602
+ "system_instruction_sha": null,
603
+ "fewshot_as_multiturn": false,
604
+ "chat_template": null,
605
+ "chat_template_sha": null,
606
+ "start_time": 676668.460022961,
607
+ "end_time": 676956.72660195,
608
+ "total_evaluation_time_seconds": "288.266578988987"
609
+ }
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "top_p": 0.8,
4
+ "temperature": 0.8,
5
+ "bos_token_id": 1,
6
+ "eos_token_id": [2,73440],
7
+ "pad_token_id": 2
8
+ }
modeling_blockffn.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from typing import Callable, Optional, Union
21
+
22
+ import math
23
+ import torch
24
+ from torch import nn
25
+
26
+ import tree
27
+ from abc import ABC, abstractmethod
28
+ from fmoe.linear import MOELinear
29
+ from fmoe.functions import prepare_forward, MOEScatter, MOEGather
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.generation import GenerationMixin
34
+ from transformers.integrations import use_kernel_forward_from_hub
35
+ from transformers.masking_utils import create_causal_mask
36
+ from transformers.modeling_layers import GradientCheckpointingLayer
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from transformers.processing_utils import Unpack
44
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
45
+ from transformers.utils.generic import check_model_inputs
46
+ from .configuration_blockffn import BlockFFNConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ @use_kernel_forward_from_hub("RMSNorm")
53
+ class BlockFFNRMSNorm(nn.Module):
54
+ def __init__(self, hidden_size, eps=1e-6):
55
+ super().__init__()
56
+ self.weight = nn.Parameter(torch.ones(hidden_size))
57
+ self.variance_epsilon = eps
58
+
59
+ def forward(self, hidden_states):
60
+ input_dtype = hidden_states.dtype
61
+ hidden_states = hidden_states.to(torch.float32)
62
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
63
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
64
+ return self.weight * hidden_states.to(input_dtype)
65
+
66
+ def extra_repr(self):
67
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
68
+
69
+
70
+ class BlockFFNRotaryEmbedding(nn.Module):
71
+ def __init__(self, config: BlockFFNConfig, device=None):
72
+ super().__init__()
73
+ # BC: "rope_type" was originally "type"
74
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
75
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
76
+ else:
77
+ self.rope_type = "default"
78
+ self.max_seq_len_cached = config.max_position_embeddings
79
+ self.original_max_seq_len = config.max_position_embeddings
80
+
81
+ self.config = config
82
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
83
+
84
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
85
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86
+ self.original_inv_freq = self.inv_freq
87
+
88
+ @torch.no_grad()
89
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
90
+ def forward(self, x, position_ids):
91
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
92
+ position_ids_expanded = position_ids[:, None, :].float()
93
+
94
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
95
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
96
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
97
+ emb = torch.cat((freqs, freqs), dim=-1)
98
+ cos = emb.cos() * self.attention_scaling
99
+ sin = emb.sin() * self.attention_scaling
100
+
101
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
102
+
103
+
104
+ def rotate_half(x):
105
+ """Rotates half the hidden dims of the input."""
106
+ x1 = x[..., : x.shape[-1] // 2]
107
+ x2 = x[..., x.shape[-1] // 2 :]
108
+ return torch.cat((-x2, x1), dim=-1)
109
+
110
+
111
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
112
+ """Applies Rotary Position Embedding to the query and key tensors.
113
+
114
+ Args:
115
+ q (`torch.Tensor`): The query tensor.
116
+ k (`torch.Tensor`): The key tensor.
117
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
118
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
119
+ position_ids (`torch.Tensor`, *optional*):
120
+ Deprecated and unused.
121
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
122
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
123
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
124
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
125
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
126
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
127
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
128
+ Returns:
129
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
130
+ """
131
+ cos = cos.unsqueeze(unsqueeze_dim)
132
+ sin = sin.unsqueeze(unsqueeze_dim)
133
+ q_embed = (q * cos) + (rotate_half(q) * sin)
134
+ k_embed = (k * cos) + (rotate_half(k) * sin)
135
+ return q_embed, k_embed
136
+
137
+
138
+ class SimpleLayerNorm(nn.Module):
139
+ def __init__(self, dim_norm: int):
140
+ super().__init__()
141
+ self.dim_norm = dim_norm
142
+ self.weight = torch.nn.Parameter(torch.empty(self.dim_norm))
143
+
144
+ @torch.compile
145
+ def forward(self, x: torch.Tensor):
146
+ return x * self.weight
147
+
148
+
149
+ class BlockFFNMLP(nn.Module):
150
+ def __init__(self, config: BlockFFNConfig, intermediate_size: int = None):
151
+ super().__init__()
152
+ self.config = config
153
+ self.hidden_size = config.hidden_size
154
+ self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
155
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
156
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
157
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
158
+ self.act_fn = ACT2FN[config.hidden_act]
159
+
160
+ def forward(self, x):
161
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
162
+ return down_proj
163
+
164
+
165
+ class BlockFFNRouter(nn.Module):
166
+ def __init__(self, config: BlockFFNConfig):
167
+ super().__init__()
168
+ self.config = config
169
+ self.num_experts = self.config.num_experts
170
+
171
+ if self.config.moe_router_dtype == "fp32":
172
+ self.router_dtype = torch.float32
173
+ elif self.config.moe_router_dtype == "fp64":
174
+ self.router_dtype = torch.float64
175
+ elif self.config.moe_router_dtype == "bf16":
176
+ self.router_dtype = torch.bfloat16
177
+ else:
178
+ raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
179
+
180
+ self.weight = torch.nn.Parameter(
181
+ torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype)
182
+ )
183
+
184
+ def forward(self, x: torch.Tensor):
185
+ return nn.functional.linear(x.to(self.router_dtype), self.weight)
186
+
187
+
188
+ class NormSiLU(nn.Module):
189
+ def __init__(self, config: BlockFFNConfig):
190
+ super().__init__()
191
+ self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size
192
+ self.activate_fn_type = config.expert_act_func
193
+ assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean"]
194
+
195
+ self.rms_norm = None
196
+ if self.activate_fn_type != "norm_silu_norms":
197
+ self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon)
198
+ self.silu = torch.nn.SiLU()
199
+
200
+ @torch.compile
201
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
202
+ assert hidden.ndim == 2
203
+ if self.activate_fn_type != "norm_silu_nomean":
204
+ hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True)
205
+ if self.activate_fn_type != "norm_silu_norms":
206
+ return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size)))
207
+ else:
208
+ return self.silu(hidden)
209
+
210
+
211
+ class BlockFFNLayer(nn.Module):
212
+ def __init__(self, config: BlockFFNConfig):
213
+ super(BlockFFNLayer, self).__init__()
214
+ self.config = config
215
+ self.num_experts, self.dim_expert, self.hidden_size = \
216
+ config.num_experts, config.moe_ffn_hidden_size, config.hidden_size
217
+ self.dim_shared_expert = config.moe_shared_expert_intermediate_size
218
+ self.router_norm_type = config.router_norm_type
219
+
220
+ self.moe_router = BlockFFNRouter(self.config)
221
+ assert config.router_act_func == "relu"
222
+ self.router_act = nn.ReLU()
223
+ if config.router_norm_type == "simple":
224
+ self.router_norm = SimpleLayerNorm(self.config.num_experts)
225
+ elif config.router_norm_type == "rms":
226
+ self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon)
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ self.expert_gated = not config.expert_not_gated
231
+ if self.expert_gated:
232
+ self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
233
+
234
+ self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
235
+ assert config.expert_act_norm_type == "normal"
236
+ if config.expert_act_func == "norm_silu":
237
+ self.expert_act = NormSiLU(self.config)
238
+ elif config.expert_act_func == "silu":
239
+ self.expert_act = nn.SiLU()
240
+ else:
241
+ raise NotImplementedError
242
+ self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias)
243
+
244
+ self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
245
+ if self.use_shared_expert:
246
+ self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
247
+
248
+ def forward(self, hidden_states: torch.Tensor):
249
+ ori_shape = hidden_states.shape
250
+ hidden_states = hidden_states.view(-1, self.hidden_size)
251
+ seq_len = hidden_states.shape[0]
252
+
253
+ # router module forward
254
+ raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts]
255
+ router_score = self.router_act(raw_router_score)
256
+ router_score = self.router_norm(router_score)
257
+
258
+ # expert module forward
259
+ x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert]
260
+ if self.expert_gated:
261
+ x_gate = self.expert_gate_proj(hidden_states)
262
+ x_in = x_in * self.expert_act(x_gate)
263
+ else:
264
+ x_in = self.expert_act(x_in)
265
+ if x_in.ndim == 3:
266
+ scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1)
267
+ else:
268
+ scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1)
269
+ output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert))
270
+
271
+ if self.use_shared_expert:
272
+ output = output + self.shared_experts(hidden_states)
273
+ return output.view(*ori_shape)
274
+
275
+
276
+ class BaseRouter(ABC, nn.Module):
277
+ """Base Router class"""
278
+ def __init__(self, config: BlockFFNConfig) -> None:
279
+ super().__init__()
280
+ self.config = config
281
+ self.num_experts = self.config.num_experts
282
+
283
+ if self.config.moe_router_dtype == "fp32":
284
+ self.router_dtype = torch.float32
285
+ elif self.config.moe_router_dtype == "fp64":
286
+ self.router_dtype = torch.float64
287
+ elif self.config.moe_router_dtype == "bf16":
288
+ self.router_dtype = torch.bfloat16
289
+ else:
290
+ raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
291
+
292
+ self.weight = torch.nn.Parameter(
293
+ torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype)
294
+ )
295
+
296
+ def gating(self, input: torch.Tensor):
297
+ return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype))
298
+
299
+ @abstractmethod
300
+ def routing(self, logits: torch.Tensor):
301
+ """Routing function.
302
+
303
+ Args:
304
+ logits (torch.Tensor): Logits tensor.
305
+
306
+ Returns:
307
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
308
+ probabilities and mapping.
309
+ """
310
+ raise NotImplementedError("Routing function not implemented.")
311
+
312
+ @abstractmethod
313
+ def forward(self, input: torch.Tensor):
314
+ """
315
+ Forward pass of the router.
316
+
317
+ Args:
318
+ input (torch.Tensor): Input tensor.
319
+ """
320
+ raise NotImplementedError("Forward function not implemented.")
321
+
322
+
323
+ class TopKRouter(BaseRouter):
324
+ """Route each token to the top-k experts."""
325
+
326
+ def __init__(self, config: BlockFFNConfig) -> None:
327
+ super().__init__(config)
328
+ self.config = config
329
+ self.topk = self.config.moe_router_topk
330
+ self.score_function = self.config.moe_router_score_function
331
+ self.use_pre_softmax = self.config.moe_router_pre_softmax
332
+ self.scaling_factor = self.config.moe_router_topk_scaling_factor
333
+
334
+ self.enable_expert_bias = self.config.moe_router_enable_expert_bias
335
+ if self.enable_expert_bias:
336
+ self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
337
+ else:
338
+ self.expert_bias = None
339
+
340
+ def _maintain_float32_expert_bias(self):
341
+ """
342
+ Maintain the expert bias in float32.
343
+
344
+ When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
345
+ We keep it in float32 to avoid routing errors when updating the expert_bias.
346
+ """
347
+ if hasattr(self, 'expert_bias') and self.expert_bias is not None:
348
+ if self.expert_bias.dtype != torch.float32:
349
+ self.expert_bias.data = self.expert_bias.data.to(torch.float32)
350
+
351
+ def routing(self, logits: torch.Tensor):
352
+ """Top-k routing function
353
+
354
+ Args:
355
+ logits (torch.Tensor): Logits tensor after gating.
356
+
357
+ Returns:
358
+ probs (torch.Tensor): The probabilities of token to experts assignment.
359
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
360
+ with shape [num_tokens, num_experts].
361
+ """
362
+ logits = logits.view(-1, self.num_experts)
363
+
364
+ if self.score_function == "softmax":
365
+ if self.use_pre_softmax:
366
+ scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
367
+ probs, top_indices = torch.topk(scores, k=self.topk, dim=1)
368
+ else:
369
+ scores, top_indices = torch.topk(logits, k=self.topk, dim=1)
370
+ probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
371
+ elif self.score_function == "sigmoid":
372
+ scores = torch.sigmoid(logits.float()).type_as(logits)
373
+ if self.expert_bias is not None:
374
+ scores_for_routing = scores + self.expert_bias
375
+ _, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1)
376
+ scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
377
+ else:
378
+ scores, top_indices = torch.topk(scores, k=self.topk, dim=1)
379
+ probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
380
+ else:
381
+ raise ValueError(f"Invalid score_function: {self.score_function}")
382
+
383
+ if self.scaling_factor:
384
+ probs = probs * self.scaling_factor
385
+
386
+ return probs, top_indices
387
+
388
+ def forward(self, input: torch.Tensor):
389
+ """
390
+ Forward pass of the router.
391
+
392
+ Args:
393
+ input (torch.Tensor): Input tensor.
394
+ """
395
+ self._maintain_float32_expert_bias()
396
+ logits = self.gating(input)
397
+ top_scores, top_indices = self.routing(logits)
398
+ return top_scores, top_indices
399
+
400
+
401
+ class ReMoERouter(BaseRouter):
402
+ def __init__(self, config: BlockFFNConfig) -> None:
403
+ super().__init__(config)
404
+ self.config = config
405
+ self.router_act = torch.nn.ReLU()
406
+
407
+ def routing(self, logits: torch.Tensor):
408
+ """Top-k routing function
409
+
410
+ Args:
411
+ logits (torch.Tensor): Logits tensor after gating.
412
+
413
+ Returns:
414
+ probs (torch.Tensor): The probabilities of token to experts assignment.
415
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
416
+ with shape [num_tokens, num_experts].
417
+ """
418
+ logits = logits.view(-1, self.num_experts)
419
+
420
+ router_score = self.router_act(logits)
421
+ routing_map = router_score > 0
422
+
423
+ sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
424
+ sorted_map = sorted_probs <= 0
425
+ sorted_indices = torch.where(sorted_map, -1, sorted_indices)
426
+ max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1)
427
+ assert torch.all(sorted_map[:, max_valid_num:])
428
+ sorted_probs = sorted_probs[:, :max_valid_num]
429
+ sorted_indices = sorted_indices[:, :max_valid_num]
430
+ assert torch.sum(routing_map) == torch.sum(sorted_indices != -1)
431
+ return sorted_probs, sorted_indices
432
+
433
+ def forward(self, input: torch.Tensor):
434
+ """
435
+ Forward pass of the router.
436
+
437
+ Args:
438
+ input (torch.Tensor): Input tensor.
439
+ """
440
+ logits = self.gating(input)
441
+ top_scores, top_indices = self.routing(logits)
442
+ return top_scores, top_indices
443
+
444
+
445
+ class TopPRouter(BaseRouter):
446
+ def __init__(self, config: BlockFFNConfig) -> None:
447
+ super().__init__(config)
448
+ self.config = config
449
+ self.top_p = config.moe_router_topp
450
+
451
+ def routing(self, logits: torch.Tensor):
452
+ """Top-k routing function
453
+
454
+ Args:
455
+ logits (torch.Tensor): Logits tensor after gating.
456
+
457
+ Returns:
458
+ probs (torch.Tensor): The probabilities of token to experts assignment.
459
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
460
+ with shape [num_tokens, num_experts].
461
+ """
462
+ logits = logits.view(-1, self.num_experts)
463
+
464
+ router_score = torch.abs(logits)
465
+ router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20)
466
+
467
+ sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
468
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
469
+ mask = cumulative_probs > self.top_p
470
+
471
+ threshold_indices = mask.long().argmax(dim=-1)
472
+ threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
473
+
474
+ mask = mask & ~threshold_mask
475
+ sorted_indices = torch.where(mask, -1, sorted_indices)
476
+ sorted_probs = torch.where(mask, 0.0, sorted_probs)
477
+
478
+ max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1)
479
+ assert torch.all(mask[:, max_valid_num:])
480
+
481
+ sorted_indices = sorted_indices[:, :max_valid_num]
482
+ sorted_probs = sorted_probs[:, :max_valid_num]
483
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
484
+ return sorted_probs, sorted_indices
485
+
486
+ def forward(self, input: torch.Tensor):
487
+ """
488
+ Forward pass of the router.
489
+
490
+ Args:
491
+ input (torch.Tensor): Input tensor.
492
+ """
493
+ logits = self.gating(input)
494
+ top_scores, top_indices = self.routing(logits)
495
+ return top_scores, top_indices
496
+
497
+
498
+ class FastTopKCalculator:
499
+ def __init__(self, num_experts: int):
500
+ self.num_experts = num_experts
501
+
502
+ def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module):
503
+ (
504
+ pos,
505
+ local_expert_count,
506
+ global_expert_count,
507
+ fwd_expert_count,
508
+ fwd_batch_size,
509
+ ) = prepare_forward(topk_indices, self.num_experts, 1)
510
+ topk = 1
511
+ if len(topk_indices.shape) == 2:
512
+ topk = topk_indices.shape[1]
513
+
514
+ def scatter_func(tensor):
515
+ return MOEScatter.apply(
516
+ tensor,
517
+ torch.div(pos, topk, rounding_mode='floor'),
518
+ local_expert_count,
519
+ global_expert_count,
520
+ fwd_batch_size,
521
+ 1,
522
+ )
523
+
524
+ x = tree.map_structure(scatter_func, hidden_states)
525
+ x = experts(x, fwd_expert_count, topk_indices=topk_indices)
526
+
527
+ out_batch_size = tree.flatten(hidden_states)[0].shape[0]
528
+ if len(topk_indices.shape) == 2:
529
+ out_batch_size *= topk_indices.shape[1]
530
+
531
+ def gather_func(tensor):
532
+ return MOEGather.apply(
533
+ tensor,
534
+ pos,
535
+ local_expert_count,
536
+ global_expert_count,
537
+ out_batch_size,
538
+ 1,
539
+ )
540
+
541
+ outp = tree.map_structure(gather_func, x)
542
+ return outp
543
+
544
+ def forward(self, hidden_states, topk_indices, topk_weights, experts):
545
+ assert topk_indices.shape == topk_weights.shape
546
+ top_k = topk_indices.shape[-1]
547
+ dim3 = hidden_states.ndim == 3
548
+ if dim3:
549
+ batch_size, seq_len, dim = hidden_states.shape
550
+ hidden_states = hidden_states.view(batch_size * seq_len, dim)
551
+ else:
552
+ assert hidden_states.ndim == 2
553
+ batch_size, (seq_len, dim) = -1, hidden_states.shape
554
+ fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts)
555
+
556
+ def view_func(tensor):
557
+ n_dim = tensor.shape[-1]
558
+ tensor = tensor.view(-1, top_k, n_dim)
559
+ return tensor
560
+
561
+ moe_output = tree.map_structure(view_func, fwd)
562
+ topk_weights = topk_weights.unsqueeze(1)
563
+
564
+ def bmm_func(tensor):
565
+ n_dim = tensor.shape[-1]
566
+ tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim)
567
+ return tensor
568
+
569
+ moe_output = tree.map_structure(bmm_func, moe_output)
570
+ if dim3:
571
+ moe_output = moe_output.view(batch_size, seq_len, -1)
572
+ return moe_output
573
+
574
+
575
+ class MoELinearExperts(nn.Module):
576
+ def __init__(
577
+ self,
578
+ dim_in: int,
579
+ dim_out: int,
580
+ num_experts: int,
581
+ ffn_bias: bool,
582
+ ):
583
+ super().__init__()
584
+ self.dim_in = self.in_features = dim_in
585
+ self.dim_out = self.out_features = dim_out
586
+ self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
587
+ self.bias = None
588
+ if ffn_bias:
589
+ self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out))
590
+
591
+ def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor):
592
+ x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias)
593
+ return x
594
+
595
+
596
+ class MoEGatedExperts(nn.Module):
597
+ def __init__(
598
+ self,
599
+ dim_in: int,
600
+ dim_ff: int,
601
+ is_gated: bool,
602
+ act_name: str,
603
+ num_experts: int,
604
+ ffn_bias: bool = False,
605
+ ):
606
+ super().__init__()
607
+ self.is_gated = is_gated
608
+ self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts
609
+ if self.is_gated:
610
+ self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
611
+ self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
612
+ self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias)
613
+
614
+ self.act_fn = ACT2FN[act_name]
615
+
616
+ def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor:
617
+ if self.is_gated:
618
+ gate_score = self.gate_proj(x, fwd_expert_count)
619
+ up_proj = self.up_proj(x, fwd_expert_count)
620
+ x = up_proj * self.act_fn(gate_score)
621
+ else:
622
+ up_score = self.up_proj(x, fwd_expert_count)
623
+ x = self.act_fn(up_score)
624
+ x = self.down_proj(x, fwd_expert_count)
625
+ return x
626
+
627
+
628
+ class VanillaMoELayer(nn.Module):
629
+ def __init__(self, config: BlockFFNConfig):
630
+ super(VanillaMoELayer, self).__init__()
631
+ self.config = config
632
+
633
+ # Initialize router
634
+ if config.router_type == "topk":
635
+ self.router = TopKRouter(config=self.config)
636
+ elif config.router_type == "remoe":
637
+ self.router = ReMoERouter(config=self.config)
638
+ elif config.router_type == "topp":
639
+ self.router = TopPRouter(config=self.config)
640
+ else:
641
+ raise NotImplementedError(f"Router type {config.router_type} not implemented.")
642
+
643
+ self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts)
644
+
645
+ # Initialize experts
646
+ self.experts = MoEGatedExperts(
647
+ dim_in=self.config.hidden_size,
648
+ dim_ff=self.config.moe_ffn_hidden_size,
649
+ is_gated=not self.config.expert_not_gated,
650
+ act_name="silu",
651
+ num_experts=self.config.num_experts,
652
+ )
653
+
654
+ self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size
655
+ self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
656
+ if self.use_shared_expert:
657
+ self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
658
+
659
+ def forward(self, hidden_states: torch.Tensor):
660
+ top_scores, top_indices = self.router(hidden_states)
661
+ y = self.mix_calculator.forward(
662
+ hidden_states=hidden_states,
663
+ topk_indices=top_indices.contiguous(),
664
+ topk_weights=top_scores.type_as(hidden_states),
665
+ experts=self.experts,
666
+ )
667
+ if self.shared_experts is not None:
668
+ y = y + self.shared_experts(hidden_states)
669
+ return y
670
+
671
+
672
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
673
+ """
674
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
675
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
676
+ """
677
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
678
+ if n_rep == 1:
679
+ return hidden_states
680
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
681
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
682
+
683
+
684
+ def eager_attention_forward(
685
+ module: nn.Module,
686
+ query: torch.Tensor,
687
+ key: torch.Tensor,
688
+ value: torch.Tensor,
689
+ attention_mask: Optional[torch.Tensor],
690
+ scaling: float,
691
+ dropout: float = 0.0,
692
+ ):
693
+ key_states = repeat_kv(key, module.num_key_value_groups)
694
+ value_states = repeat_kv(value, module.num_key_value_groups)
695
+
696
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
697
+ if attention_mask is not None:
698
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
699
+ attn_weights = attn_weights + causal_mask
700
+
701
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
702
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
703
+ attn_output = torch.matmul(attn_weights, value_states)
704
+ attn_output = attn_output.transpose(1, 2).contiguous()
705
+
706
+ return attn_output, attn_weights
707
+
708
+
709
+ class BlockFFNAttention(nn.Module):
710
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
711
+
712
+ def __init__(self, config: BlockFFNConfig, layer_idx: int):
713
+ super().__init__()
714
+ self.config = config
715
+ self.layer_idx = layer_idx
716
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
717
+ self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
718
+ self.scaling = self.head_dim**-0.5
719
+ self.attention_dropout = config.attention_dropout
720
+ self.is_causal = True
721
+
722
+ self.q_proj = nn.Linear(
723
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
724
+ )
725
+ self.k_proj = nn.Linear(
726
+ config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
727
+ )
728
+ self.v_proj = nn.Linear(
729
+ config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
730
+ )
731
+ self.o_proj = nn.Linear(
732
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
733
+ )
734
+
735
+ def forward(
736
+ self,
737
+ hidden_states: torch.Tensor,
738
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
739
+ attention_mask: Optional[torch.Tensor],
740
+ past_key_value: Optional[Cache] = None,
741
+ cache_position: Optional[torch.LongTensor] = None,
742
+ **kwargs: Unpack[TransformersKwargs],
743
+ ) -> tuple[torch.Tensor, torch.Tensor]:
744
+ input_shape = hidden_states.shape[:-1]
745
+ hidden_shape = (*input_shape, -1, self.head_dim)
746
+
747
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
748
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
749
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
750
+
751
+ cos, sin = position_embeddings
752
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
753
+
754
+ if past_key_value is not None:
755
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
756
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
757
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
758
+
759
+ attention_interface: Callable = eager_attention_forward
760
+ if self.config._attn_implementation != "eager":
761
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
762
+
763
+ attn_output, attn_weights = attention_interface(
764
+ self,
765
+ query_states,
766
+ key_states,
767
+ value_states,
768
+ attention_mask,
769
+ dropout=0.0 if not self.training else self.attention_dropout,
770
+ scaling=self.scaling,
771
+ **kwargs,
772
+ )
773
+
774
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
775
+ attn_output = self.o_proj(attn_output)
776
+ return attn_output, attn_weights
777
+
778
+
779
+ class BlockFFNDecoderLayer(GradientCheckpointingLayer):
780
+ def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool):
781
+ super().__init__()
782
+ self.config = config
783
+ self.hidden_size = config.hidden_size
784
+
785
+ self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx)
786
+
787
+ if is_moe_layer:
788
+ if config.use_blockffn:
789
+ self.mlp = BlockFFNLayer(config)
790
+ elif config.router_type in ["topk", "remoe", "topp"]:
791
+ self.mlp = VanillaMoELayer(config)
792
+ else:
793
+ raise NotImplementedError
794
+ else:
795
+ self.mlp = BlockFFNMLP(config)
796
+ self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
797
+ self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
798
+
799
+ def forward(
800
+ self,
801
+ hidden_states: torch.Tensor,
802
+ attention_mask: Optional[torch.Tensor] = None,
803
+ position_ids: Optional[torch.LongTensor] = None,
804
+ past_key_value: Optional[Cache] = None,
805
+ use_cache: Optional[bool] = False,
806
+ cache_position: Optional[torch.LongTensor] = None,
807
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
808
+ **kwargs: Unpack[TransformersKwargs],
809
+ ) -> tuple[torch.Tensor]:
810
+ residual = hidden_states
811
+ hidden_states = self.input_layernorm(hidden_states)
812
+ # Self Attention
813
+ hidden_states, _ = self.self_attn(
814
+ hidden_states=hidden_states,
815
+ attention_mask=attention_mask,
816
+ position_ids=position_ids,
817
+ past_key_value=past_key_value,
818
+ use_cache=use_cache,
819
+ cache_position=cache_position,
820
+ position_embeddings=position_embeddings,
821
+ **kwargs,
822
+ )
823
+ if self.config.use_mup:
824
+ hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
825
+ else:
826
+ hidden_states = residual + hidden_states
827
+
828
+ # Fully Connected
829
+ residual = hidden_states
830
+ hidden_states = self.post_attention_layernorm(hidden_states)
831
+ hidden_states = self.mlp(hidden_states)
832
+ if self.config.use_mup:
833
+ hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
834
+ else:
835
+ hidden_states = residual + hidden_states
836
+ return hidden_states
837
+
838
+
839
+ @auto_docstring
840
+ class BlockFFNPreTrainedModel(PreTrainedModel):
841
+ config: BlockFFNConfig
842
+ base_model_prefix = "model"
843
+ supports_gradient_checkpointing = True
844
+ _no_split_modules = ["BlockFFNDecoderLayer"]
845
+ _skip_keys_device_placement = ["past_key_values"]
846
+ _supports_flash_attn = True
847
+ _supports_sdpa = True
848
+ _supports_flex_attn = True
849
+
850
+ _can_compile_fullgraph = True
851
+ _supports_attention_backend = True
852
+ _can_record_outputs = {
853
+ "hidden_states": BlockFFNDecoderLayer,
854
+ "attentions": BlockFFNAttention,
855
+ }
856
+
857
+
858
+ @auto_docstring
859
+ class BlockFFNModel(BlockFFNPreTrainedModel):
860
+ def __init__(self, config: BlockFFNConfig):
861
+ super().__init__(config)
862
+ self.config = config
863
+ self.padding_idx = config.pad_token_id
864
+ self.vocab_size = config.vocab_size
865
+
866
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
867
+ self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq
868
+ assert len(self.moe_layer_freq) == config.num_layers
869
+ self.layers = nn.ModuleList(
870
+ [BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)]
871
+ )
872
+ self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
873
+ self.rotary_emb = BlockFFNRotaryEmbedding(config=config)
874
+ self.gradient_checkpointing = False
875
+
876
+ # Initialize weights and apply final processing
877
+ self.post_init()
878
+
879
+ @check_model_inputs
880
+ @auto_docstring
881
+ def forward(
882
+ self,
883
+ input_ids: Optional[torch.LongTensor] = None,
884
+ attention_mask: Optional[torch.Tensor] = None,
885
+ position_ids: Optional[torch.LongTensor] = None,
886
+ past_key_values: Optional[Cache] = None,
887
+ inputs_embeds: Optional[torch.FloatTensor] = None,
888
+ cache_position: Optional[torch.LongTensor] = None,
889
+ use_cache: Optional[bool] = None,
890
+ **kwargs: Unpack[TransformersKwargs],
891
+ ) -> BaseModelOutputWithPast:
892
+ if (input_ids is None) ^ (inputs_embeds is not None):
893
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
894
+
895
+ if inputs_embeds is None:
896
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
897
+ if self.config.use_mup:
898
+ inputs_embeds = inputs_embeds * self.config.mup_emb_scale
899
+
900
+ if use_cache and past_key_values is None:
901
+ past_key_values = DynamicCache()
902
+
903
+ if cache_position is None:
904
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
905
+ cache_position: torch.Tensor = torch.arange(
906
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
907
+ )
908
+
909
+ if position_ids is None:
910
+ position_ids = cache_position.unsqueeze(0)
911
+
912
+ causal_mask = create_causal_mask(
913
+ config=self.config,
914
+ input_embeds=inputs_embeds,
915
+ attention_mask=attention_mask,
916
+ cache_position=cache_position,
917
+ past_key_values=past_key_values,
918
+ position_ids=position_ids,
919
+ )
920
+
921
+ hidden_states = inputs_embeds
922
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
923
+
924
+ for decoder_layer in self.layers[: self.config.num_layers]:
925
+ hidden_states = decoder_layer(
926
+ hidden_states,
927
+ attention_mask=causal_mask,
928
+ position_ids=position_ids,
929
+ past_key_value=past_key_values,
930
+ cache_position=cache_position,
931
+ position_embeddings=position_embeddings,
932
+ **kwargs,
933
+ )
934
+
935
+ hidden_states = self.norm(hidden_states)
936
+ return BaseModelOutputWithPast(
937
+ last_hidden_state=hidden_states,
938
+ past_key_values=past_key_values,
939
+ )
940
+
941
+
942
+ @auto_docstring
943
+ class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin):
944
+ _tied_weights_keys = ["lm_head.weight"]
945
+ _tp_plan = {"lm_head": "colwise_rep"}
946
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
947
+
948
+ def __init__(self, config: BlockFFNConfig):
949
+ super().__init__(config)
950
+ self.config = config
951
+ self.model = BlockFFNModel(config)
952
+ self.vocab_size = config.vocab_size
953
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
954
+
955
+ # Initialize weights and apply final processing
956
+ self.post_init()
957
+
958
+ def set_decoder(self, decoder):
959
+ self.model = decoder
960
+
961
+ def get_decoder(self):
962
+ return self.model
963
+
964
+ @can_return_tuple
965
+ @auto_docstring
966
+ def forward(
967
+ self,
968
+ input_ids: Optional[torch.LongTensor] = None,
969
+ attention_mask: Optional[torch.Tensor] = None,
970
+ position_ids: Optional[torch.LongTensor] = None,
971
+ past_key_values: Optional[Cache] = None,
972
+ inputs_embeds: Optional[torch.FloatTensor] = None,
973
+ labels: Optional[torch.LongTensor] = None,
974
+ use_cache: Optional[bool] = None,
975
+ cache_position: Optional[torch.LongTensor] = None,
976
+ logits_to_keep: Union[int, torch.Tensor] = 0,
977
+ **kwargs: Unpack[TransformersKwargs],
978
+ ) -> CausalLMOutputWithPast:
979
+ outputs: BaseModelOutputWithPast = self.model(
980
+ input_ids=input_ids,
981
+ attention_mask=attention_mask,
982
+ position_ids=position_ids,
983
+ past_key_values=past_key_values,
984
+ inputs_embeds=inputs_embeds,
985
+ use_cache=use_cache,
986
+ cache_position=cache_position,
987
+ **kwargs,
988
+ )
989
+
990
+ hidden_states = outputs.last_hidden_state
991
+ if self.config.use_mup:
992
+ hidden_states = hidden_states / self.config.mup_width_scale
993
+
994
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
995
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
996
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
997
+
998
+ loss = None
999
+ if labels is not None:
1000
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1001
+
1002
+ return CausalLMOutputWithPast(
1003
+ loss=loss,
1004
+ logits=logits,
1005
+ past_key_values=outputs.past_key_values,
1006
+ hidden_states=outputs.hidden_states,
1007
+ attentions=outputs.attentions,
1008
+ )
1009
+
1010
+ __all__ = [
1011
+ "BlockFFNForCausalLM",
1012
+ "BlockFFNModel",
1013
+ "BlockFFNPreTrainedModel",
1014
+ ]
modeling_blockffn.py.bak ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from typing import Callable, Optional, Union
21
+
22
+ import math
23
+ import torch
24
+ from torch import nn
25
+
26
+ import tree
27
+ from abc import ABC, abstractmethod
28
+ from fmoe.linear import MOELinear
29
+ from fmoe.functions import prepare_forward, MOEScatter, MOEGather
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.generation import GenerationMixin
34
+ from transformers.integrations import use_kernel_forward_from_hub
35
+ from transformers.masking_utils import create_causal_mask
36
+ from transformers.modeling_layers import GradientCheckpointingLayer
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from transformers.processing_utils import Unpack
44
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
45
+ from transformers.utils.generic import check_model_inputs
46
+ from .configuration_blockffn import BlockFFNConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ @use_kernel_forward_from_hub("RMSNorm")
53
+ class BlockFFNRMSNorm(nn.Module):
54
+ def __init__(self, hidden_size, eps=1e-6):
55
+ super().__init__()
56
+ self.weight = nn.Parameter(torch.ones(hidden_size))
57
+ self.variance_epsilon = eps
58
+
59
+ def forward(self, hidden_states):
60
+ input_dtype = hidden_states.dtype
61
+ hidden_states = hidden_states.to(torch.float32)
62
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
63
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
64
+ return self.weight * hidden_states.to(input_dtype)
65
+
66
+ def extra_repr(self):
67
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
68
+
69
+
70
+ class BlockFFNRotaryEmbedding(nn.Module):
71
+ def __init__(self, config: BlockFFNConfig, device=None):
72
+ super().__init__()
73
+ # BC: "rope_type" was originally "type"
74
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
75
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
76
+ else:
77
+ self.rope_type = "default"
78
+ self.max_seq_len_cached = config.max_position_embeddings
79
+ self.original_max_seq_len = config.max_position_embeddings
80
+
81
+ self.config = config
82
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
83
+
84
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
85
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86
+ self.original_inv_freq = self.inv_freq
87
+
88
+ @torch.no_grad()
89
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
90
+ def forward(self, x, position_ids):
91
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
92
+ position_ids_expanded = position_ids[:, None, :].float()
93
+
94
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
95
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
96
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
97
+ emb = torch.cat((freqs, freqs), dim=-1)
98
+ cos = emb.cos() * self.attention_scaling
99
+ sin = emb.sin() * self.attention_scaling
100
+
101
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
102
+
103
+
104
+ def rotate_half(x):
105
+ """Rotates half the hidden dims of the input."""
106
+ x1 = x[..., : x.shape[-1] // 2]
107
+ x2 = x[..., x.shape[-1] // 2 :]
108
+ return torch.cat((-x2, x1), dim=-1)
109
+
110
+
111
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
112
+ """Applies Rotary Position Embedding to the query and key tensors.
113
+
114
+ Args:
115
+ q (`torch.Tensor`): The query tensor.
116
+ k (`torch.Tensor`): The key tensor.
117
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
118
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
119
+ position_ids (`torch.Tensor`, *optional*):
120
+ Deprecated and unused.
121
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
122
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
123
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
124
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
125
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
126
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
127
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
128
+ Returns:
129
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
130
+ """
131
+ cos = cos.unsqueeze(unsqueeze_dim)
132
+ sin = sin.unsqueeze(unsqueeze_dim)
133
+ q_embed = (q * cos) + (rotate_half(q) * sin)
134
+ k_embed = (k * cos) + (rotate_half(k) * sin)
135
+ return q_embed, k_embed
136
+
137
+
138
+ class SimpleLayerNorm(nn.Module):
139
+ def __init__(self, dim_norm: int):
140
+ super().__init__()
141
+ self.dim_norm = dim_norm
142
+ self.weight = torch.nn.Parameter(torch.empty(self.dim_norm))
143
+
144
+ @torch.compile
145
+ def forward(self, x: torch.Tensor):
146
+ return x * self.weight
147
+
148
+
149
+ class BlockFFNMLP(nn.Module):
150
+ def __init__(self, config: BlockFFNConfig, intermediate_size: int = None):
151
+ super().__init__()
152
+ self.config = config
153
+ self.hidden_size = config.hidden_size
154
+ self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
155
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
156
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
157
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
158
+ self.act_fn = ACT2FN[config.hidden_act]
159
+
160
+ def forward(self, x):
161
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
162
+ return down_proj
163
+
164
+
165
+ class BlockFFNRouter(nn.Module):
166
+ def __init__(self, config: BlockFFNConfig):
167
+ super().__init__()
168
+ self.config = config
169
+ self.num_experts = self.config.num_experts
170
+
171
+ if self.config.moe_router_dtype == "fp32":
172
+ self.router_dtype = torch.float32
173
+ elif self.config.moe_router_dtype == "fp64":
174
+ self.router_dtype = torch.float64
175
+ elif self.config.moe_router_dtype == "bf16":
176
+ self.router_dtype = torch.bfloat16
177
+ else:
178
+ raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
179
+
180
+ self.weight = torch.nn.Parameter(
181
+ torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype)
182
+ )
183
+
184
+ def forward(self, x: torch.Tensor):
185
+ return nn.functional.linear(x.to(self.router_dtype), self.weight)
186
+
187
+
188
+ class NormSiLU(nn.Module):
189
+ def __init__(self, config: BlockFFNConfig):
190
+ super().__init__()
191
+ self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size
192
+ self.activate_fn_type = config.expert_act_func
193
+ assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean"]
194
+
195
+ self.rms_norm = None
196
+ if self.activate_fn_type != "norm_silu_norms":
197
+ self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon)
198
+ self.silu = torch.nn.SiLU()
199
+
200
+ @torch.compile
201
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
202
+ assert hidden.ndim == 2
203
+ if self.activate_fn_type != "norm_silu_nomean":
204
+ hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True)
205
+ if self.activate_fn_type != "norm_silu_norms":
206
+ return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size)))
207
+ else:
208
+ return self.silu(hidden)
209
+
210
+
211
+ class BlockFFNLayer(nn.Module):
212
+ def __init__(self, config: BlockFFNConfig):
213
+ super(BlockFFNLayer, self).__init__()
214
+ self.config = config
215
+ self.num_experts, self.dim_expert, self.hidden_size = \
216
+ config.num_experts, config.moe_ffn_hidden_size, config.hidden_size
217
+ self.dim_shared_expert = config.moe_shared_expert_intermediate_size
218
+ self.router_norm_type = config.router_norm_type
219
+
220
+ self.moe_router = BlockFFNRouter(self.config)
221
+ assert config.router_act_func == "relu"
222
+ self.router_act = nn.ReLU()
223
+ if config.router_norm_type == "simple":
224
+ self.router_norm = SimpleLayerNorm(self.config.num_experts)
225
+ elif config.router_norm_type == "rms":
226
+ self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon)
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ self.expert_gated = not config.expert_not_gated
231
+ if self.expert_gated:
232
+ self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
233
+
234
+ self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
235
+ assert config.expert_act_norm_type == "normal"
236
+ if config.expert_act_func == "norm_silu":
237
+ self.expert_act = NormSiLU(self.config)
238
+ elif config.expert_act_func == "silu":
239
+ self.expert_act = nn.SiLU()
240
+ else:
241
+ raise NotImplementedError
242
+ self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias)
243
+
244
+ self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
245
+ if self.use_shared_expert:
246
+ self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
247
+
248
+ self.post_act_values = []
249
+
250
+ def forward(self, hidden_states: torch.Tensor):
251
+ ori_shape = hidden_states.shape
252
+ hidden_states = hidden_states.view(-1, self.hidden_size)
253
+ seq_len = hidden_states.shape[0]
254
+
255
+ # router module forward
256
+ raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts]
257
+ router_score = self.router_act(raw_router_score)
258
+ router_score = self.router_norm(router_score)
259
+
260
+ # expert module forward
261
+ x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert]
262
+ if self.expert_gated:
263
+ x_gate = self.expert_gate_proj(hidden_states)
264
+ x_in = x_in * self.expert_act(x_gate)
265
+ else:
266
+ x_in = self.expert_act(x_in)
267
+ with torch.no_grad():
268
+ x_in_np = x_in.cpu().float().numpy()
269
+ self.post_act_values.append(x_in_np)
270
+ if x_in.ndim == 3:
271
+ scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1)
272
+ else:
273
+ scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1)
274
+ output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert))
275
+
276
+ if self.use_shared_expert:
277
+ output = output + self.shared_experts(hidden_states)
278
+ return output.view(*ori_shape)
279
+
280
+
281
+ class BaseRouter(ABC, nn.Module):
282
+ """Base Router class"""
283
+ def __init__(self, config: BlockFFNConfig) -> None:
284
+ super().__init__()
285
+ self.config = config
286
+ self.num_experts = self.config.num_experts
287
+
288
+ if self.config.moe_router_dtype == "fp32":
289
+ self.router_dtype = torch.float32
290
+ elif self.config.moe_router_dtype == "fp64":
291
+ self.router_dtype = torch.float64
292
+ elif self.config.moe_router_dtype == "bf16":
293
+ self.router_dtype = torch.bfloat16
294
+ else:
295
+ raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
296
+
297
+ self.weight = torch.nn.Parameter(
298
+ torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype)
299
+ )
300
+
301
+ def gating(self, input: torch.Tensor):
302
+ return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype))
303
+
304
+ @abstractmethod
305
+ def routing(self, logits: torch.Tensor):
306
+ """Routing function.
307
+
308
+ Args:
309
+ logits (torch.Tensor): Logits tensor.
310
+
311
+ Returns:
312
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
313
+ probabilities and mapping.
314
+ """
315
+ raise NotImplementedError("Routing function not implemented.")
316
+
317
+ @abstractmethod
318
+ def forward(self, input: torch.Tensor):
319
+ """
320
+ Forward pass of the router.
321
+
322
+ Args:
323
+ input (torch.Tensor): Input tensor.
324
+ """
325
+ raise NotImplementedError("Forward function not implemented.")
326
+
327
+
328
+ class TopKRouter(BaseRouter):
329
+ """Route each token to the top-k experts."""
330
+
331
+ def __init__(self, config: BlockFFNConfig) -> None:
332
+ super().__init__(config)
333
+ self.config = config
334
+ self.topk = self.config.moe_router_topk
335
+ self.score_function = self.config.moe_router_score_function
336
+ self.use_pre_softmax = self.config.moe_router_pre_softmax
337
+ self.scaling_factor = self.config.moe_router_topk_scaling_factor
338
+
339
+ self.enable_expert_bias = self.config.moe_router_enable_expert_bias
340
+ if self.enable_expert_bias:
341
+ self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
342
+ else:
343
+ self.expert_bias = None
344
+
345
+ def _maintain_float32_expert_bias(self):
346
+ """
347
+ Maintain the expert bias in float32.
348
+
349
+ When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
350
+ We keep it in float32 to avoid routing errors when updating the expert_bias.
351
+ """
352
+ if hasattr(self, 'expert_bias') and self.expert_bias is not None:
353
+ if self.expert_bias.dtype != torch.float32:
354
+ self.expert_bias.data = self.expert_bias.data.to(torch.float32)
355
+
356
+ def routing(self, logits: torch.Tensor):
357
+ """Top-k routing function
358
+
359
+ Args:
360
+ logits (torch.Tensor): Logits tensor after gating.
361
+
362
+ Returns:
363
+ probs (torch.Tensor): The probabilities of token to experts assignment.
364
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
365
+ with shape [num_tokens, num_experts].
366
+ """
367
+ logits = logits.view(-1, self.num_experts)
368
+
369
+ if self.score_function == "softmax":
370
+ if self.use_pre_softmax:
371
+ scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
372
+ probs, top_indices = torch.topk(scores, k=self.topk, dim=1)
373
+ else:
374
+ scores, top_indices = torch.topk(logits, k=self.topk, dim=1)
375
+ probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
376
+ elif self.score_function == "sigmoid":
377
+ scores = torch.sigmoid(logits.float()).type_as(logits)
378
+ if self.expert_bias is not None:
379
+ scores_for_routing = scores + self.expert_bias
380
+ _, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1)
381
+ scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
382
+ else:
383
+ scores, top_indices = torch.topk(scores, k=self.topk, dim=1)
384
+ probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
385
+ else:
386
+ raise ValueError(f"Invalid score_function: {self.score_function}")
387
+
388
+ if self.scaling_factor:
389
+ probs = probs * self.scaling_factor
390
+
391
+ return probs, top_indices
392
+
393
+ def forward(self, input: torch.Tensor):
394
+ """
395
+ Forward pass of the router.
396
+
397
+ Args:
398
+ input (torch.Tensor): Input tensor.
399
+ """
400
+ self._maintain_float32_expert_bias()
401
+ logits = self.gating(input)
402
+ top_scores, top_indices = self.routing(logits)
403
+ return top_scores, top_indices
404
+
405
+
406
+ class ReMoERouter(BaseRouter):
407
+ def __init__(self, config: BlockFFNConfig) -> None:
408
+ super().__init__(config)
409
+ self.config = config
410
+ self.router_act = torch.nn.ReLU()
411
+
412
+ def routing(self, logits: torch.Tensor):
413
+ """Top-k routing function
414
+
415
+ Args:
416
+ logits (torch.Tensor): Logits tensor after gating.
417
+
418
+ Returns:
419
+ probs (torch.Tensor): The probabilities of token to experts assignment.
420
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
421
+ with shape [num_tokens, num_experts].
422
+ """
423
+ logits = logits.view(-1, self.num_experts)
424
+
425
+ router_score = self.router_act(logits)
426
+ routing_map = router_score > 0
427
+
428
+ sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
429
+ sorted_map = sorted_probs <= 0
430
+ sorted_indices = torch.where(sorted_map, -1, sorted_indices)
431
+ max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1)
432
+ assert torch.all(sorted_map[:, max_valid_num:])
433
+ sorted_probs = sorted_probs[:, :max_valid_num]
434
+ sorted_indices = sorted_indices[:, :max_valid_num]
435
+ assert torch.sum(routing_map) == torch.sum(sorted_indices != -1)
436
+ return sorted_probs, sorted_indices
437
+
438
+ def forward(self, input: torch.Tensor):
439
+ """
440
+ Forward pass of the router.
441
+
442
+ Args:
443
+ input (torch.Tensor): Input tensor.
444
+ """
445
+ logits = self.gating(input)
446
+ top_scores, top_indices = self.routing(logits)
447
+ return top_scores, top_indices
448
+
449
+
450
+ class TopPRouter(BaseRouter):
451
+ def __init__(self, config: BlockFFNConfig) -> None:
452
+ super().__init__(config)
453
+ self.config = config
454
+ self.top_p = config.moe_router_topp
455
+
456
+ def routing(self, logits: torch.Tensor):
457
+ """Top-k routing function
458
+
459
+ Args:
460
+ logits (torch.Tensor): Logits tensor after gating.
461
+
462
+ Returns:
463
+ probs (torch.Tensor): The probabilities of token to experts assignment.
464
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
465
+ with shape [num_tokens, num_experts].
466
+ """
467
+ logits = logits.view(-1, self.num_experts)
468
+
469
+ router_score = torch.abs(logits)
470
+ router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20)
471
+
472
+ sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
473
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
474
+ mask = cumulative_probs > self.top_p
475
+
476
+ threshold_indices = mask.long().argmax(dim=-1)
477
+ threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
478
+
479
+ mask = mask & ~threshold_mask
480
+ sorted_indices = torch.where(mask, -1, sorted_indices)
481
+ sorted_probs = torch.where(mask, 0.0, sorted_probs)
482
+
483
+ max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1)
484
+ assert torch.all(mask[:, max_valid_num:])
485
+
486
+ sorted_indices = sorted_indices[:, :max_valid_num]
487
+ sorted_probs = sorted_probs[:, :max_valid_num]
488
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
489
+ return sorted_probs, sorted_indices
490
+
491
+ def forward(self, input: torch.Tensor):
492
+ """
493
+ Forward pass of the router.
494
+
495
+ Args:
496
+ input (torch.Tensor): Input tensor.
497
+ """
498
+ logits = self.gating(input)
499
+ top_scores, top_indices = self.routing(logits)
500
+ return top_scores, top_indices
501
+
502
+
503
+ class FastTopKCalculator:
504
+ def __init__(self, num_experts: int):
505
+ self.num_experts = num_experts
506
+
507
+ def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module):
508
+ (
509
+ pos,
510
+ local_expert_count,
511
+ global_expert_count,
512
+ fwd_expert_count,
513
+ fwd_batch_size,
514
+ ) = prepare_forward(topk_indices, self.num_experts, 1)
515
+ topk = 1
516
+ if len(topk_indices.shape) == 2:
517
+ topk = topk_indices.shape[1]
518
+
519
+ def scatter_func(tensor):
520
+ return MOEScatter.apply(
521
+ tensor,
522
+ torch.div(pos, topk, rounding_mode='floor'),
523
+ local_expert_count,
524
+ global_expert_count,
525
+ fwd_batch_size,
526
+ 1,
527
+ )
528
+
529
+ x = tree.map_structure(scatter_func, hidden_states)
530
+ x = experts(x, fwd_expert_count, topk_indices=topk_indices)
531
+
532
+ out_batch_size = tree.flatten(hidden_states)[0].shape[0]
533
+ if len(topk_indices.shape) == 2:
534
+ out_batch_size *= topk_indices.shape[1]
535
+
536
+ def gather_func(tensor):
537
+ return MOEGather.apply(
538
+ tensor,
539
+ pos,
540
+ local_expert_count,
541
+ global_expert_count,
542
+ out_batch_size,
543
+ 1,
544
+ )
545
+
546
+ outp = tree.map_structure(gather_func, x)
547
+ return outp
548
+
549
+ def forward(self, hidden_states, topk_indices, topk_weights, experts):
550
+ assert topk_indices.shape == topk_weights.shape
551
+ top_k = topk_indices.shape[-1]
552
+ dim3 = hidden_states.ndim == 3
553
+ if dim3:
554
+ batch_size, seq_len, dim = hidden_states.shape
555
+ hidden_states = hidden_states.view(batch_size * seq_len, dim)
556
+ else:
557
+ assert hidden_states.ndim == 2
558
+ batch_size, (seq_len, dim) = -1, hidden_states.shape
559
+ fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts)
560
+
561
+ def view_func(tensor):
562
+ n_dim = tensor.shape[-1]
563
+ tensor = tensor.view(-1, top_k, n_dim)
564
+ return tensor
565
+
566
+ moe_output = tree.map_structure(view_func, fwd)
567
+ topk_weights = topk_weights.unsqueeze(1)
568
+
569
+ def bmm_func(tensor):
570
+ n_dim = tensor.shape[-1]
571
+ tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim)
572
+ return tensor
573
+
574
+ moe_output = tree.map_structure(bmm_func, moe_output)
575
+ if dim3:
576
+ moe_output = moe_output.view(batch_size, seq_len, -1)
577
+ return moe_output
578
+
579
+
580
+ class MoELinearExperts(nn.Module):
581
+ def __init__(
582
+ self,
583
+ dim_in: int,
584
+ dim_out: int,
585
+ num_experts: int,
586
+ ffn_bias: bool,
587
+ ):
588
+ super().__init__()
589
+ self.dim_in = self.in_features = dim_in
590
+ self.dim_out = self.out_features = dim_out
591
+ self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
592
+ self.bias = None
593
+ if ffn_bias:
594
+ self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out))
595
+
596
+ def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor):
597
+ x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias)
598
+ return x
599
+
600
+
601
+ class MoEGatedExperts(nn.Module):
602
+ def __init__(
603
+ self,
604
+ dim_in: int,
605
+ dim_ff: int,
606
+ is_gated: bool,
607
+ act_name: str,
608
+ num_experts: int,
609
+ ffn_bias: bool = False,
610
+ ):
611
+ super().__init__()
612
+ self.is_gated = is_gated
613
+ self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts
614
+ if self.is_gated:
615
+ self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
616
+ self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
617
+ self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias)
618
+
619
+ self.act_fn = ACT2FN[act_name]
620
+
621
+ def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor:
622
+ if self.is_gated:
623
+ gate_score = self.gate_proj(x, fwd_expert_count)
624
+ up_proj = self.up_proj(x, fwd_expert_count)
625
+ x = up_proj * self.act_fn(gate_score)
626
+ else:
627
+ up_score = self.up_proj(x, fwd_expert_count)
628
+ x = self.act_fn(up_score)
629
+ x = self.down_proj(x, fwd_expert_count)
630
+ return x
631
+
632
+
633
+ class VanillaMoELayer(nn.Module):
634
+ def __init__(self, config: BlockFFNConfig):
635
+ super(VanillaMoELayer, self).__init__()
636
+ self.config = config
637
+
638
+ # Initialize router
639
+ if config.router_type == "topk":
640
+ self.router = TopKRouter(config=self.config)
641
+ elif config.router_type == "remoe":
642
+ self.router = ReMoERouter(config=self.config)
643
+ elif config.router_type == "topp":
644
+ self.router = TopPRouter(config=self.config)
645
+ else:
646
+ raise NotImplementedError(f"Router type {config.router_type} not implemented.")
647
+
648
+ self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts)
649
+
650
+ # Initialize experts
651
+ self.experts = MoEGatedExperts(
652
+ dim_in=self.config.hidden_size,
653
+ dim_ff=self.config.moe_ffn_hidden_size,
654
+ is_gated=not self.config.expert_not_gated,
655
+ act_name="silu",
656
+ num_experts=self.config.num_experts,
657
+ )
658
+
659
+ self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size
660
+ self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
661
+ if self.use_shared_expert:
662
+ self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
663
+
664
+ def forward(self, hidden_states: torch.Tensor):
665
+ top_scores, top_indices = self.router(hidden_states)
666
+ y = self.mix_calculator.forward(
667
+ hidden_states=hidden_states,
668
+ topk_indices=top_indices.contiguous(),
669
+ topk_weights=top_scores.type_as(hidden_states),
670
+ experts=self.experts,
671
+ )
672
+ if self.shared_experts is not None:
673
+ y = y + self.shared_experts(hidden_states)
674
+ return y
675
+
676
+
677
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
678
+ """
679
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
680
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
681
+ """
682
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
683
+ if n_rep == 1:
684
+ return hidden_states
685
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
686
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
687
+
688
+
689
+ def eager_attention_forward(
690
+ module: nn.Module,
691
+ query: torch.Tensor,
692
+ key: torch.Tensor,
693
+ value: torch.Tensor,
694
+ attention_mask: Optional[torch.Tensor],
695
+ scaling: float,
696
+ dropout: float = 0.0,
697
+ ):
698
+ key_states = repeat_kv(key, module.num_key_value_groups)
699
+ value_states = repeat_kv(value, module.num_key_value_groups)
700
+
701
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
702
+ if attention_mask is not None:
703
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
704
+ attn_weights = attn_weights + causal_mask
705
+
706
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
707
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
708
+ attn_output = torch.matmul(attn_weights, value_states)
709
+ attn_output = attn_output.transpose(1, 2).contiguous()
710
+
711
+ return attn_output, attn_weights
712
+
713
+
714
+ class BlockFFNAttention(nn.Module):
715
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
716
+
717
+ def __init__(self, config: BlockFFNConfig, layer_idx: int):
718
+ super().__init__()
719
+ self.config = config
720
+ self.layer_idx = layer_idx
721
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
722
+ self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
723
+ self.scaling = self.head_dim**-0.5
724
+ self.attention_dropout = config.attention_dropout
725
+ self.is_causal = True
726
+
727
+ self.q_proj = nn.Linear(
728
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
729
+ )
730
+ self.k_proj = nn.Linear(
731
+ config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
732
+ )
733
+ self.v_proj = nn.Linear(
734
+ config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
735
+ )
736
+ self.o_proj = nn.Linear(
737
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
738
+ )
739
+
740
+ def forward(
741
+ self,
742
+ hidden_states: torch.Tensor,
743
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
744
+ attention_mask: Optional[torch.Tensor],
745
+ past_key_value: Optional[Cache] = None,
746
+ cache_position: Optional[torch.LongTensor] = None,
747
+ **kwargs: Unpack[TransformersKwargs],
748
+ ) -> tuple[torch.Tensor, torch.Tensor]:
749
+ input_shape = hidden_states.shape[:-1]
750
+ hidden_shape = (*input_shape, -1, self.head_dim)
751
+
752
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
753
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
754
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
755
+
756
+ cos, sin = position_embeddings
757
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
758
+
759
+ if past_key_value is not None:
760
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
761
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
762
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
763
+
764
+ attention_interface: Callable = eager_attention_forward
765
+ if self.config._attn_implementation != "eager":
766
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
767
+
768
+ attn_output, attn_weights = attention_interface(
769
+ self,
770
+ query_states,
771
+ key_states,
772
+ value_states,
773
+ attention_mask,
774
+ dropout=0.0 if not self.training else self.attention_dropout,
775
+ scaling=self.scaling,
776
+ **kwargs,
777
+ )
778
+
779
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
780
+ attn_output = self.o_proj(attn_output)
781
+ return attn_output, attn_weights
782
+
783
+
784
+ class BlockFFNDecoderLayer(GradientCheckpointingLayer):
785
+ def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool):
786
+ super().__init__()
787
+ self.config = config
788
+ self.hidden_size = config.hidden_size
789
+
790
+ self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx)
791
+
792
+ if is_moe_layer:
793
+ if config.use_blockffn:
794
+ self.mlp = BlockFFNLayer(config)
795
+ elif config.router_type in ["topk", "remoe", "topp"]:
796
+ self.mlp = VanillaMoELayer(config)
797
+ else:
798
+ raise NotImplementedError
799
+ else:
800
+ self.mlp = BlockFFNMLP(config)
801
+ self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
802
+ self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
803
+
804
+ def forward(
805
+ self,
806
+ hidden_states: torch.Tensor,
807
+ attention_mask: Optional[torch.Tensor] = None,
808
+ position_ids: Optional[torch.LongTensor] = None,
809
+ past_key_value: Optional[Cache] = None,
810
+ use_cache: Optional[bool] = False,
811
+ cache_position: Optional[torch.LongTensor] = None,
812
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
813
+ **kwargs: Unpack[TransformersKwargs],
814
+ ) -> tuple[torch.Tensor]:
815
+ residual = hidden_states
816
+ hidden_states = self.input_layernorm(hidden_states)
817
+ # Self Attention
818
+ hidden_states, _ = self.self_attn(
819
+ hidden_states=hidden_states,
820
+ attention_mask=attention_mask,
821
+ position_ids=position_ids,
822
+ past_key_value=past_key_value,
823
+ use_cache=use_cache,
824
+ cache_position=cache_position,
825
+ position_embeddings=position_embeddings,
826
+ **kwargs,
827
+ )
828
+ if self.config.use_mup:
829
+ hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
830
+ else:
831
+ hidden_states = residual + hidden_states
832
+
833
+ # Fully Connected
834
+ residual = hidden_states
835
+ hidden_states = self.post_attention_layernorm(hidden_states)
836
+ hidden_states = self.mlp(hidden_states)
837
+ if self.config.use_mup:
838
+ hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
839
+ else:
840
+ hidden_states = residual + hidden_states
841
+ return hidden_states
842
+
843
+
844
+ @auto_docstring
845
+ class BlockFFNPreTrainedModel(PreTrainedModel):
846
+ config: BlockFFNConfig
847
+ base_model_prefix = "model"
848
+ supports_gradient_checkpointing = True
849
+ _no_split_modules = ["BlockFFNDecoderLayer"]
850
+ _skip_keys_device_placement = ["past_key_values"]
851
+ _supports_flash_attn = True
852
+ _supports_sdpa = True
853
+ _supports_flex_attn = True
854
+
855
+ _can_compile_fullgraph = True
856
+ _supports_attention_backend = True
857
+ _can_record_outputs = {
858
+ "hidden_states": BlockFFNDecoderLayer,
859
+ "attentions": BlockFFNAttention,
860
+ }
861
+
862
+
863
+ @auto_docstring
864
+ class BlockFFNModel(BlockFFNPreTrainedModel):
865
+ def __init__(self, config: BlockFFNConfig):
866
+ super().__init__(config)
867
+ self.config = config
868
+ self.padding_idx = config.pad_token_id
869
+ self.vocab_size = config.vocab_size
870
+
871
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
872
+ self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq
873
+ assert len(self.moe_layer_freq) == config.num_layers
874
+ self.layers = nn.ModuleList(
875
+ [BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)]
876
+ )
877
+ self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
878
+ self.rotary_emb = BlockFFNRotaryEmbedding(config=config)
879
+ self.gradient_checkpointing = False
880
+
881
+ # Initialize weights and apply final processing
882
+ self.post_init()
883
+
884
+ @check_model_inputs
885
+ @auto_docstring
886
+ def forward(
887
+ self,
888
+ input_ids: Optional[torch.LongTensor] = None,
889
+ attention_mask: Optional[torch.Tensor] = None,
890
+ position_ids: Optional[torch.LongTensor] = None,
891
+ past_key_values: Optional[Cache] = None,
892
+ inputs_embeds: Optional[torch.FloatTensor] = None,
893
+ cache_position: Optional[torch.LongTensor] = None,
894
+ use_cache: Optional[bool] = None,
895
+ **kwargs: Unpack[TransformersKwargs],
896
+ ) -> BaseModelOutputWithPast:
897
+ if (input_ids is None) ^ (inputs_embeds is not None):
898
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
899
+
900
+ if inputs_embeds is None:
901
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
902
+ if self.config.use_mup:
903
+ inputs_embeds = inputs_embeds * self.config.mup_emb_scale
904
+
905
+ if use_cache and past_key_values is None:
906
+ past_key_values = DynamicCache()
907
+
908
+ if cache_position is None:
909
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
910
+ cache_position: torch.Tensor = torch.arange(
911
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
912
+ )
913
+
914
+ if position_ids is None:
915
+ position_ids = cache_position.unsqueeze(0)
916
+
917
+ causal_mask = create_causal_mask(
918
+ config=self.config,
919
+ input_embeds=inputs_embeds,
920
+ attention_mask=attention_mask,
921
+ cache_position=cache_position,
922
+ past_key_values=past_key_values,
923
+ position_ids=position_ids,
924
+ )
925
+
926
+ hidden_states = inputs_embeds
927
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
928
+
929
+ for decoder_layer in self.layers[: self.config.num_layers]:
930
+ hidden_states = decoder_layer(
931
+ hidden_states,
932
+ attention_mask=causal_mask,
933
+ position_ids=position_ids,
934
+ past_key_value=past_key_values,
935
+ cache_position=cache_position,
936
+ position_embeddings=position_embeddings,
937
+ **kwargs,
938
+ )
939
+
940
+ hidden_states = self.norm(hidden_states)
941
+ return BaseModelOutputWithPast(
942
+ last_hidden_state=hidden_states,
943
+ past_key_values=past_key_values,
944
+ )
945
+
946
+
947
+ @auto_docstring
948
+ class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin):
949
+ _tied_weights_keys = ["lm_head.weight"]
950
+ _tp_plan = {"lm_head": "colwise_rep"}
951
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
952
+
953
+ def __init__(self, config: BlockFFNConfig):
954
+ super().__init__(config)
955
+ self.config = config
956
+ self.model = BlockFFNModel(config)
957
+ self.vocab_size = config.vocab_size
958
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
959
+
960
+ # Initialize weights and apply final processing
961
+ self.post_init()
962
+
963
+ def set_decoder(self, decoder):
964
+ self.model = decoder
965
+
966
+ def get_decoder(self):
967
+ return self.model
968
+
969
+ @can_return_tuple
970
+ @auto_docstring
971
+ def forward(
972
+ self,
973
+ input_ids: Optional[torch.LongTensor] = None,
974
+ attention_mask: Optional[torch.Tensor] = None,
975
+ position_ids: Optional[torch.LongTensor] = None,
976
+ past_key_values: Optional[Cache] = None,
977
+ inputs_embeds: Optional[torch.FloatTensor] = None,
978
+ labels: Optional[torch.LongTensor] = None,
979
+ use_cache: Optional[bool] = None,
980
+ cache_position: Optional[torch.LongTensor] = None,
981
+ logits_to_keep: Union[int, torch.Tensor] = 0,
982
+ **kwargs: Unpack[TransformersKwargs],
983
+ ) -> CausalLMOutputWithPast:
984
+ outputs: BaseModelOutputWithPast = self.model(
985
+ input_ids=input_ids,
986
+ attention_mask=attention_mask,
987
+ position_ids=position_ids,
988
+ past_key_values=past_key_values,
989
+ inputs_embeds=inputs_embeds,
990
+ use_cache=use_cache,
991
+ cache_position=cache_position,
992
+ **kwargs,
993
+ )
994
+
995
+ hidden_states = outputs.last_hidden_state
996
+ if self.config.use_mup:
997
+ hidden_states = hidden_states / self.config.mup_width_scale
998
+
999
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1000
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1001
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1002
+
1003
+ loss = None
1004
+ if labels is not None:
1005
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1006
+
1007
+ return CausalLMOutputWithPast(
1008
+ loss=loss,
1009
+ logits=logits,
1010
+ past_key_values=outputs.past_key_values,
1011
+ hidden_states=outputs.hidden_states,
1012
+ attentions=outputs.attentions,
1013
+ )
1014
+
1015
+ __all__ = [
1016
+ "BlockFFNForCausalLM",
1017
+ "BlockFFNModel",
1018
+ "BlockFFNPreTrainedModel",
1019
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2685206ec453d7b43292cf86cd09d6e7a53bbb8a8e008325a9f1ee38d03e8958
3
+ size 443784525
special_tokens_map.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|im_end|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<|im_start|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "<|tool_call|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "<|execute_start|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "<|execute_end|>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ {
39
+ "content": "<|fim_prefix|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ },
45
+ {
46
+ "content": "<|fim_middle|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false
51
+ },
52
+ {
53
+ "content": "<|fim_suffix|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false
58
+ }
59
+ ],
60
+ "bos_token": {
61
+ "content": "<s>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false
66
+ },
67
+ "eos_token": {
68
+ "content": "</s>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false
73
+ },
74
+ "unk_token": {
75
+ "content": "<unk>",
76
+ "lstrip": false,
77
+ "normalized": false,
78
+ "rstrip": false,
79
+ "single_word": false
80
+ }
81
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
3
+ size 1181204
tokenizer_config.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "73440": {
30
+ "content": "<|im_end|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "73441": {
38
+ "content": "<|im_start|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "73442": {
46
+ "content": "<|tool_call|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "73443": {
54
+ "content": "<|execute_start|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "73444": {
62
+ "content": "<|execute_end|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "73445": {
70
+ "content": "<|fim_prefix|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "73446": {
78
+ "content": "<|fim_middle|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "73447": {
86
+ "content": "<|fim_suffix|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ }
93
+ },
94
+ "additional_special_tokens": [
95
+ "<|im_end|>",
96
+ "<|im_start|>",
97
+ "<|tool_call|>",
98
+ "<|execute_start|>",
99
+ "<|execute_end|>",
100
+ "<|fim_prefix|>",
101
+ "<|fim_middle|>",
102
+ "<|fim_suffix|>"
103
+ ],
104
+ "bos_token": "<s>",
105
+ "clean_up_tokenization_spaces": false,
106
+ "eos_token": "<|im_end|>",
107
+ "legacy": true,
108
+ "model_max_length": 1000000000000000019884624838656,
109
+ "pad_token": null,
110
+ "sp_model_kwargs": {},
111
+ "spaces_between_special_tokens": false,
112
+ "tokenizer_class": "LlamaTokenizer",
113
+ "unk_token": "<unk>",
114
+ "use_default_system_prompt": false,
115
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
116
+ }