Upload folder using huggingface_hub
Browse files- c4_validation.json +0 -0
- config.json +77 -0
- configuration_blockffn.py +161 -0
- evaluation.log +0 -0
- evaluation.log.bak +0 -0
- evaluation/results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__/results_2026-01-23T14-52-19.555032.json +609 -0
- evaluation/results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__/results_2026-03-30T20-55-02.181492.json +609 -0
- evaluation2.log +0 -0
- generation_config.json +8 -0
- modeling_blockffn.py +1014 -0
- modeling_blockffn.py.bak +1024 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +81 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +116 -0
c4_validation.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config.json
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BlockFFNForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_blockffn.BlockFFNConfig",
|
| 7 |
+
"AutoModel": "modeling_blockffn.BlockFFNModel",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_blockffn.BlockFFNForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"bos_token_id": 1,
|
| 11 |
+
"eos_token_id": [
|
| 12 |
+
2,
|
| 13 |
+
73440
|
| 14 |
+
],
|
| 15 |
+
"pad_token_id": 2,
|
| 16 |
+
"hidden_act": "silu",
|
| 17 |
+
"hidden_size": 1024,
|
| 18 |
+
"initializer_range": 0.1,
|
| 19 |
+
"intermediate_size": 10240,
|
| 20 |
+
"head_dim": 128,
|
| 21 |
+
"max_position_embeddings": 4096,
|
| 22 |
+
"model_type": "blockffn",
|
| 23 |
+
"num_attention_heads": 8,
|
| 24 |
+
"num_hidden_layers": 32,
|
| 25 |
+
"num_key_value_heads": 2,
|
| 26 |
+
"rms_norm_eps": 1e-05,
|
| 27 |
+
"rope_scaling": null,
|
| 28 |
+
"rope_theta": 10000.0,
|
| 29 |
+
"torch_dtype": "bfloat16",
|
| 30 |
+
"transformers_version": "4.36.0",
|
| 31 |
+
"use_cache": true,
|
| 32 |
+
"vocab_size": 73448,
|
| 33 |
+
"use_mup": false,
|
| 34 |
+
"num_experts": 57,
|
| 35 |
+
"moe_ffn_hidden_size": 64,
|
| 36 |
+
"moe_shared_expert_intermediate_size": 128,
|
| 37 |
+
"moe_layer_freq": [
|
| 38 |
+
0,
|
| 39 |
+
1,
|
| 40 |
+
1,
|
| 41 |
+
1,
|
| 42 |
+
1,
|
| 43 |
+
1,
|
| 44 |
+
1,
|
| 45 |
+
1,
|
| 46 |
+
1,
|
| 47 |
+
1,
|
| 48 |
+
1,
|
| 49 |
+
1,
|
| 50 |
+
1,
|
| 51 |
+
1,
|
| 52 |
+
1,
|
| 53 |
+
1,
|
| 54 |
+
1,
|
| 55 |
+
1,
|
| 56 |
+
1,
|
| 57 |
+
1
|
| 58 |
+
],
|
| 59 |
+
"moe_router_dtype": "fp32",
|
| 60 |
+
"router_act_func": "relu",
|
| 61 |
+
"router_norm_type": "simple",
|
| 62 |
+
"expert_act_func": "norm_silu",
|
| 63 |
+
"expert_act_norm_type": "normal",
|
| 64 |
+
"num_layers": 20,
|
| 65 |
+
"ffn_hidden_size": 2560,
|
| 66 |
+
"num_query_groups": 8,
|
| 67 |
+
"norm_epsilon": 1e-05,
|
| 68 |
+
"use_blockffn": true,
|
| 69 |
+
"router_type": "topk",
|
| 70 |
+
"moe_router_enable_expert_bias": false,
|
| 71 |
+
"expert_not_gated": true,
|
| 72 |
+
"moe_router_pre_softmax": false,
|
| 73 |
+
"moe_router_topk": 2,
|
| 74 |
+
"moe_router_topp": 0.5,
|
| 75 |
+
"moe_router_score_function": "softmax",
|
| 76 |
+
"moe_router_topk_scaling_factor": null
|
| 77 |
+
}
|
configuration_blockffn.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""BlockFFN model configuration"""
|
| 21 |
+
|
| 22 |
+
from transformers import PretrainedConfig
|
| 23 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BlockFFNConfig(PretrainedConfig):
|
| 27 |
+
|
| 28 |
+
model_type = "blockffn"
|
| 29 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 30 |
+
# Default tensor parallel plan for base model `BlockFFNModel`
|
| 31 |
+
base_model_tp_plan = {
|
| 32 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 33 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 34 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 35 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 36 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 37 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 38 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 39 |
+
}
|
| 40 |
+
base_model_pp_plan = {
|
| 41 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 42 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 43 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
vocab_size=32000,
|
| 49 |
+
hidden_size=4096,
|
| 50 |
+
ffn_hidden_size=11008,
|
| 51 |
+
num_layers=32,
|
| 52 |
+
num_attention_heads=32,
|
| 53 |
+
num_query_groups=None,
|
| 54 |
+
hidden_act="silu",
|
| 55 |
+
max_position_embeddings=2048,
|
| 56 |
+
initializer_range=0.02,
|
| 57 |
+
norm_epsilon=1e-6,
|
| 58 |
+
use_cache=True,
|
| 59 |
+
pad_token_id=None,
|
| 60 |
+
bos_token_id=1,
|
| 61 |
+
eos_token_id=2,
|
| 62 |
+
pretraining_tp=1,
|
| 63 |
+
tie_word_embeddings=False,
|
| 64 |
+
rope_theta=10000.0,
|
| 65 |
+
rope_scaling=None,
|
| 66 |
+
attention_bias=False,
|
| 67 |
+
attention_dropout=0.0,
|
| 68 |
+
mlp_bias=False,
|
| 69 |
+
head_dim=None,
|
| 70 |
+
use_mup=True,
|
| 71 |
+
mup_emb_scale=12,
|
| 72 |
+
mup_depth_scale=1.4,
|
| 73 |
+
mup_base_hidden_size=256,
|
| 74 |
+
num_experts=180,
|
| 75 |
+
moe_ffn_hidden_size=128,
|
| 76 |
+
moe_shared_expert_intermediate_size=128,
|
| 77 |
+
moe_layer_freq="([0]*3+[1]*29)",
|
| 78 |
+
moe_router_dtype="fp32",
|
| 79 |
+
router_act_func="relu",
|
| 80 |
+
router_norm_type="simple",
|
| 81 |
+
expert_act_func="norm_silu",
|
| 82 |
+
expert_act_norm_type="normal",
|
| 83 |
+
use_blockffn=False,
|
| 84 |
+
router_type="topk",
|
| 85 |
+
moe_router_topk=0,
|
| 86 |
+
moe_router_topp=0,
|
| 87 |
+
moe_router_enable_expert_bias=False,
|
| 88 |
+
moe_router_score_function="sigmoid",
|
| 89 |
+
moe_router_topk_scaling_factor=2.5,
|
| 90 |
+
expert_not_gated=False,
|
| 91 |
+
moe_router_pre_softmax=False,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
self.vocab_size = vocab_size
|
| 95 |
+
self.max_position_embeddings = max_position_embeddings
|
| 96 |
+
self.hidden_size = hidden_size
|
| 97 |
+
self.ffn_hidden_size = ffn_hidden_size
|
| 98 |
+
self.num_layers = num_layers
|
| 99 |
+
self.num_attention_heads = num_attention_heads
|
| 100 |
+
|
| 101 |
+
# for backward compatibility
|
| 102 |
+
if num_query_groups is None:
|
| 103 |
+
num_query_groups = num_attention_heads
|
| 104 |
+
|
| 105 |
+
self.num_query_groups = num_query_groups
|
| 106 |
+
self.hidden_act = hidden_act
|
| 107 |
+
self.initializer_range = initializer_range
|
| 108 |
+
self.norm_epsilon = norm_epsilon
|
| 109 |
+
self.pretraining_tp = pretraining_tp
|
| 110 |
+
self.use_cache = use_cache
|
| 111 |
+
self.rope_theta = rope_theta
|
| 112 |
+
self.rope_scaling = rope_scaling
|
| 113 |
+
self.attention_bias = attention_bias
|
| 114 |
+
self.attention_dropout = attention_dropout
|
| 115 |
+
self.mlp_bias = mlp_bias
|
| 116 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 117 |
+
self.use_mup = use_mup
|
| 118 |
+
self.mup_emb_scale = mup_emb_scale
|
| 119 |
+
self.mup_depth_scale = mup_depth_scale
|
| 120 |
+
self.mup_base_hidden_size = mup_base_hidden_size
|
| 121 |
+
|
| 122 |
+
self.num_experts = num_experts
|
| 123 |
+
self.moe_ffn_hidden_size = moe_ffn_hidden_size
|
| 124 |
+
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
|
| 125 |
+
self.moe_layer_freq = moe_layer_freq if isinstance(moe_layer_freq, (str, list)) else ([0] * num_layers)
|
| 126 |
+
self.moe_router_dtype = moe_router_dtype
|
| 127 |
+
self.router_act_func = router_act_func
|
| 128 |
+
self.router_norm_type = router_norm_type
|
| 129 |
+
self.expert_act_func = expert_act_func
|
| 130 |
+
self.expert_act_norm_type = expert_act_norm_type
|
| 131 |
+
|
| 132 |
+
self.use_blockffn = use_blockffn
|
| 133 |
+
self.router_type = router_type
|
| 134 |
+
self.moe_router_topk = moe_router_topk
|
| 135 |
+
self.moe_router_topp = moe_router_topp
|
| 136 |
+
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
|
| 137 |
+
self.moe_router_score_function = moe_router_score_function
|
| 138 |
+
self.moe_router_topk_scaling_factor = moe_router_topk_scaling_factor
|
| 139 |
+
self.expert_not_gated = expert_not_gated
|
| 140 |
+
self.moe_router_pre_softmax = moe_router_pre_softmax
|
| 141 |
+
|
| 142 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 143 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 144 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 145 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 146 |
+
rope_config_validation(self)
|
| 147 |
+
|
| 148 |
+
super().__init__(
|
| 149 |
+
pad_token_id=pad_token_id,
|
| 150 |
+
bos_token_id=bos_token_id,
|
| 151 |
+
eos_token_id=eos_token_id,
|
| 152 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 153 |
+
**kwargs,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def mup_width_scale(self):
|
| 158 |
+
return (self.hidden_size / self.mup_base_hidden_size) if (self.use_mup and self.mup_base_hidden_size > 0) else 1
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
__all__ = ["BlockFFNConfig"]
|
evaluation.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation.log.bak
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
evaluation/results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__/results_2026-01-23T14-52-19.555032.json
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"arc_challenge": {
|
| 4 |
+
"alias": "arc_challenge",
|
| 5 |
+
"acc,none": 0.20733788395904437,
|
| 6 |
+
"acc_stderr,none": 0.011846905782971337,
|
| 7 |
+
"acc_norm,none": 0.24914675767918087,
|
| 8 |
+
"acc_norm_stderr,none": 0.012639407111926433
|
| 9 |
+
},
|
| 10 |
+
"arc_easy": {
|
| 11 |
+
"alias": "arc_easy",
|
| 12 |
+
"acc,none": 0.5197811447811448,
|
| 13 |
+
"acc_stderr,none": 0.010251751199542735,
|
| 14 |
+
"acc_norm,none": 0.44907407407407407,
|
| 15 |
+
"acc_norm_stderr,none": 0.010206428316323365
|
| 16 |
+
},
|
| 17 |
+
"boolq": {
|
| 18 |
+
"alias": "boolq",
|
| 19 |
+
"acc,none": 0.6113149847094801,
|
| 20 |
+
"acc_stderr,none": 0.008525580498982967
|
| 21 |
+
},
|
| 22 |
+
"hellaswag": {
|
| 23 |
+
"alias": "hellaswag",
|
| 24 |
+
"acc,none": 0.29466241784505076,
|
| 25 |
+
"acc_stderr,none": 0.004549591490046219,
|
| 26 |
+
"acc_norm,none": 0.3293168691495718,
|
| 27 |
+
"acc_norm_stderr,none": 0.004690047021719822
|
| 28 |
+
},
|
| 29 |
+
"lambada_openai": {
|
| 30 |
+
"alias": "lambada_openai",
|
| 31 |
+
"perplexity,none": 66.63752818011969,
|
| 32 |
+
"perplexity_stderr,none": 2.753797550415125,
|
| 33 |
+
"acc,none": 0.27653793906462254,
|
| 34 |
+
"acc_stderr,none": 0.006231567654090107
|
| 35 |
+
},
|
| 36 |
+
"lambada_standard": {
|
| 37 |
+
"alias": "lambada_standard",
|
| 38 |
+
"perplexity,none": 198.31426320573388,
|
| 39 |
+
"perplexity_stderr,none": 8.56348322119257,
|
| 40 |
+
"acc,none": 0.19483795847079372,
|
| 41 |
+
"acc_stderr,none": 0.005518111913121867
|
| 42 |
+
},
|
| 43 |
+
"piqa": {
|
| 44 |
+
"alias": "piqa",
|
| 45 |
+
"acc,none": 0.6436343852013058,
|
| 46 |
+
"acc_stderr,none": 0.011174109865864717,
|
| 47 |
+
"acc_norm,none": 0.6131664853101197,
|
| 48 |
+
"acc_norm_stderr,none": 0.011363095931902848
|
| 49 |
+
},
|
| 50 |
+
"social_iqa": {
|
| 51 |
+
"alias": "social_iqa",
|
| 52 |
+
"acc,none": 0.3618219037871034,
|
| 53 |
+
"acc_stderr,none": 0.010873447266941618
|
| 54 |
+
},
|
| 55 |
+
"wikitext": {
|
| 56 |
+
"alias": "wikitext",
|
| 57 |
+
"word_perplexity,none": 39.00052118822888,
|
| 58 |
+
"word_perplexity_stderr,none": "N/A",
|
| 59 |
+
"byte_perplexity,none": 1.9839837931891766,
|
| 60 |
+
"byte_perplexity_stderr,none": "N/A",
|
| 61 |
+
"bits_per_byte,none": 0.9884002406536678,
|
| 62 |
+
"bits_per_byte_stderr,none": "N/A"
|
| 63 |
+
},
|
| 64 |
+
"winogrande": {
|
| 65 |
+
"alias": "winogrande",
|
| 66 |
+
"acc,none": 0.5209155485398579,
|
| 67 |
+
"acc_stderr,none": 0.014040185494212945
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
"group_subtasks": {
|
| 71 |
+
"arc_challenge": [],
|
| 72 |
+
"arc_easy": [],
|
| 73 |
+
"boolq": [],
|
| 74 |
+
"hellaswag": [],
|
| 75 |
+
"lambada_openai": [],
|
| 76 |
+
"lambada_standard": [],
|
| 77 |
+
"piqa": [],
|
| 78 |
+
"social_iqa": [],
|
| 79 |
+
"wikitext": [],
|
| 80 |
+
"winogrande": []
|
| 81 |
+
},
|
| 82 |
+
"configs": {
|
| 83 |
+
"arc_challenge": {
|
| 84 |
+
"task": "arc_challenge",
|
| 85 |
+
"tag": [
|
| 86 |
+
"ai2_arc"
|
| 87 |
+
],
|
| 88 |
+
"dataset_path": "allenai/ai2_arc",
|
| 89 |
+
"dataset_name": "ARC-Challenge",
|
| 90 |
+
"training_split": "train",
|
| 91 |
+
"validation_split": "validation",
|
| 92 |
+
"test_split": "test",
|
| 93 |
+
"doc_to_text": "Question: {{question}}\nAnswer:",
|
| 94 |
+
"doc_to_target": "{{choices.label.index(answerKey)}}",
|
| 95 |
+
"unsafe_code": false,
|
| 96 |
+
"doc_to_choice": "{{choices.text}}",
|
| 97 |
+
"description": "",
|
| 98 |
+
"target_delimiter": " ",
|
| 99 |
+
"fewshot_delimiter": "\n\n",
|
| 100 |
+
"num_fewshot": 0,
|
| 101 |
+
"metric_list": [
|
| 102 |
+
{
|
| 103 |
+
"metric": "acc",
|
| 104 |
+
"aggregation": "mean",
|
| 105 |
+
"higher_is_better": true
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"metric": "acc_norm",
|
| 109 |
+
"aggregation": "mean",
|
| 110 |
+
"higher_is_better": true
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"output_type": "multiple_choice",
|
| 114 |
+
"repeats": 1,
|
| 115 |
+
"should_decontaminate": true,
|
| 116 |
+
"doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
|
| 117 |
+
"metadata": {
|
| 118 |
+
"version": 1.0,
|
| 119 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 120 |
+
"dtype": "bfloat16",
|
| 121 |
+
"trust_remote_code": true
|
| 122 |
+
}
|
| 123 |
+
},
|
| 124 |
+
"arc_easy": {
|
| 125 |
+
"task": "arc_easy",
|
| 126 |
+
"tag": [
|
| 127 |
+
"ai2_arc"
|
| 128 |
+
],
|
| 129 |
+
"dataset_path": "allenai/ai2_arc",
|
| 130 |
+
"dataset_name": "ARC-Easy",
|
| 131 |
+
"training_split": "train",
|
| 132 |
+
"validation_split": "validation",
|
| 133 |
+
"test_split": "test",
|
| 134 |
+
"doc_to_text": "Question: {{question}}\nAnswer:",
|
| 135 |
+
"doc_to_target": "{{choices.label.index(answerKey)}}",
|
| 136 |
+
"unsafe_code": false,
|
| 137 |
+
"doc_to_choice": "{{choices.text}}",
|
| 138 |
+
"description": "",
|
| 139 |
+
"target_delimiter": " ",
|
| 140 |
+
"fewshot_delimiter": "\n\n",
|
| 141 |
+
"num_fewshot": 0,
|
| 142 |
+
"metric_list": [
|
| 143 |
+
{
|
| 144 |
+
"metric": "acc",
|
| 145 |
+
"aggregation": "mean",
|
| 146 |
+
"higher_is_better": true
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"metric": "acc_norm",
|
| 150 |
+
"aggregation": "mean",
|
| 151 |
+
"higher_is_better": true
|
| 152 |
+
}
|
| 153 |
+
],
|
| 154 |
+
"output_type": "multiple_choice",
|
| 155 |
+
"repeats": 1,
|
| 156 |
+
"should_decontaminate": true,
|
| 157 |
+
"doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
|
| 158 |
+
"metadata": {
|
| 159 |
+
"version": 1.0,
|
| 160 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 161 |
+
"dtype": "bfloat16",
|
| 162 |
+
"trust_remote_code": true
|
| 163 |
+
}
|
| 164 |
+
},
|
| 165 |
+
"boolq": {
|
| 166 |
+
"task": "boolq",
|
| 167 |
+
"tag": [
|
| 168 |
+
"super-glue-lm-eval-v1"
|
| 169 |
+
],
|
| 170 |
+
"dataset_path": "super_glue",
|
| 171 |
+
"dataset_name": "boolq",
|
| 172 |
+
"training_split": "train",
|
| 173 |
+
"validation_split": "validation",
|
| 174 |
+
"doc_to_text": "{{passage}}\nQuestion: {{question}}?\nAnswer:",
|
| 175 |
+
"doc_to_target": "label",
|
| 176 |
+
"unsafe_code": false,
|
| 177 |
+
"doc_to_choice": [
|
| 178 |
+
"no",
|
| 179 |
+
"yes"
|
| 180 |
+
],
|
| 181 |
+
"description": "",
|
| 182 |
+
"target_delimiter": " ",
|
| 183 |
+
"fewshot_delimiter": "\n\n",
|
| 184 |
+
"num_fewshot": 0,
|
| 185 |
+
"metric_list": [
|
| 186 |
+
{
|
| 187 |
+
"metric": "acc"
|
| 188 |
+
}
|
| 189 |
+
],
|
| 190 |
+
"output_type": "multiple_choice",
|
| 191 |
+
"repeats": 1,
|
| 192 |
+
"should_decontaminate": true,
|
| 193 |
+
"doc_to_decontamination_query": "passage",
|
| 194 |
+
"metadata": {
|
| 195 |
+
"version": 2.0,
|
| 196 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 197 |
+
"dtype": "bfloat16",
|
| 198 |
+
"trust_remote_code": true
|
| 199 |
+
}
|
| 200 |
+
},
|
| 201 |
+
"hellaswag": {
|
| 202 |
+
"task": "hellaswag",
|
| 203 |
+
"tag": [
|
| 204 |
+
"multiple_choice"
|
| 205 |
+
],
|
| 206 |
+
"dataset_path": "Rowan/hellaswag",
|
| 207 |
+
"training_split": "train",
|
| 208 |
+
"validation_split": "validation",
|
| 209 |
+
"process_docs": "def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:\n def _process_doc(doc):\n ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n out_doc = {\n \"query\": preprocess(doc[\"activity_label\"] + \": \" + ctx),\n \"choices\": [preprocess(ending) for ending in doc[\"endings\"]],\n \"gold\": int(doc[\"label\"]),\n }\n return out_doc\n\n return dataset.map(_process_doc)\n",
|
| 210 |
+
"doc_to_text": "{{query}}",
|
| 211 |
+
"doc_to_target": "{{label}}",
|
| 212 |
+
"unsafe_code": false,
|
| 213 |
+
"doc_to_choice": "choices",
|
| 214 |
+
"description": "",
|
| 215 |
+
"target_delimiter": " ",
|
| 216 |
+
"fewshot_delimiter": "\n\n",
|
| 217 |
+
"num_fewshot": 0,
|
| 218 |
+
"metric_list": [
|
| 219 |
+
{
|
| 220 |
+
"metric": "acc",
|
| 221 |
+
"aggregation": "mean",
|
| 222 |
+
"higher_is_better": true
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"metric": "acc_norm",
|
| 226 |
+
"aggregation": "mean",
|
| 227 |
+
"higher_is_better": true
|
| 228 |
+
}
|
| 229 |
+
],
|
| 230 |
+
"output_type": "multiple_choice",
|
| 231 |
+
"repeats": 1,
|
| 232 |
+
"should_decontaminate": false,
|
| 233 |
+
"metadata": {
|
| 234 |
+
"version": 1.0,
|
| 235 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 236 |
+
"dtype": "bfloat16",
|
| 237 |
+
"trust_remote_code": true
|
| 238 |
+
}
|
| 239 |
+
},
|
| 240 |
+
"lambada_openai": {
|
| 241 |
+
"task": "lambada_openai",
|
| 242 |
+
"tag": [
|
| 243 |
+
"lambada"
|
| 244 |
+
],
|
| 245 |
+
"dataset_path": "EleutherAI/lambada_openai",
|
| 246 |
+
"dataset_name": "default",
|
| 247 |
+
"test_split": "test",
|
| 248 |
+
"doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
|
| 249 |
+
"doc_to_target": "{{' '+text.split(' ')[-1]}}",
|
| 250 |
+
"unsafe_code": false,
|
| 251 |
+
"description": "",
|
| 252 |
+
"target_delimiter": " ",
|
| 253 |
+
"fewshot_delimiter": "\n\n",
|
| 254 |
+
"num_fewshot": 0,
|
| 255 |
+
"metric_list": [
|
| 256 |
+
{
|
| 257 |
+
"metric": "perplexity",
|
| 258 |
+
"aggregation": "perplexity",
|
| 259 |
+
"higher_is_better": false
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"metric": "acc",
|
| 263 |
+
"aggregation": "mean",
|
| 264 |
+
"higher_is_better": true
|
| 265 |
+
}
|
| 266 |
+
],
|
| 267 |
+
"output_type": "loglikelihood",
|
| 268 |
+
"repeats": 1,
|
| 269 |
+
"should_decontaminate": true,
|
| 270 |
+
"doc_to_decontamination_query": "{{text}}",
|
| 271 |
+
"metadata": {
|
| 272 |
+
"version": 1.0,
|
| 273 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 274 |
+
"dtype": "bfloat16",
|
| 275 |
+
"trust_remote_code": true
|
| 276 |
+
}
|
| 277 |
+
},
|
| 278 |
+
"lambada_standard": {
|
| 279 |
+
"task": "lambada_standard",
|
| 280 |
+
"tag": [
|
| 281 |
+
"lambada"
|
| 282 |
+
],
|
| 283 |
+
"dataset_path": "lambada",
|
| 284 |
+
"validation_split": "validation",
|
| 285 |
+
"test_split": "test",
|
| 286 |
+
"doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
|
| 287 |
+
"doc_to_target": "{{' '+text.split(' ')[-1]}}",
|
| 288 |
+
"unsafe_code": false,
|
| 289 |
+
"description": "",
|
| 290 |
+
"target_delimiter": " ",
|
| 291 |
+
"fewshot_delimiter": "\n\n",
|
| 292 |
+
"num_fewshot": 0,
|
| 293 |
+
"metric_list": [
|
| 294 |
+
{
|
| 295 |
+
"metric": "perplexity",
|
| 296 |
+
"aggregation": "perplexity",
|
| 297 |
+
"higher_is_better": false
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"metric": "acc",
|
| 301 |
+
"aggregation": "mean",
|
| 302 |
+
"higher_is_better": true
|
| 303 |
+
}
|
| 304 |
+
],
|
| 305 |
+
"output_type": "loglikelihood",
|
| 306 |
+
"repeats": 1,
|
| 307 |
+
"should_decontaminate": true,
|
| 308 |
+
"doc_to_decontamination_query": "{{text}}",
|
| 309 |
+
"metadata": {
|
| 310 |
+
"version": 1.0,
|
| 311 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 312 |
+
"dtype": "bfloat16",
|
| 313 |
+
"trust_remote_code": true
|
| 314 |
+
}
|
| 315 |
+
},
|
| 316 |
+
"piqa": {
|
| 317 |
+
"task": "piqa",
|
| 318 |
+
"dataset_path": "baber/piqa",
|
| 319 |
+
"training_split": "train",
|
| 320 |
+
"validation_split": "validation",
|
| 321 |
+
"doc_to_text": "Question: {{goal}}\nAnswer:",
|
| 322 |
+
"doc_to_target": "label",
|
| 323 |
+
"unsafe_code": false,
|
| 324 |
+
"doc_to_choice": "{{[sol1, sol2]}}",
|
| 325 |
+
"description": "",
|
| 326 |
+
"target_delimiter": " ",
|
| 327 |
+
"fewshot_delimiter": "\n\n",
|
| 328 |
+
"num_fewshot": 0,
|
| 329 |
+
"metric_list": [
|
| 330 |
+
{
|
| 331 |
+
"metric": "acc",
|
| 332 |
+
"aggregation": "mean",
|
| 333 |
+
"higher_is_better": true
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"metric": "acc_norm",
|
| 337 |
+
"aggregation": "mean",
|
| 338 |
+
"higher_is_better": true
|
| 339 |
+
}
|
| 340 |
+
],
|
| 341 |
+
"output_type": "multiple_choice",
|
| 342 |
+
"repeats": 1,
|
| 343 |
+
"should_decontaminate": true,
|
| 344 |
+
"doc_to_decontamination_query": "goal",
|
| 345 |
+
"metadata": {
|
| 346 |
+
"version": 1.0,
|
| 347 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 348 |
+
"dtype": "bfloat16",
|
| 349 |
+
"trust_remote_code": true
|
| 350 |
+
}
|
| 351 |
+
},
|
| 352 |
+
"social_iqa": {
|
| 353 |
+
"task": "social_iqa",
|
| 354 |
+
"dataset_path": "social_i_qa",
|
| 355 |
+
"training_split": "train",
|
| 356 |
+
"validation_split": "validation",
|
| 357 |
+
"doc_to_text": "Q: {{context}} {{question}}\nA:",
|
| 358 |
+
"doc_to_target": "{{ (label|int) - 1 }}",
|
| 359 |
+
"unsafe_code": false,
|
| 360 |
+
"doc_to_choice": "{{[answerA, answerB, answerC]}}",
|
| 361 |
+
"description": "",
|
| 362 |
+
"target_delimiter": " ",
|
| 363 |
+
"fewshot_delimiter": "\n\n",
|
| 364 |
+
"num_fewshot": 0,
|
| 365 |
+
"metric_list": [
|
| 366 |
+
{
|
| 367 |
+
"metric": "acc",
|
| 368 |
+
"aggregation": "mean",
|
| 369 |
+
"higher_is_better": true
|
| 370 |
+
}
|
| 371 |
+
],
|
| 372 |
+
"output_type": "multiple_choice",
|
| 373 |
+
"repeats": 1,
|
| 374 |
+
"should_decontaminate": false,
|
| 375 |
+
"metadata": {
|
| 376 |
+
"version": 0.0,
|
| 377 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 378 |
+
"dtype": "bfloat16",
|
| 379 |
+
"trust_remote_code": true
|
| 380 |
+
}
|
| 381 |
+
},
|
| 382 |
+
"wikitext": {
|
| 383 |
+
"task": "wikitext",
|
| 384 |
+
"dataset_path": "EleutherAI/wikitext_document_level",
|
| 385 |
+
"dataset_name": "wikitext-2-raw-v1",
|
| 386 |
+
"training_split": "train",
|
| 387 |
+
"validation_split": "validation",
|
| 388 |
+
"test_split": "test",
|
| 389 |
+
"doc_to_text": "",
|
| 390 |
+
"doc_to_target": "def wikitext_detokenizer(doc):\n string = doc[\"page\"]\n # contractions\n string = string.replace(\"s '\", \"s'\")\n string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n # number separators\n string = string.replace(\" @-@ \", \"-\")\n string = string.replace(\" @,@ \", \",\")\n string = string.replace(\" @.@ \", \".\")\n # punctuation\n string = string.replace(\" : \", \": \")\n string = string.replace(\" ; \", \"; \")\n string = string.replace(\" . \", \". \")\n string = string.replace(\" ! \", \"! \")\n string = string.replace(\" ? \", \"? \")\n string = string.replace(\" , \", \", \")\n # double brackets\n string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n # miscellaneous\n string = string.replace(\"= = = =\", \"====\")\n string = string.replace(\"= = =\", \"===\")\n string = string.replace(\"= =\", \"==\")\n string = string.replace(\" \" + chr(176) + \" \", chr(176))\n string = string.replace(\" \\n\", \"\\n\")\n string = string.replace(\"\\n \", \"\\n\")\n string = string.replace(\" N \", \" 1 \")\n string = string.replace(\" 's\", \"'s\")\n\n return string\n",
|
| 391 |
+
"unsafe_code": false,
|
| 392 |
+
"process_results": "def process_results(doc, results):\n (loglikelihood,) = results\n # IMPORTANT: wikitext counts number of words in *original doc before detokenization*\n _words = len(re.split(r\"\\s+\", doc[\"page\"]))\n _bytes = len(doc[\"page\"].encode(\"utf-8\"))\n return {\n \"word_perplexity\": (loglikelihood, _words),\n \"byte_perplexity\": (loglikelihood, _bytes),\n \"bits_per_byte\": (loglikelihood, _bytes),\n }\n",
|
| 393 |
+
"description": "",
|
| 394 |
+
"target_delimiter": " ",
|
| 395 |
+
"fewshot_delimiter": "\n\n",
|
| 396 |
+
"num_fewshot": 0,
|
| 397 |
+
"metric_list": [
|
| 398 |
+
{
|
| 399 |
+
"metric": "word_perplexity"
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"metric": "byte_perplexity"
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"metric": "bits_per_byte"
|
| 406 |
+
}
|
| 407 |
+
],
|
| 408 |
+
"output_type": "loglikelihood_rolling",
|
| 409 |
+
"repeats": 1,
|
| 410 |
+
"should_decontaminate": true,
|
| 411 |
+
"doc_to_decontamination_query": "{{page}}",
|
| 412 |
+
"metadata": {
|
| 413 |
+
"version": 2.0,
|
| 414 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 415 |
+
"dtype": "bfloat16",
|
| 416 |
+
"trust_remote_code": true
|
| 417 |
+
}
|
| 418 |
+
},
|
| 419 |
+
"winogrande": {
|
| 420 |
+
"task": "winogrande",
|
| 421 |
+
"dataset_path": "winogrande",
|
| 422 |
+
"dataset_name": "winogrande_xl",
|
| 423 |
+
"training_split": "train",
|
| 424 |
+
"validation_split": "validation",
|
| 425 |
+
"doc_to_text": "def doc_to_text(doc):\n answer_to_num = {\"1\": 0, \"2\": 1}\n return answer_to_num[doc[\"answer\"]]\n",
|
| 426 |
+
"doc_to_target": "def doc_to_target(doc):\n idx = doc[\"sentence\"].index(\"_\") + 1\n return doc[\"sentence\"][idx:].strip()\n",
|
| 427 |
+
"unsafe_code": false,
|
| 428 |
+
"doc_to_choice": "def doc_to_choice(doc):\n idx = doc[\"sentence\"].index(\"_\")\n options = [doc[\"option1\"], doc[\"option2\"]]\n return [doc[\"sentence\"][:idx] + opt for opt in options]\n",
|
| 429 |
+
"description": "",
|
| 430 |
+
"target_delimiter": " ",
|
| 431 |
+
"fewshot_delimiter": "\n\n",
|
| 432 |
+
"num_fewshot": 0,
|
| 433 |
+
"metric_list": [
|
| 434 |
+
{
|
| 435 |
+
"metric": "acc",
|
| 436 |
+
"aggregation": "mean",
|
| 437 |
+
"higher_is_better": true
|
| 438 |
+
}
|
| 439 |
+
],
|
| 440 |
+
"output_type": "multiple_choice",
|
| 441 |
+
"repeats": 1,
|
| 442 |
+
"should_decontaminate": true,
|
| 443 |
+
"doc_to_decontamination_query": "sentence",
|
| 444 |
+
"metadata": {
|
| 445 |
+
"version": 1.0,
|
| 446 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 447 |
+
"dtype": "bfloat16",
|
| 448 |
+
"trust_remote_code": true
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
},
|
| 452 |
+
"versions": {
|
| 453 |
+
"arc_challenge": 1.0,
|
| 454 |
+
"arc_easy": 1.0,
|
| 455 |
+
"boolq": 2.0,
|
| 456 |
+
"hellaswag": 1.0,
|
| 457 |
+
"lambada_openai": 1.0,
|
| 458 |
+
"lambada_standard": 1.0,
|
| 459 |
+
"piqa": 1.0,
|
| 460 |
+
"social_iqa": 0.0,
|
| 461 |
+
"wikitext": 2.0,
|
| 462 |
+
"winogrande": 1.0
|
| 463 |
+
},
|
| 464 |
+
"n-shot": {
|
| 465 |
+
"arc_challenge": 0,
|
| 466 |
+
"arc_easy": 0,
|
| 467 |
+
"boolq": 0,
|
| 468 |
+
"hellaswag": 0,
|
| 469 |
+
"lambada_openai": 0,
|
| 470 |
+
"lambada_standard": 0,
|
| 471 |
+
"piqa": 0,
|
| 472 |
+
"social_iqa": 0,
|
| 473 |
+
"wikitext": 0,
|
| 474 |
+
"winogrande": 0
|
| 475 |
+
},
|
| 476 |
+
"higher_is_better": {
|
| 477 |
+
"arc_challenge": {
|
| 478 |
+
"acc": true,
|
| 479 |
+
"acc_norm": true
|
| 480 |
+
},
|
| 481 |
+
"arc_easy": {
|
| 482 |
+
"acc": true,
|
| 483 |
+
"acc_norm": true
|
| 484 |
+
},
|
| 485 |
+
"boolq": {
|
| 486 |
+
"acc": true
|
| 487 |
+
},
|
| 488 |
+
"hellaswag": {
|
| 489 |
+
"acc": true,
|
| 490 |
+
"acc_norm": true
|
| 491 |
+
},
|
| 492 |
+
"lambada_openai": {
|
| 493 |
+
"perplexity": false,
|
| 494 |
+
"acc": true
|
| 495 |
+
},
|
| 496 |
+
"lambada_standard": {
|
| 497 |
+
"perplexity": false,
|
| 498 |
+
"acc": true
|
| 499 |
+
},
|
| 500 |
+
"piqa": {
|
| 501 |
+
"acc": true,
|
| 502 |
+
"acc_norm": true
|
| 503 |
+
},
|
| 504 |
+
"social_iqa": {
|
| 505 |
+
"acc": true
|
| 506 |
+
},
|
| 507 |
+
"wikitext": {
|
| 508 |
+
"word_perplexity": false,
|
| 509 |
+
"byte_perplexity": false,
|
| 510 |
+
"bits_per_byte": false
|
| 511 |
+
},
|
| 512 |
+
"winogrande": {
|
| 513 |
+
"acc": true
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"n-samples": {
|
| 517 |
+
"winogrande": {
|
| 518 |
+
"original": 1267,
|
| 519 |
+
"effective": 1267
|
| 520 |
+
},
|
| 521 |
+
"wikitext": {
|
| 522 |
+
"original": 62,
|
| 523 |
+
"effective": 62
|
| 524 |
+
},
|
| 525 |
+
"social_iqa": {
|
| 526 |
+
"original": 1954,
|
| 527 |
+
"effective": 1954
|
| 528 |
+
},
|
| 529 |
+
"piqa": {
|
| 530 |
+
"original": 1838,
|
| 531 |
+
"effective": 1838
|
| 532 |
+
},
|
| 533 |
+
"lambada_standard": {
|
| 534 |
+
"original": 5153,
|
| 535 |
+
"effective": 5153
|
| 536 |
+
},
|
| 537 |
+
"lambada_openai": {
|
| 538 |
+
"original": 5153,
|
| 539 |
+
"effective": 5153
|
| 540 |
+
},
|
| 541 |
+
"hellaswag": {
|
| 542 |
+
"original": 10042,
|
| 543 |
+
"effective": 10042
|
| 544 |
+
},
|
| 545 |
+
"boolq": {
|
| 546 |
+
"original": 3270,
|
| 547 |
+
"effective": 3270
|
| 548 |
+
},
|
| 549 |
+
"arc_easy": {
|
| 550 |
+
"original": 2376,
|
| 551 |
+
"effective": 2376
|
| 552 |
+
},
|
| 553 |
+
"arc_challenge": {
|
| 554 |
+
"original": 1172,
|
| 555 |
+
"effective": 1172
|
| 556 |
+
}
|
| 557 |
+
},
|
| 558 |
+
"config": {
|
| 559 |
+
"model": "hf",
|
| 560 |
+
"model_args": "pretrained=results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/,dtype=bfloat16,trust_remote_code=True,trust_remote_code=True",
|
| 561 |
+
"model_num_parameters": 392747259,
|
| 562 |
+
"model_dtype": "torch.bfloat16",
|
| 563 |
+
"model_revision": "main",
|
| 564 |
+
"model_sha": "",
|
| 565 |
+
"batch_size": "8",
|
| 566 |
+
"batch_sizes": [],
|
| 567 |
+
"device": "cuda:0",
|
| 568 |
+
"use_cache": null,
|
| 569 |
+
"limit": null,
|
| 570 |
+
"bootstrap_iters": 100000,
|
| 571 |
+
"gen_kwargs": null,
|
| 572 |
+
"random_seed": 0,
|
| 573 |
+
"numpy_seed": 1234,
|
| 574 |
+
"torch_seed": 1234,
|
| 575 |
+
"fewshot_seed": 1234
|
| 576 |
+
},
|
| 577 |
+
"git_hash": "core_v0.12.0-111-g418d5cb59",
|
| 578 |
+
"date": 1769150810.054304,
|
| 579 |
+
"pretty_env_info": "PyTorch version: 2.6.0+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 22.04.4 LTS (x86_64)\nGCC version: (conda-forge gcc 9.5.0-19) 9.5.0\nClang version: Could not collect\nCMake version: version 3.30.1\nLibc version: glibc-2.35\n\nPython version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)\nPython platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.35\nIs CUDA available: True\nCUDA runtime version: 12.4.131\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: NVIDIA A800-SXM4-80GB\nGPU 1: NVIDIA A800-SXM4-80GB\nGPU 2: NVIDIA A800-SXM4-80GB\nGPU 3: NVIDIA A800-SXM4-80GB\nGPU 4: NVIDIA A800-SXM4-80GB\nGPU 5: NVIDIA A800-SXM4-80GB\nGPU 6: NVIDIA A800-SXM4-80GB\nGPU 7: NVIDIA A800-SXM4-80GB\n\nNvidia driver version: 550.54.15\ncuDNN version: Could not collect\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nAddress sizes: 52 bits physical, 57 bits virtual\nByte Order: Little Endian\nCPU(s): 104\nOn-line CPU(s) list: 0-103\nVendor ID: GenuineIntel\nModel name: Intel(R) Xeon(R) Platinum 8470\nCPU family: 6\nModel: 143\nThread(s) per core: 1\nCore(s) per socket: 52\nSocket(s): 2\nStepping: 8\nCPU max MHz: 3800.0000\nCPU min MHz: 800.0000\nBogoMIPS: 4000.00\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities\nVirtualization: VT-x\nL1d cache: 4.9 MiB (104 instances)\nL1i cache: 3.3 MiB (104 instances)\nL2 cache: 208 MiB (104 instances)\nL3 cache: 210 MiB (2 instances)\nNUMA node(s): 2\nNUMA node0 CPU(s): 0-51\nNUMA node1 CPU(s): 52-103\nVulnerability Gather data sampling: Not affected\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Not affected\nVulnerability Retbleed: Not affected\nVulnerability Spec rstack overflow: Not affected\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Not affected\n\nVersions of relevant libraries:\n[pip3] numpy==1.26.4\n[pip3] nvidia-cublas-cu12==12.4.5.8\n[pip3] nvidia-cuda-cupti-cu12==12.4.127\n[pip3] nvidia-cuda-nvrtc-cu12==12.4.127\n[pip3] nvidia-cuda-runtime-cu12==12.4.127\n[pip3] nvidia-cudnn-cu12==9.1.0.70\n[pip3] nvidia-cufft-cu12==11.2.1.3\n[pip3] nvidia-curand-cu12==10.3.5.147\n[pip3] nvidia-cusolver-cu12==11.6.1.9\n[pip3] nvidia-cusparse-cu12==12.3.1.170\n[pip3] nvidia-cusparselt-cu12==0.6.2\n[pip3] nvidia-nccl-cu11==2.21.5\n[pip3] nvidia-nccl-cu12==2.21.5\n[pip3] nvidia-nvjitlink-cu12==12.4.127\n[pip3] nvidia-nvtx-cu12==12.4.127\n[pip3] torch==2.6.0\n[pip3] torchaudio==2.6.0\n[pip3] torchdata==0.11.0\n[pip3] torchvision==0.21.0\n[pip3] triton==3.2.0\n[conda] cuda-cudart 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-cudart_linux-64 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-cupti 12.4.127 he02047a_2 conda-forge\n[conda] cuda-libraries 12.4.0 ha770c72_0 conda-forge\n[conda] cuda-nvrtc 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-nvtx 12.4.127 he02047a_2 conda-forge\n[conda] cuda-opencl 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-runtime 12.4.0 ha804496_0 conda-forge\n[conda] ffmpeg 4.3 hf484d3e_0 pytorch\n[conda] libcublas 12.4.2.65 hd3aeb46_0 conda-forge\n[conda] libcufft 11.2.0.44 hd3aeb46_0 conda-forge\n[conda] libcurand 10.3.5.119 hd3aeb46_0 conda-forge\n[conda] libcusolver 11.6.0.99 hd3aeb46_0 conda-forge\n[conda] libcusparse 12.3.0.142 hd3aeb46_0 conda-forge\n[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch\n[conda] libnvjitlink 12.4.99 hd3aeb46_0 conda-forge\n[conda] mkl 2023.1.0 h213fc3f_46344 defaults\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi\n[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi\n[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi\n[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi\n[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi\n[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi\n[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi\n[conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi\n[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi\n[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi\n[conda] pytorch-cuda 12.4 hc786d27_6 pytorch\n[conda] pytorch-mutex 1.0 cuda pytorch\n[conda] torch 2.6.0 pypi_0 pypi\n[conda] torchaudio 2.6.0 pypi_0 pypi\n[conda] torchdata 0.11.0 pypi_0 pypi\n[conda] torchvision 0.21.0 pypi_0 pypi\n[conda] triton 3.2.0 pypi_0 pypi",
|
| 580 |
+
"transformers_version": "4.55.2",
|
| 581 |
+
"lm_eval_version": "0.4.9.1",
|
| 582 |
+
"upper_git_hash": null,
|
| 583 |
+
"tokenizer_pad_token": [
|
| 584 |
+
"<unk>",
|
| 585 |
+
"0"
|
| 586 |
+
],
|
| 587 |
+
"tokenizer_eos_token": [
|
| 588 |
+
"<|im_end|>",
|
| 589 |
+
"73440"
|
| 590 |
+
],
|
| 591 |
+
"tokenizer_bos_token": [
|
| 592 |
+
"<s>",
|
| 593 |
+
"1"
|
| 594 |
+
],
|
| 595 |
+
"eot_token_id": 73440,
|
| 596 |
+
"max_length": 4096,
|
| 597 |
+
"task_hashes": {},
|
| 598 |
+
"model_source": "hf",
|
| 599 |
+
"model_name": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 600 |
+
"model_name_sanitized": "results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__",
|
| 601 |
+
"system_instruction": null,
|
| 602 |
+
"system_instruction_sha": null,
|
| 603 |
+
"fewshot_as_multiturn": false,
|
| 604 |
+
"chat_template": null,
|
| 605 |
+
"chat_template_sha": null,
|
| 606 |
+
"start_time": 684014.828182741,
|
| 607 |
+
"end_time": 684465.591942707,
|
| 608 |
+
"total_evaluation_time_seconds": "450.76375996600837"
|
| 609 |
+
}
|
evaluation/results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__/results_2026-03-30T20-55-02.181492.json
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"arc_challenge": {
|
| 4 |
+
"alias": "arc_challenge",
|
| 5 |
+
"acc,none": 0.20733788395904437,
|
| 6 |
+
"acc_stderr,none": 0.011846905782971337,
|
| 7 |
+
"acc_norm,none": 0.24914675767918087,
|
| 8 |
+
"acc_norm_stderr,none": 0.012639407111926433
|
| 9 |
+
},
|
| 10 |
+
"arc_easy": {
|
| 11 |
+
"alias": "arc_easy",
|
| 12 |
+
"acc,none": 0.5197811447811448,
|
| 13 |
+
"acc_stderr,none": 0.010251751199542735,
|
| 14 |
+
"acc_norm,none": 0.44907407407407407,
|
| 15 |
+
"acc_norm_stderr,none": 0.010206428316323365
|
| 16 |
+
},
|
| 17 |
+
"boolq": {
|
| 18 |
+
"alias": "boolq",
|
| 19 |
+
"acc,none": 0.6113149847094801,
|
| 20 |
+
"acc_stderr,none": 0.008525580498982967
|
| 21 |
+
},
|
| 22 |
+
"hellaswag": {
|
| 23 |
+
"alias": "hellaswag",
|
| 24 |
+
"acc,none": 0.29466241784505076,
|
| 25 |
+
"acc_stderr,none": 0.004549591490046219,
|
| 26 |
+
"acc_norm,none": 0.3293168691495718,
|
| 27 |
+
"acc_norm_stderr,none": 0.004690047021719822
|
| 28 |
+
},
|
| 29 |
+
"lambada_openai": {
|
| 30 |
+
"alias": "lambada_openai",
|
| 31 |
+
"perplexity,none": 66.63752818011969,
|
| 32 |
+
"perplexity_stderr,none": 2.753797550415125,
|
| 33 |
+
"acc,none": 0.27653793906462254,
|
| 34 |
+
"acc_stderr,none": 0.006231567654090107
|
| 35 |
+
},
|
| 36 |
+
"lambada_standard": {
|
| 37 |
+
"alias": "lambada_standard",
|
| 38 |
+
"perplexity,none": 198.31426320573388,
|
| 39 |
+
"perplexity_stderr,none": 8.56348322119257,
|
| 40 |
+
"acc,none": 0.19483795847079372,
|
| 41 |
+
"acc_stderr,none": 0.005518111913121867
|
| 42 |
+
},
|
| 43 |
+
"piqa": {
|
| 44 |
+
"alias": "piqa",
|
| 45 |
+
"acc,none": 0.6436343852013058,
|
| 46 |
+
"acc_stderr,none": 0.011174109865864717,
|
| 47 |
+
"acc_norm,none": 0.6131664853101197,
|
| 48 |
+
"acc_norm_stderr,none": 0.011363095931902848
|
| 49 |
+
},
|
| 50 |
+
"social_iqa": {
|
| 51 |
+
"alias": "social_iqa",
|
| 52 |
+
"acc,none": 0.3618219037871034,
|
| 53 |
+
"acc_stderr,none": 0.010873447266941618
|
| 54 |
+
},
|
| 55 |
+
"wikitext": {
|
| 56 |
+
"alias": "wikitext",
|
| 57 |
+
"word_perplexity,none": 39.00052118822888,
|
| 58 |
+
"word_perplexity_stderr,none": "N/A",
|
| 59 |
+
"byte_perplexity,none": 1.9839837931891766,
|
| 60 |
+
"byte_perplexity_stderr,none": "N/A",
|
| 61 |
+
"bits_per_byte,none": 0.9884002406536678,
|
| 62 |
+
"bits_per_byte_stderr,none": "N/A"
|
| 63 |
+
},
|
| 64 |
+
"winogrande": {
|
| 65 |
+
"alias": "winogrande",
|
| 66 |
+
"acc,none": 0.5209155485398579,
|
| 67 |
+
"acc_stderr,none": 0.014040185494212945
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
"group_subtasks": {
|
| 71 |
+
"arc_challenge": [],
|
| 72 |
+
"arc_easy": [],
|
| 73 |
+
"boolq": [],
|
| 74 |
+
"hellaswag": [],
|
| 75 |
+
"lambada_openai": [],
|
| 76 |
+
"lambada_standard": [],
|
| 77 |
+
"piqa": [],
|
| 78 |
+
"social_iqa": [],
|
| 79 |
+
"wikitext": [],
|
| 80 |
+
"winogrande": []
|
| 81 |
+
},
|
| 82 |
+
"configs": {
|
| 83 |
+
"arc_challenge": {
|
| 84 |
+
"task": "arc_challenge",
|
| 85 |
+
"tag": [
|
| 86 |
+
"ai2_arc"
|
| 87 |
+
],
|
| 88 |
+
"dataset_path": "allenai/ai2_arc",
|
| 89 |
+
"dataset_name": "ARC-Challenge",
|
| 90 |
+
"training_split": "train",
|
| 91 |
+
"validation_split": "validation",
|
| 92 |
+
"test_split": "test",
|
| 93 |
+
"doc_to_text": "Question: {{question}}\nAnswer:",
|
| 94 |
+
"doc_to_target": "{{choices.label.index(answerKey)}}",
|
| 95 |
+
"unsafe_code": false,
|
| 96 |
+
"doc_to_choice": "{{choices.text}}",
|
| 97 |
+
"description": "",
|
| 98 |
+
"target_delimiter": " ",
|
| 99 |
+
"fewshot_delimiter": "\n\n",
|
| 100 |
+
"num_fewshot": 0,
|
| 101 |
+
"metric_list": [
|
| 102 |
+
{
|
| 103 |
+
"metric": "acc",
|
| 104 |
+
"aggregation": "mean",
|
| 105 |
+
"higher_is_better": true
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"metric": "acc_norm",
|
| 109 |
+
"aggregation": "mean",
|
| 110 |
+
"higher_is_better": true
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"output_type": "multiple_choice",
|
| 114 |
+
"repeats": 1,
|
| 115 |
+
"should_decontaminate": true,
|
| 116 |
+
"doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
|
| 117 |
+
"metadata": {
|
| 118 |
+
"version": 1.0,
|
| 119 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 120 |
+
"dtype": "bfloat16",
|
| 121 |
+
"trust_remote_code": true
|
| 122 |
+
}
|
| 123 |
+
},
|
| 124 |
+
"arc_easy": {
|
| 125 |
+
"task": "arc_easy",
|
| 126 |
+
"tag": [
|
| 127 |
+
"ai2_arc"
|
| 128 |
+
],
|
| 129 |
+
"dataset_path": "allenai/ai2_arc",
|
| 130 |
+
"dataset_name": "ARC-Easy",
|
| 131 |
+
"training_split": "train",
|
| 132 |
+
"validation_split": "validation",
|
| 133 |
+
"test_split": "test",
|
| 134 |
+
"doc_to_text": "Question: {{question}}\nAnswer:",
|
| 135 |
+
"doc_to_target": "{{choices.label.index(answerKey)}}",
|
| 136 |
+
"unsafe_code": false,
|
| 137 |
+
"doc_to_choice": "{{choices.text}}",
|
| 138 |
+
"description": "",
|
| 139 |
+
"target_delimiter": " ",
|
| 140 |
+
"fewshot_delimiter": "\n\n",
|
| 141 |
+
"num_fewshot": 0,
|
| 142 |
+
"metric_list": [
|
| 143 |
+
{
|
| 144 |
+
"metric": "acc",
|
| 145 |
+
"aggregation": "mean",
|
| 146 |
+
"higher_is_better": true
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"metric": "acc_norm",
|
| 150 |
+
"aggregation": "mean",
|
| 151 |
+
"higher_is_better": true
|
| 152 |
+
}
|
| 153 |
+
],
|
| 154 |
+
"output_type": "multiple_choice",
|
| 155 |
+
"repeats": 1,
|
| 156 |
+
"should_decontaminate": true,
|
| 157 |
+
"doc_to_decontamination_query": "Question: {{question}}\nAnswer:",
|
| 158 |
+
"metadata": {
|
| 159 |
+
"version": 1.0,
|
| 160 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 161 |
+
"dtype": "bfloat16",
|
| 162 |
+
"trust_remote_code": true
|
| 163 |
+
}
|
| 164 |
+
},
|
| 165 |
+
"boolq": {
|
| 166 |
+
"task": "boolq",
|
| 167 |
+
"tag": [
|
| 168 |
+
"super-glue-lm-eval-v1"
|
| 169 |
+
],
|
| 170 |
+
"dataset_path": "super_glue",
|
| 171 |
+
"dataset_name": "boolq",
|
| 172 |
+
"training_split": "train",
|
| 173 |
+
"validation_split": "validation",
|
| 174 |
+
"doc_to_text": "{{passage}}\nQuestion: {{question}}?\nAnswer:",
|
| 175 |
+
"doc_to_target": "label",
|
| 176 |
+
"unsafe_code": false,
|
| 177 |
+
"doc_to_choice": [
|
| 178 |
+
"no",
|
| 179 |
+
"yes"
|
| 180 |
+
],
|
| 181 |
+
"description": "",
|
| 182 |
+
"target_delimiter": " ",
|
| 183 |
+
"fewshot_delimiter": "\n\n",
|
| 184 |
+
"num_fewshot": 0,
|
| 185 |
+
"metric_list": [
|
| 186 |
+
{
|
| 187 |
+
"metric": "acc"
|
| 188 |
+
}
|
| 189 |
+
],
|
| 190 |
+
"output_type": "multiple_choice",
|
| 191 |
+
"repeats": 1,
|
| 192 |
+
"should_decontaminate": true,
|
| 193 |
+
"doc_to_decontamination_query": "passage",
|
| 194 |
+
"metadata": {
|
| 195 |
+
"version": 2.0,
|
| 196 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 197 |
+
"dtype": "bfloat16",
|
| 198 |
+
"trust_remote_code": true
|
| 199 |
+
}
|
| 200 |
+
},
|
| 201 |
+
"hellaswag": {
|
| 202 |
+
"task": "hellaswag",
|
| 203 |
+
"tag": [
|
| 204 |
+
"multiple_choice"
|
| 205 |
+
],
|
| 206 |
+
"dataset_path": "Rowan/hellaswag",
|
| 207 |
+
"training_split": "train",
|
| 208 |
+
"validation_split": "validation",
|
| 209 |
+
"process_docs": "def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:\n def _process_doc(doc):\n ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n out_doc = {\n \"query\": preprocess(doc[\"activity_label\"] + \": \" + ctx),\n \"choices\": [preprocess(ending) for ending in doc[\"endings\"]],\n \"gold\": int(doc[\"label\"]),\n }\n return out_doc\n\n return dataset.map(_process_doc)\n",
|
| 210 |
+
"doc_to_text": "{{query}}",
|
| 211 |
+
"doc_to_target": "{{label}}",
|
| 212 |
+
"unsafe_code": false,
|
| 213 |
+
"doc_to_choice": "choices",
|
| 214 |
+
"description": "",
|
| 215 |
+
"target_delimiter": " ",
|
| 216 |
+
"fewshot_delimiter": "\n\n",
|
| 217 |
+
"num_fewshot": 0,
|
| 218 |
+
"metric_list": [
|
| 219 |
+
{
|
| 220 |
+
"metric": "acc",
|
| 221 |
+
"aggregation": "mean",
|
| 222 |
+
"higher_is_better": true
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"metric": "acc_norm",
|
| 226 |
+
"aggregation": "mean",
|
| 227 |
+
"higher_is_better": true
|
| 228 |
+
}
|
| 229 |
+
],
|
| 230 |
+
"output_type": "multiple_choice",
|
| 231 |
+
"repeats": 1,
|
| 232 |
+
"should_decontaminate": false,
|
| 233 |
+
"metadata": {
|
| 234 |
+
"version": 1.0,
|
| 235 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 236 |
+
"dtype": "bfloat16",
|
| 237 |
+
"trust_remote_code": true
|
| 238 |
+
}
|
| 239 |
+
},
|
| 240 |
+
"lambada_openai": {
|
| 241 |
+
"task": "lambada_openai",
|
| 242 |
+
"tag": [
|
| 243 |
+
"lambada"
|
| 244 |
+
],
|
| 245 |
+
"dataset_path": "EleutherAI/lambada_openai",
|
| 246 |
+
"dataset_name": "default",
|
| 247 |
+
"test_split": "test",
|
| 248 |
+
"doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
|
| 249 |
+
"doc_to_target": "{{' '+text.split(' ')[-1]}}",
|
| 250 |
+
"unsafe_code": false,
|
| 251 |
+
"description": "",
|
| 252 |
+
"target_delimiter": " ",
|
| 253 |
+
"fewshot_delimiter": "\n\n",
|
| 254 |
+
"num_fewshot": 0,
|
| 255 |
+
"metric_list": [
|
| 256 |
+
{
|
| 257 |
+
"metric": "perplexity",
|
| 258 |
+
"aggregation": "perplexity",
|
| 259 |
+
"higher_is_better": false
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"metric": "acc",
|
| 263 |
+
"aggregation": "mean",
|
| 264 |
+
"higher_is_better": true
|
| 265 |
+
}
|
| 266 |
+
],
|
| 267 |
+
"output_type": "loglikelihood",
|
| 268 |
+
"repeats": 1,
|
| 269 |
+
"should_decontaminate": true,
|
| 270 |
+
"doc_to_decontamination_query": "{{text}}",
|
| 271 |
+
"metadata": {
|
| 272 |
+
"version": 1.0,
|
| 273 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 274 |
+
"dtype": "bfloat16",
|
| 275 |
+
"trust_remote_code": true
|
| 276 |
+
}
|
| 277 |
+
},
|
| 278 |
+
"lambada_standard": {
|
| 279 |
+
"task": "lambada_standard",
|
| 280 |
+
"tag": [
|
| 281 |
+
"lambada"
|
| 282 |
+
],
|
| 283 |
+
"dataset_path": "lambada",
|
| 284 |
+
"validation_split": "validation",
|
| 285 |
+
"test_split": "test",
|
| 286 |
+
"doc_to_text": "{{text.split(' ')[:-1]|join(' ')}}",
|
| 287 |
+
"doc_to_target": "{{' '+text.split(' ')[-1]}}",
|
| 288 |
+
"unsafe_code": false,
|
| 289 |
+
"description": "",
|
| 290 |
+
"target_delimiter": " ",
|
| 291 |
+
"fewshot_delimiter": "\n\n",
|
| 292 |
+
"num_fewshot": 0,
|
| 293 |
+
"metric_list": [
|
| 294 |
+
{
|
| 295 |
+
"metric": "perplexity",
|
| 296 |
+
"aggregation": "perplexity",
|
| 297 |
+
"higher_is_better": false
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"metric": "acc",
|
| 301 |
+
"aggregation": "mean",
|
| 302 |
+
"higher_is_better": true
|
| 303 |
+
}
|
| 304 |
+
],
|
| 305 |
+
"output_type": "loglikelihood",
|
| 306 |
+
"repeats": 1,
|
| 307 |
+
"should_decontaminate": true,
|
| 308 |
+
"doc_to_decontamination_query": "{{text}}",
|
| 309 |
+
"metadata": {
|
| 310 |
+
"version": 1.0,
|
| 311 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 312 |
+
"dtype": "bfloat16",
|
| 313 |
+
"trust_remote_code": true
|
| 314 |
+
}
|
| 315 |
+
},
|
| 316 |
+
"piqa": {
|
| 317 |
+
"task": "piqa",
|
| 318 |
+
"dataset_path": "baber/piqa",
|
| 319 |
+
"training_split": "train",
|
| 320 |
+
"validation_split": "validation",
|
| 321 |
+
"doc_to_text": "Question: {{goal}}\nAnswer:",
|
| 322 |
+
"doc_to_target": "label",
|
| 323 |
+
"unsafe_code": false,
|
| 324 |
+
"doc_to_choice": "{{[sol1, sol2]}}",
|
| 325 |
+
"description": "",
|
| 326 |
+
"target_delimiter": " ",
|
| 327 |
+
"fewshot_delimiter": "\n\n",
|
| 328 |
+
"num_fewshot": 0,
|
| 329 |
+
"metric_list": [
|
| 330 |
+
{
|
| 331 |
+
"metric": "acc",
|
| 332 |
+
"aggregation": "mean",
|
| 333 |
+
"higher_is_better": true
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"metric": "acc_norm",
|
| 337 |
+
"aggregation": "mean",
|
| 338 |
+
"higher_is_better": true
|
| 339 |
+
}
|
| 340 |
+
],
|
| 341 |
+
"output_type": "multiple_choice",
|
| 342 |
+
"repeats": 1,
|
| 343 |
+
"should_decontaminate": true,
|
| 344 |
+
"doc_to_decontamination_query": "goal",
|
| 345 |
+
"metadata": {
|
| 346 |
+
"version": 1.0,
|
| 347 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 348 |
+
"dtype": "bfloat16",
|
| 349 |
+
"trust_remote_code": true
|
| 350 |
+
}
|
| 351 |
+
},
|
| 352 |
+
"social_iqa": {
|
| 353 |
+
"task": "social_iqa",
|
| 354 |
+
"dataset_path": "social_i_qa",
|
| 355 |
+
"training_split": "train",
|
| 356 |
+
"validation_split": "validation",
|
| 357 |
+
"doc_to_text": "Q: {{context}} {{question}}\nA:",
|
| 358 |
+
"doc_to_target": "{{ (label|int) - 1 }}",
|
| 359 |
+
"unsafe_code": false,
|
| 360 |
+
"doc_to_choice": "{{[answerA, answerB, answerC]}}",
|
| 361 |
+
"description": "",
|
| 362 |
+
"target_delimiter": " ",
|
| 363 |
+
"fewshot_delimiter": "\n\n",
|
| 364 |
+
"num_fewshot": 0,
|
| 365 |
+
"metric_list": [
|
| 366 |
+
{
|
| 367 |
+
"metric": "acc",
|
| 368 |
+
"aggregation": "mean",
|
| 369 |
+
"higher_is_better": true
|
| 370 |
+
}
|
| 371 |
+
],
|
| 372 |
+
"output_type": "multiple_choice",
|
| 373 |
+
"repeats": 1,
|
| 374 |
+
"should_decontaminate": false,
|
| 375 |
+
"metadata": {
|
| 376 |
+
"version": 0.0,
|
| 377 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 378 |
+
"dtype": "bfloat16",
|
| 379 |
+
"trust_remote_code": true
|
| 380 |
+
}
|
| 381 |
+
},
|
| 382 |
+
"wikitext": {
|
| 383 |
+
"task": "wikitext",
|
| 384 |
+
"dataset_path": "EleutherAI/wikitext_document_level",
|
| 385 |
+
"dataset_name": "wikitext-2-raw-v1",
|
| 386 |
+
"training_split": "train",
|
| 387 |
+
"validation_split": "validation",
|
| 388 |
+
"test_split": "test",
|
| 389 |
+
"doc_to_text": "",
|
| 390 |
+
"doc_to_target": "def wikitext_detokenizer(doc):\n string = doc[\"page\"]\n # contractions\n string = string.replace(\"s '\", \"s'\")\n string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n # number separators\n string = string.replace(\" @-@ \", \"-\")\n string = string.replace(\" @,@ \", \",\")\n string = string.replace(\" @.@ \", \".\")\n # punctuation\n string = string.replace(\" : \", \": \")\n string = string.replace(\" ; \", \"; \")\n string = string.replace(\" . \", \". \")\n string = string.replace(\" ! \", \"! \")\n string = string.replace(\" ? \", \"? \")\n string = string.replace(\" , \", \", \")\n # double brackets\n string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n # miscellaneous\n string = string.replace(\"= = = =\", \"====\")\n string = string.replace(\"= = =\", \"===\")\n string = string.replace(\"= =\", \"==\")\n string = string.replace(\" \" + chr(176) + \" \", chr(176))\n string = string.replace(\" \\n\", \"\\n\")\n string = string.replace(\"\\n \", \"\\n\")\n string = string.replace(\" N \", \" 1 \")\n string = string.replace(\" 's\", \"'s\")\n\n return string\n",
|
| 391 |
+
"unsafe_code": false,
|
| 392 |
+
"process_results": "def process_results(doc, results):\n (loglikelihood,) = results\n # IMPORTANT: wikitext counts number of words in *original doc before detokenization*\n _words = len(re.split(r\"\\s+\", doc[\"page\"]))\n _bytes = len(doc[\"page\"].encode(\"utf-8\"))\n return {\n \"word_perplexity\": (loglikelihood, _words),\n \"byte_perplexity\": (loglikelihood, _bytes),\n \"bits_per_byte\": (loglikelihood, _bytes),\n }\n",
|
| 393 |
+
"description": "",
|
| 394 |
+
"target_delimiter": " ",
|
| 395 |
+
"fewshot_delimiter": "\n\n",
|
| 396 |
+
"num_fewshot": 0,
|
| 397 |
+
"metric_list": [
|
| 398 |
+
{
|
| 399 |
+
"metric": "word_perplexity"
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"metric": "byte_perplexity"
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"metric": "bits_per_byte"
|
| 406 |
+
}
|
| 407 |
+
],
|
| 408 |
+
"output_type": "loglikelihood_rolling",
|
| 409 |
+
"repeats": 1,
|
| 410 |
+
"should_decontaminate": true,
|
| 411 |
+
"doc_to_decontamination_query": "{{page}}",
|
| 412 |
+
"metadata": {
|
| 413 |
+
"version": 2.0,
|
| 414 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 415 |
+
"dtype": "bfloat16",
|
| 416 |
+
"trust_remote_code": true
|
| 417 |
+
}
|
| 418 |
+
},
|
| 419 |
+
"winogrande": {
|
| 420 |
+
"task": "winogrande",
|
| 421 |
+
"dataset_path": "winogrande",
|
| 422 |
+
"dataset_name": "winogrande_xl",
|
| 423 |
+
"training_split": "train",
|
| 424 |
+
"validation_split": "validation",
|
| 425 |
+
"doc_to_text": "def doc_to_text(doc):\n answer_to_num = {\"1\": 0, \"2\": 1}\n return answer_to_num[doc[\"answer\"]]\n",
|
| 426 |
+
"doc_to_target": "def doc_to_target(doc):\n idx = doc[\"sentence\"].index(\"_\") + 1\n return doc[\"sentence\"][idx:].strip()\n",
|
| 427 |
+
"unsafe_code": false,
|
| 428 |
+
"doc_to_choice": "def doc_to_choice(doc):\n idx = doc[\"sentence\"].index(\"_\")\n options = [doc[\"option1\"], doc[\"option2\"]]\n return [doc[\"sentence\"][:idx] + opt for opt in options]\n",
|
| 429 |
+
"description": "",
|
| 430 |
+
"target_delimiter": " ",
|
| 431 |
+
"fewshot_delimiter": "\n\n",
|
| 432 |
+
"num_fewshot": 0,
|
| 433 |
+
"metric_list": [
|
| 434 |
+
{
|
| 435 |
+
"metric": "acc",
|
| 436 |
+
"aggregation": "mean",
|
| 437 |
+
"higher_is_better": true
|
| 438 |
+
}
|
| 439 |
+
],
|
| 440 |
+
"output_type": "multiple_choice",
|
| 441 |
+
"repeats": 1,
|
| 442 |
+
"should_decontaminate": true,
|
| 443 |
+
"doc_to_decontamination_query": "sentence",
|
| 444 |
+
"metadata": {
|
| 445 |
+
"version": 1.0,
|
| 446 |
+
"pretrained": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 447 |
+
"dtype": "bfloat16",
|
| 448 |
+
"trust_remote_code": true
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
},
|
| 452 |
+
"versions": {
|
| 453 |
+
"arc_challenge": 1.0,
|
| 454 |
+
"arc_easy": 1.0,
|
| 455 |
+
"boolq": 2.0,
|
| 456 |
+
"hellaswag": 1.0,
|
| 457 |
+
"lambada_openai": 1.0,
|
| 458 |
+
"lambada_standard": 1.0,
|
| 459 |
+
"piqa": 1.0,
|
| 460 |
+
"social_iqa": 0.0,
|
| 461 |
+
"wikitext": 2.0,
|
| 462 |
+
"winogrande": 1.0
|
| 463 |
+
},
|
| 464 |
+
"n-shot": {
|
| 465 |
+
"arc_challenge": 0,
|
| 466 |
+
"arc_easy": 0,
|
| 467 |
+
"boolq": 0,
|
| 468 |
+
"hellaswag": 0,
|
| 469 |
+
"lambada_openai": 0,
|
| 470 |
+
"lambada_standard": 0,
|
| 471 |
+
"piqa": 0,
|
| 472 |
+
"social_iqa": 0,
|
| 473 |
+
"wikitext": 0,
|
| 474 |
+
"winogrande": 0
|
| 475 |
+
},
|
| 476 |
+
"higher_is_better": {
|
| 477 |
+
"arc_challenge": {
|
| 478 |
+
"acc": true,
|
| 479 |
+
"acc_norm": true
|
| 480 |
+
},
|
| 481 |
+
"arc_easy": {
|
| 482 |
+
"acc": true,
|
| 483 |
+
"acc_norm": true
|
| 484 |
+
},
|
| 485 |
+
"boolq": {
|
| 486 |
+
"acc": true
|
| 487 |
+
},
|
| 488 |
+
"hellaswag": {
|
| 489 |
+
"acc": true,
|
| 490 |
+
"acc_norm": true
|
| 491 |
+
},
|
| 492 |
+
"lambada_openai": {
|
| 493 |
+
"perplexity": false,
|
| 494 |
+
"acc": true
|
| 495 |
+
},
|
| 496 |
+
"lambada_standard": {
|
| 497 |
+
"perplexity": false,
|
| 498 |
+
"acc": true
|
| 499 |
+
},
|
| 500 |
+
"piqa": {
|
| 501 |
+
"acc": true,
|
| 502 |
+
"acc_norm": true
|
| 503 |
+
},
|
| 504 |
+
"social_iqa": {
|
| 505 |
+
"acc": true
|
| 506 |
+
},
|
| 507 |
+
"wikitext": {
|
| 508 |
+
"word_perplexity": false,
|
| 509 |
+
"byte_perplexity": false,
|
| 510 |
+
"bits_per_byte": false
|
| 511 |
+
},
|
| 512 |
+
"winogrande": {
|
| 513 |
+
"acc": true
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"n-samples": {
|
| 517 |
+
"winogrande": {
|
| 518 |
+
"original": 1267,
|
| 519 |
+
"effective": 1267
|
| 520 |
+
},
|
| 521 |
+
"wikitext": {
|
| 522 |
+
"original": 62,
|
| 523 |
+
"effective": 62
|
| 524 |
+
},
|
| 525 |
+
"social_iqa": {
|
| 526 |
+
"original": 1954,
|
| 527 |
+
"effective": 1954
|
| 528 |
+
},
|
| 529 |
+
"piqa": {
|
| 530 |
+
"original": 1838,
|
| 531 |
+
"effective": 1838
|
| 532 |
+
},
|
| 533 |
+
"lambada_standard": {
|
| 534 |
+
"original": 5153,
|
| 535 |
+
"effective": 5153
|
| 536 |
+
},
|
| 537 |
+
"lambada_openai": {
|
| 538 |
+
"original": 5153,
|
| 539 |
+
"effective": 5153
|
| 540 |
+
},
|
| 541 |
+
"hellaswag": {
|
| 542 |
+
"original": 10042,
|
| 543 |
+
"effective": 10042
|
| 544 |
+
},
|
| 545 |
+
"boolq": {
|
| 546 |
+
"original": 3270,
|
| 547 |
+
"effective": 3270
|
| 548 |
+
},
|
| 549 |
+
"arc_easy": {
|
| 550 |
+
"original": 2376,
|
| 551 |
+
"effective": 2376
|
| 552 |
+
},
|
| 553 |
+
"arc_challenge": {
|
| 554 |
+
"original": 1172,
|
| 555 |
+
"effective": 1172
|
| 556 |
+
}
|
| 557 |
+
},
|
| 558 |
+
"config": {
|
| 559 |
+
"model": "hf",
|
| 560 |
+
"model_args": "pretrained=results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/,dtype=bfloat16,trust_remote_code=True,trust_remote_code=True",
|
| 561 |
+
"model_num_parameters": 392747259,
|
| 562 |
+
"model_dtype": "torch.bfloat16",
|
| 563 |
+
"model_revision": "main",
|
| 564 |
+
"model_sha": "",
|
| 565 |
+
"batch_size": "8",
|
| 566 |
+
"batch_sizes": [],
|
| 567 |
+
"device": "cuda:0",
|
| 568 |
+
"use_cache": null,
|
| 569 |
+
"limit": null,
|
| 570 |
+
"bootstrap_iters": 100000,
|
| 571 |
+
"gen_kwargs": null,
|
| 572 |
+
"random_seed": 0,
|
| 573 |
+
"numpy_seed": 1234,
|
| 574 |
+
"torch_seed": 1234,
|
| 575 |
+
"fewshot_seed": 1234
|
| 576 |
+
},
|
| 577 |
+
"git_hash": "core_v0.12.0-147-g5c103f4",
|
| 578 |
+
"date": 1774874949.7520695,
|
| 579 |
+
"pretty_env_info": "PyTorch version: 2.6.0+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: CentOS Linux 7 (Core) (x86_64)\nGCC version: (conda-forge gcc 9.5.0-19) 9.5.0\nClang version: Could not collect\nCMake version: version 3.30.1\nLibc version: glibc-2.17\n\nPython version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)\nPython platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.17\nIs CUDA available: True\nCUDA runtime version: 12.4.131\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: NVIDIA A800-SXM4-80GB\nGPU 1: NVIDIA A800-SXM4-80GB\nGPU 2: NVIDIA A800-SXM4-80GB\nGPU 3: NVIDIA A800-SXM4-80GB\nGPU 4: NVIDIA A800-SXM4-80GB\nGPU 5: NVIDIA A800-SXM4-80GB\nGPU 6: NVIDIA A800-SXM4-80GB\nGPU 7: NVIDIA A800-SXM4-80GB\n\nNvidia driver version: 550.163.01\ncuDNN version: Could not collect\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nByte Order: Little Endian\nCPU(s): 104\nOn-line CPU(s) list: 0-103\nThread(s) per core: 1\nCore(s) per socket: 52\n座: 2\nNUMA 节点: 2\n厂商 ID: GenuineIntel\nCPU 系列: 6\n型号: 143\n型号名称: Intel(R) Xeon(R) Platinum 8470\n步进: 8\nCPU MHz: 799.926\nCPU max MHz: 3800.0000\nCPU min MHz: 800.0000\nBogoMIPS: 4000.00\n虚拟化: VT-x\nL1d 缓存: 48K\nL1i 缓存: 32K\nL2 缓存: 2048K\nL3 缓存: 107520K\nNUMA 节点0 CPU: 0-51\nNUMA 节点1 CPU: 52-103\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_pt cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq cldemote movdiri movdir64b md_clear pconfig spec_ctrl intel_stibp flush_l1d arch_capabilities\n\nVersions of relevant libraries:\n[pip3] numpy==1.26.4\n[pip3] nvidia-cublas-cu12==12.4.5.8\n[pip3] nvidia-cuda-cupti-cu12==12.4.127\n[pip3] nvidia-cuda-nvrtc-cu12==12.4.127\n[pip3] nvidia-cuda-runtime-cu12==12.4.127\n[pip3] nvidia-cudnn-cu12==9.1.0.70\n[pip3] nvidia-cufft-cu12==11.2.1.3\n[pip3] nvidia-curand-cu12==10.3.5.147\n[pip3] nvidia-cusolver-cu12==11.6.1.9\n[pip3] nvidia-cusparse-cu12==12.3.1.170\n[pip3] nvidia-cusparselt-cu12==0.6.2\n[pip3] nvidia-nccl-cu11==2.21.5\n[pip3] nvidia-nccl-cu12==2.21.5\n[pip3] nvidia-nvjitlink-cu12==12.4.127\n[pip3] nvidia-nvtx-cu12==12.4.127\n[pip3] torch==2.6.0\n[pip3] torchaudio==2.6.0\n[pip3] torchdata==0.11.0\n[pip3] torchvision==0.21.0\n[pip3] triton==3.2.0\n[conda] cuda-cudart 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-cudart_linux-64 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-cupti 12.4.127 he02047a_2 conda-forge\n[conda] cuda-libraries 12.4.0 ha770c72_0 conda-forge\n[conda] cuda-nvrtc 12.4.99 hd3aeb46_0 conda-forge\n[conda] cuda-nvtx 12.4.127 he02047a_2 conda-forge\n[conda] cuda-opencl 12.4.99 h59595ed_0 conda-forge\n[conda] cuda-runtime 12.4.0 ha804496_0 conda-forge\n[conda] ffmpeg 4.3 hf484d3e_0 pytorch\n[conda] libcublas 12.4.2.65 hd3aeb46_0 conda-forge\n[conda] libcufft 11.2.0.44 hd3aeb46_0 conda-forge\n[conda] libcurand 10.3.5.119 hd3aeb46_0 conda-forge\n[conda] libcusolver 11.6.0.99 hd3aeb46_0 conda-forge\n[conda] libcusparse 12.3.0.142 hd3aeb46_0 conda-forge\n[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch\n[conda] libnvjitlink 12.4.99 hd3aeb46_0 conda-forge\n[conda] mkl 2023.1.0 h213fc3f_46344 defaults\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi\n[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi\n[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi\n[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi\n[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi\n[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi\n[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi\n[conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi\n[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi\n[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi\n[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi\n[conda] pytorch-cuda 12.4 hc786d27_6 pytorch\n[conda] pytorch-mutex 1.0 cuda pytorch\n[conda] torch 2.6.0 pypi_0 pypi\n[conda] torchaudio 2.6.0 pypi_0 pypi\n[conda] torchdata 0.11.0 pypi_0 pypi\n[conda] torchvision 0.21.0 pypi_0 pypi\n[conda] triton 3.2.0 pypi_0 pypi",
|
| 580 |
+
"transformers_version": "4.55.2",
|
| 581 |
+
"lm_eval_version": "0.4.9.1",
|
| 582 |
+
"upper_git_hash": null,
|
| 583 |
+
"tokenizer_pad_token": [
|
| 584 |
+
"<unk>",
|
| 585 |
+
"0"
|
| 586 |
+
],
|
| 587 |
+
"tokenizer_eos_token": [
|
| 588 |
+
"<|im_end|>",
|
| 589 |
+
"73440"
|
| 590 |
+
],
|
| 591 |
+
"tokenizer_bos_token": [
|
| 592 |
+
"<s>",
|
| 593 |
+
"1"
|
| 594 |
+
],
|
| 595 |
+
"eot_token_id": 73440,
|
| 596 |
+
"max_length": 4096,
|
| 597 |
+
"task_hashes": {},
|
| 598 |
+
"model_source": "hf",
|
| 599 |
+
"model_name": "results/hf_ckpts/blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128/",
|
| 600 |
+
"model_name_sanitized": "results__hf_ckpts__blockffn_02b_mul1002_withmean_d64_s128_lr93e4_b128__",
|
| 601 |
+
"system_instruction": null,
|
| 602 |
+
"system_instruction_sha": null,
|
| 603 |
+
"fewshot_as_multiturn": false,
|
| 604 |
+
"chat_template": null,
|
| 605 |
+
"chat_template_sha": null,
|
| 606 |
+
"start_time": 1822696.504315611,
|
| 607 |
+
"end_time": 1823059.519498931,
|
| 608 |
+
"total_evaluation_time_seconds": "363.01518332003616"
|
| 609 |
+
}
|
evaluation2.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generation_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_sample": true,
|
| 3 |
+
"top_p": 0.8,
|
| 4 |
+
"temperature": 0.8,
|
| 5 |
+
"bos_token_id": 1,
|
| 6 |
+
"eos_token_id": [2,73440],
|
| 7 |
+
"pad_token_id": 2
|
| 8 |
+
}
|
modeling_blockffn.py
ADDED
|
@@ -0,0 +1,1014 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
from typing import Callable, Optional, Union
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
import tree
|
| 27 |
+
from abc import ABC, abstractmethod
|
| 28 |
+
from fmoe.linear import MOELinear
|
| 29 |
+
from fmoe.functions import prepare_forward, MOEScatter, MOEGather
|
| 30 |
+
|
| 31 |
+
from transformers.activations import ACT2FN
|
| 32 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
+
from transformers.generation import GenerationMixin
|
| 34 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 35 |
+
from transformers.masking_utils import create_causal_mask
|
| 36 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 37 |
+
from transformers.modeling_outputs import (
|
| 38 |
+
BaseModelOutputWithPast,
|
| 39 |
+
CausalLMOutputWithPast,
|
| 40 |
+
)
|
| 41 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 42 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 43 |
+
from transformers.processing_utils import Unpack
|
| 44 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 45 |
+
from transformers.utils.generic import check_model_inputs
|
| 46 |
+
from .configuration_blockffn import BlockFFNConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 53 |
+
class BlockFFNRMSNorm(nn.Module):
|
| 54 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 57 |
+
self.variance_epsilon = eps
|
| 58 |
+
|
| 59 |
+
def forward(self, hidden_states):
|
| 60 |
+
input_dtype = hidden_states.dtype
|
| 61 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 62 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 63 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 64 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 65 |
+
|
| 66 |
+
def extra_repr(self):
|
| 67 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class BlockFFNRotaryEmbedding(nn.Module):
|
| 71 |
+
def __init__(self, config: BlockFFNConfig, device=None):
|
| 72 |
+
super().__init__()
|
| 73 |
+
# BC: "rope_type" was originally "type"
|
| 74 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 75 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 76 |
+
else:
|
| 77 |
+
self.rope_type = "default"
|
| 78 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 79 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 80 |
+
|
| 81 |
+
self.config = config
|
| 82 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 83 |
+
|
| 84 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 85 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 86 |
+
self.original_inv_freq = self.inv_freq
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 90 |
+
def forward(self, x, position_ids):
|
| 91 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 92 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 93 |
+
|
| 94 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 95 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 96 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 97 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 98 |
+
cos = emb.cos() * self.attention_scaling
|
| 99 |
+
sin = emb.sin() * self.attention_scaling
|
| 100 |
+
|
| 101 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def rotate_half(x):
|
| 105 |
+
"""Rotates half the hidden dims of the input."""
|
| 106 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 107 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 108 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 112 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
q (`torch.Tensor`): The query tensor.
|
| 116 |
+
k (`torch.Tensor`): The key tensor.
|
| 117 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 118 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 119 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 120 |
+
Deprecated and unused.
|
| 121 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 122 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 123 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 124 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 125 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 126 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 127 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 128 |
+
Returns:
|
| 129 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 130 |
+
"""
|
| 131 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 132 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 133 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 134 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 135 |
+
return q_embed, k_embed
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class SimpleLayerNorm(nn.Module):
|
| 139 |
+
def __init__(self, dim_norm: int):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.dim_norm = dim_norm
|
| 142 |
+
self.weight = torch.nn.Parameter(torch.empty(self.dim_norm))
|
| 143 |
+
|
| 144 |
+
@torch.compile
|
| 145 |
+
def forward(self, x: torch.Tensor):
|
| 146 |
+
return x * self.weight
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class BlockFFNMLP(nn.Module):
|
| 150 |
+
def __init__(self, config: BlockFFNConfig, intermediate_size: int = None):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.config = config
|
| 153 |
+
self.hidden_size = config.hidden_size
|
| 154 |
+
self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
|
| 155 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 156 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 157 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 158 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 159 |
+
|
| 160 |
+
def forward(self, x):
|
| 161 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 162 |
+
return down_proj
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class BlockFFNRouter(nn.Module):
|
| 166 |
+
def __init__(self, config: BlockFFNConfig):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.config = config
|
| 169 |
+
self.num_experts = self.config.num_experts
|
| 170 |
+
|
| 171 |
+
if self.config.moe_router_dtype == "fp32":
|
| 172 |
+
self.router_dtype = torch.float32
|
| 173 |
+
elif self.config.moe_router_dtype == "fp64":
|
| 174 |
+
self.router_dtype = torch.float64
|
| 175 |
+
elif self.config.moe_router_dtype == "bf16":
|
| 176 |
+
self.router_dtype = torch.bfloat16
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
|
| 179 |
+
|
| 180 |
+
self.weight = torch.nn.Parameter(
|
| 181 |
+
torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor):
|
| 185 |
+
return nn.functional.linear(x.to(self.router_dtype), self.weight)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class NormSiLU(nn.Module):
|
| 189 |
+
def __init__(self, config: BlockFFNConfig):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size
|
| 192 |
+
self.activate_fn_type = config.expert_act_func
|
| 193 |
+
assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean"]
|
| 194 |
+
|
| 195 |
+
self.rms_norm = None
|
| 196 |
+
if self.activate_fn_type != "norm_silu_norms":
|
| 197 |
+
self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon)
|
| 198 |
+
self.silu = torch.nn.SiLU()
|
| 199 |
+
|
| 200 |
+
@torch.compile
|
| 201 |
+
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
assert hidden.ndim == 2
|
| 203 |
+
if self.activate_fn_type != "norm_silu_nomean":
|
| 204 |
+
hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True)
|
| 205 |
+
if self.activate_fn_type != "norm_silu_norms":
|
| 206 |
+
return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size)))
|
| 207 |
+
else:
|
| 208 |
+
return self.silu(hidden)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class BlockFFNLayer(nn.Module):
|
| 212 |
+
def __init__(self, config: BlockFFNConfig):
|
| 213 |
+
super(BlockFFNLayer, self).__init__()
|
| 214 |
+
self.config = config
|
| 215 |
+
self.num_experts, self.dim_expert, self.hidden_size = \
|
| 216 |
+
config.num_experts, config.moe_ffn_hidden_size, config.hidden_size
|
| 217 |
+
self.dim_shared_expert = config.moe_shared_expert_intermediate_size
|
| 218 |
+
self.router_norm_type = config.router_norm_type
|
| 219 |
+
|
| 220 |
+
self.moe_router = BlockFFNRouter(self.config)
|
| 221 |
+
assert config.router_act_func == "relu"
|
| 222 |
+
self.router_act = nn.ReLU()
|
| 223 |
+
if config.router_norm_type == "simple":
|
| 224 |
+
self.router_norm = SimpleLayerNorm(self.config.num_experts)
|
| 225 |
+
elif config.router_norm_type == "rms":
|
| 226 |
+
self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon)
|
| 227 |
+
else:
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
self.expert_gated = not config.expert_not_gated
|
| 231 |
+
if self.expert_gated:
|
| 232 |
+
self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
|
| 233 |
+
|
| 234 |
+
self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
|
| 235 |
+
assert config.expert_act_norm_type == "normal"
|
| 236 |
+
if config.expert_act_func == "norm_silu":
|
| 237 |
+
self.expert_act = NormSiLU(self.config)
|
| 238 |
+
elif config.expert_act_func == "silu":
|
| 239 |
+
self.expert_act = nn.SiLU()
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError
|
| 242 |
+
self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias)
|
| 243 |
+
|
| 244 |
+
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
|
| 245 |
+
if self.use_shared_expert:
|
| 246 |
+
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
|
| 247 |
+
|
| 248 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 249 |
+
ori_shape = hidden_states.shape
|
| 250 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 251 |
+
seq_len = hidden_states.shape[0]
|
| 252 |
+
|
| 253 |
+
# router module forward
|
| 254 |
+
raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts]
|
| 255 |
+
router_score = self.router_act(raw_router_score)
|
| 256 |
+
router_score = self.router_norm(router_score)
|
| 257 |
+
|
| 258 |
+
# expert module forward
|
| 259 |
+
x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert]
|
| 260 |
+
if self.expert_gated:
|
| 261 |
+
x_gate = self.expert_gate_proj(hidden_states)
|
| 262 |
+
x_in = x_in * self.expert_act(x_gate)
|
| 263 |
+
else:
|
| 264 |
+
x_in = self.expert_act(x_in)
|
| 265 |
+
if x_in.ndim == 3:
|
| 266 |
+
scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1)
|
| 267 |
+
else:
|
| 268 |
+
scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1)
|
| 269 |
+
output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert))
|
| 270 |
+
|
| 271 |
+
if self.use_shared_expert:
|
| 272 |
+
output = output + self.shared_experts(hidden_states)
|
| 273 |
+
return output.view(*ori_shape)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class BaseRouter(ABC, nn.Module):
|
| 277 |
+
"""Base Router class"""
|
| 278 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.config = config
|
| 281 |
+
self.num_experts = self.config.num_experts
|
| 282 |
+
|
| 283 |
+
if self.config.moe_router_dtype == "fp32":
|
| 284 |
+
self.router_dtype = torch.float32
|
| 285 |
+
elif self.config.moe_router_dtype == "fp64":
|
| 286 |
+
self.router_dtype = torch.float64
|
| 287 |
+
elif self.config.moe_router_dtype == "bf16":
|
| 288 |
+
self.router_dtype = torch.bfloat16
|
| 289 |
+
else:
|
| 290 |
+
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
|
| 291 |
+
|
| 292 |
+
self.weight = torch.nn.Parameter(
|
| 293 |
+
torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def gating(self, input: torch.Tensor):
|
| 297 |
+
return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype))
|
| 298 |
+
|
| 299 |
+
@abstractmethod
|
| 300 |
+
def routing(self, logits: torch.Tensor):
|
| 301 |
+
"""Routing function.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
logits (torch.Tensor): Logits tensor.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
|
| 308 |
+
probabilities and mapping.
|
| 309 |
+
"""
|
| 310 |
+
raise NotImplementedError("Routing function not implemented.")
|
| 311 |
+
|
| 312 |
+
@abstractmethod
|
| 313 |
+
def forward(self, input: torch.Tensor):
|
| 314 |
+
"""
|
| 315 |
+
Forward pass of the router.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
input (torch.Tensor): Input tensor.
|
| 319 |
+
"""
|
| 320 |
+
raise NotImplementedError("Forward function not implemented.")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class TopKRouter(BaseRouter):
|
| 324 |
+
"""Route each token to the top-k experts."""
|
| 325 |
+
|
| 326 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 327 |
+
super().__init__(config)
|
| 328 |
+
self.config = config
|
| 329 |
+
self.topk = self.config.moe_router_topk
|
| 330 |
+
self.score_function = self.config.moe_router_score_function
|
| 331 |
+
self.use_pre_softmax = self.config.moe_router_pre_softmax
|
| 332 |
+
self.scaling_factor = self.config.moe_router_topk_scaling_factor
|
| 333 |
+
|
| 334 |
+
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
|
| 335 |
+
if self.enable_expert_bias:
|
| 336 |
+
self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
|
| 337 |
+
else:
|
| 338 |
+
self.expert_bias = None
|
| 339 |
+
|
| 340 |
+
def _maintain_float32_expert_bias(self):
|
| 341 |
+
"""
|
| 342 |
+
Maintain the expert bias in float32.
|
| 343 |
+
|
| 344 |
+
When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
|
| 345 |
+
We keep it in float32 to avoid routing errors when updating the expert_bias.
|
| 346 |
+
"""
|
| 347 |
+
if hasattr(self, 'expert_bias') and self.expert_bias is not None:
|
| 348 |
+
if self.expert_bias.dtype != torch.float32:
|
| 349 |
+
self.expert_bias.data = self.expert_bias.data.to(torch.float32)
|
| 350 |
+
|
| 351 |
+
def routing(self, logits: torch.Tensor):
|
| 352 |
+
"""Top-k routing function
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 359 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 360 |
+
with shape [num_tokens, num_experts].
|
| 361 |
+
"""
|
| 362 |
+
logits = logits.view(-1, self.num_experts)
|
| 363 |
+
|
| 364 |
+
if self.score_function == "softmax":
|
| 365 |
+
if self.use_pre_softmax:
|
| 366 |
+
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
|
| 367 |
+
probs, top_indices = torch.topk(scores, k=self.topk, dim=1)
|
| 368 |
+
else:
|
| 369 |
+
scores, top_indices = torch.topk(logits, k=self.topk, dim=1)
|
| 370 |
+
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
|
| 371 |
+
elif self.score_function == "sigmoid":
|
| 372 |
+
scores = torch.sigmoid(logits.float()).type_as(logits)
|
| 373 |
+
if self.expert_bias is not None:
|
| 374 |
+
scores_for_routing = scores + self.expert_bias
|
| 375 |
+
_, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1)
|
| 376 |
+
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
|
| 377 |
+
else:
|
| 378 |
+
scores, top_indices = torch.topk(scores, k=self.topk, dim=1)
|
| 379 |
+
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError(f"Invalid score_function: {self.score_function}")
|
| 382 |
+
|
| 383 |
+
if self.scaling_factor:
|
| 384 |
+
probs = probs * self.scaling_factor
|
| 385 |
+
|
| 386 |
+
return probs, top_indices
|
| 387 |
+
|
| 388 |
+
def forward(self, input: torch.Tensor):
|
| 389 |
+
"""
|
| 390 |
+
Forward pass of the router.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
input (torch.Tensor): Input tensor.
|
| 394 |
+
"""
|
| 395 |
+
self._maintain_float32_expert_bias()
|
| 396 |
+
logits = self.gating(input)
|
| 397 |
+
top_scores, top_indices = self.routing(logits)
|
| 398 |
+
return top_scores, top_indices
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ReMoERouter(BaseRouter):
|
| 402 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 403 |
+
super().__init__(config)
|
| 404 |
+
self.config = config
|
| 405 |
+
self.router_act = torch.nn.ReLU()
|
| 406 |
+
|
| 407 |
+
def routing(self, logits: torch.Tensor):
|
| 408 |
+
"""Top-k routing function
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 415 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 416 |
+
with shape [num_tokens, num_experts].
|
| 417 |
+
"""
|
| 418 |
+
logits = logits.view(-1, self.num_experts)
|
| 419 |
+
|
| 420 |
+
router_score = self.router_act(logits)
|
| 421 |
+
routing_map = router_score > 0
|
| 422 |
+
|
| 423 |
+
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
|
| 424 |
+
sorted_map = sorted_probs <= 0
|
| 425 |
+
sorted_indices = torch.where(sorted_map, -1, sorted_indices)
|
| 426 |
+
max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1)
|
| 427 |
+
assert torch.all(sorted_map[:, max_valid_num:])
|
| 428 |
+
sorted_probs = sorted_probs[:, :max_valid_num]
|
| 429 |
+
sorted_indices = sorted_indices[:, :max_valid_num]
|
| 430 |
+
assert torch.sum(routing_map) == torch.sum(sorted_indices != -1)
|
| 431 |
+
return sorted_probs, sorted_indices
|
| 432 |
+
|
| 433 |
+
def forward(self, input: torch.Tensor):
|
| 434 |
+
"""
|
| 435 |
+
Forward pass of the router.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
input (torch.Tensor): Input tensor.
|
| 439 |
+
"""
|
| 440 |
+
logits = self.gating(input)
|
| 441 |
+
top_scores, top_indices = self.routing(logits)
|
| 442 |
+
return top_scores, top_indices
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class TopPRouter(BaseRouter):
|
| 446 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 447 |
+
super().__init__(config)
|
| 448 |
+
self.config = config
|
| 449 |
+
self.top_p = config.moe_router_topp
|
| 450 |
+
|
| 451 |
+
def routing(self, logits: torch.Tensor):
|
| 452 |
+
"""Top-k routing function
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 459 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 460 |
+
with shape [num_tokens, num_experts].
|
| 461 |
+
"""
|
| 462 |
+
logits = logits.view(-1, self.num_experts)
|
| 463 |
+
|
| 464 |
+
router_score = torch.abs(logits)
|
| 465 |
+
router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20)
|
| 466 |
+
|
| 467 |
+
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
|
| 468 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 469 |
+
mask = cumulative_probs > self.top_p
|
| 470 |
+
|
| 471 |
+
threshold_indices = mask.long().argmax(dim=-1)
|
| 472 |
+
threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
|
| 473 |
+
|
| 474 |
+
mask = mask & ~threshold_mask
|
| 475 |
+
sorted_indices = torch.where(mask, -1, sorted_indices)
|
| 476 |
+
sorted_probs = torch.where(mask, 0.0, sorted_probs)
|
| 477 |
+
|
| 478 |
+
max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1)
|
| 479 |
+
assert torch.all(mask[:, max_valid_num:])
|
| 480 |
+
|
| 481 |
+
sorted_indices = sorted_indices[:, :max_valid_num]
|
| 482 |
+
sorted_probs = sorted_probs[:, :max_valid_num]
|
| 483 |
+
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
|
| 484 |
+
return sorted_probs, sorted_indices
|
| 485 |
+
|
| 486 |
+
def forward(self, input: torch.Tensor):
|
| 487 |
+
"""
|
| 488 |
+
Forward pass of the router.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
input (torch.Tensor): Input tensor.
|
| 492 |
+
"""
|
| 493 |
+
logits = self.gating(input)
|
| 494 |
+
top_scores, top_indices = self.routing(logits)
|
| 495 |
+
return top_scores, top_indices
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class FastTopKCalculator:
|
| 499 |
+
def __init__(self, num_experts: int):
|
| 500 |
+
self.num_experts = num_experts
|
| 501 |
+
|
| 502 |
+
def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module):
|
| 503 |
+
(
|
| 504 |
+
pos,
|
| 505 |
+
local_expert_count,
|
| 506 |
+
global_expert_count,
|
| 507 |
+
fwd_expert_count,
|
| 508 |
+
fwd_batch_size,
|
| 509 |
+
) = prepare_forward(topk_indices, self.num_experts, 1)
|
| 510 |
+
topk = 1
|
| 511 |
+
if len(topk_indices.shape) == 2:
|
| 512 |
+
topk = topk_indices.shape[1]
|
| 513 |
+
|
| 514 |
+
def scatter_func(tensor):
|
| 515 |
+
return MOEScatter.apply(
|
| 516 |
+
tensor,
|
| 517 |
+
torch.div(pos, topk, rounding_mode='floor'),
|
| 518 |
+
local_expert_count,
|
| 519 |
+
global_expert_count,
|
| 520 |
+
fwd_batch_size,
|
| 521 |
+
1,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
x = tree.map_structure(scatter_func, hidden_states)
|
| 525 |
+
x = experts(x, fwd_expert_count, topk_indices=topk_indices)
|
| 526 |
+
|
| 527 |
+
out_batch_size = tree.flatten(hidden_states)[0].shape[0]
|
| 528 |
+
if len(topk_indices.shape) == 2:
|
| 529 |
+
out_batch_size *= topk_indices.shape[1]
|
| 530 |
+
|
| 531 |
+
def gather_func(tensor):
|
| 532 |
+
return MOEGather.apply(
|
| 533 |
+
tensor,
|
| 534 |
+
pos,
|
| 535 |
+
local_expert_count,
|
| 536 |
+
global_expert_count,
|
| 537 |
+
out_batch_size,
|
| 538 |
+
1,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
outp = tree.map_structure(gather_func, x)
|
| 542 |
+
return outp
|
| 543 |
+
|
| 544 |
+
def forward(self, hidden_states, topk_indices, topk_weights, experts):
|
| 545 |
+
assert topk_indices.shape == topk_weights.shape
|
| 546 |
+
top_k = topk_indices.shape[-1]
|
| 547 |
+
dim3 = hidden_states.ndim == 3
|
| 548 |
+
if dim3:
|
| 549 |
+
batch_size, seq_len, dim = hidden_states.shape
|
| 550 |
+
hidden_states = hidden_states.view(batch_size * seq_len, dim)
|
| 551 |
+
else:
|
| 552 |
+
assert hidden_states.ndim == 2
|
| 553 |
+
batch_size, (seq_len, dim) = -1, hidden_states.shape
|
| 554 |
+
fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts)
|
| 555 |
+
|
| 556 |
+
def view_func(tensor):
|
| 557 |
+
n_dim = tensor.shape[-1]
|
| 558 |
+
tensor = tensor.view(-1, top_k, n_dim)
|
| 559 |
+
return tensor
|
| 560 |
+
|
| 561 |
+
moe_output = tree.map_structure(view_func, fwd)
|
| 562 |
+
topk_weights = topk_weights.unsqueeze(1)
|
| 563 |
+
|
| 564 |
+
def bmm_func(tensor):
|
| 565 |
+
n_dim = tensor.shape[-1]
|
| 566 |
+
tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim)
|
| 567 |
+
return tensor
|
| 568 |
+
|
| 569 |
+
moe_output = tree.map_structure(bmm_func, moe_output)
|
| 570 |
+
if dim3:
|
| 571 |
+
moe_output = moe_output.view(batch_size, seq_len, -1)
|
| 572 |
+
return moe_output
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class MoELinearExperts(nn.Module):
|
| 576 |
+
def __init__(
|
| 577 |
+
self,
|
| 578 |
+
dim_in: int,
|
| 579 |
+
dim_out: int,
|
| 580 |
+
num_experts: int,
|
| 581 |
+
ffn_bias: bool,
|
| 582 |
+
):
|
| 583 |
+
super().__init__()
|
| 584 |
+
self.dim_in = self.in_features = dim_in
|
| 585 |
+
self.dim_out = self.out_features = dim_out
|
| 586 |
+
self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
|
| 587 |
+
self.bias = None
|
| 588 |
+
if ffn_bias:
|
| 589 |
+
self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out))
|
| 590 |
+
|
| 591 |
+
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor):
|
| 592 |
+
x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias)
|
| 593 |
+
return x
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class MoEGatedExperts(nn.Module):
|
| 597 |
+
def __init__(
|
| 598 |
+
self,
|
| 599 |
+
dim_in: int,
|
| 600 |
+
dim_ff: int,
|
| 601 |
+
is_gated: bool,
|
| 602 |
+
act_name: str,
|
| 603 |
+
num_experts: int,
|
| 604 |
+
ffn_bias: bool = False,
|
| 605 |
+
):
|
| 606 |
+
super().__init__()
|
| 607 |
+
self.is_gated = is_gated
|
| 608 |
+
self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts
|
| 609 |
+
if self.is_gated:
|
| 610 |
+
self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
|
| 611 |
+
self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
|
| 612 |
+
self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias)
|
| 613 |
+
|
| 614 |
+
self.act_fn = ACT2FN[act_name]
|
| 615 |
+
|
| 616 |
+
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 617 |
+
if self.is_gated:
|
| 618 |
+
gate_score = self.gate_proj(x, fwd_expert_count)
|
| 619 |
+
up_proj = self.up_proj(x, fwd_expert_count)
|
| 620 |
+
x = up_proj * self.act_fn(gate_score)
|
| 621 |
+
else:
|
| 622 |
+
up_score = self.up_proj(x, fwd_expert_count)
|
| 623 |
+
x = self.act_fn(up_score)
|
| 624 |
+
x = self.down_proj(x, fwd_expert_count)
|
| 625 |
+
return x
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class VanillaMoELayer(nn.Module):
|
| 629 |
+
def __init__(self, config: BlockFFNConfig):
|
| 630 |
+
super(VanillaMoELayer, self).__init__()
|
| 631 |
+
self.config = config
|
| 632 |
+
|
| 633 |
+
# Initialize router
|
| 634 |
+
if config.router_type == "topk":
|
| 635 |
+
self.router = TopKRouter(config=self.config)
|
| 636 |
+
elif config.router_type == "remoe":
|
| 637 |
+
self.router = ReMoERouter(config=self.config)
|
| 638 |
+
elif config.router_type == "topp":
|
| 639 |
+
self.router = TopPRouter(config=self.config)
|
| 640 |
+
else:
|
| 641 |
+
raise NotImplementedError(f"Router type {config.router_type} not implemented.")
|
| 642 |
+
|
| 643 |
+
self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts)
|
| 644 |
+
|
| 645 |
+
# Initialize experts
|
| 646 |
+
self.experts = MoEGatedExperts(
|
| 647 |
+
dim_in=self.config.hidden_size,
|
| 648 |
+
dim_ff=self.config.moe_ffn_hidden_size,
|
| 649 |
+
is_gated=not self.config.expert_not_gated,
|
| 650 |
+
act_name="silu",
|
| 651 |
+
num_experts=self.config.num_experts,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size
|
| 655 |
+
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
|
| 656 |
+
if self.use_shared_expert:
|
| 657 |
+
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
|
| 658 |
+
|
| 659 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 660 |
+
top_scores, top_indices = self.router(hidden_states)
|
| 661 |
+
y = self.mix_calculator.forward(
|
| 662 |
+
hidden_states=hidden_states,
|
| 663 |
+
topk_indices=top_indices.contiguous(),
|
| 664 |
+
topk_weights=top_scores.type_as(hidden_states),
|
| 665 |
+
experts=self.experts,
|
| 666 |
+
)
|
| 667 |
+
if self.shared_experts is not None:
|
| 668 |
+
y = y + self.shared_experts(hidden_states)
|
| 669 |
+
return y
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 673 |
+
"""
|
| 674 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 675 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 676 |
+
"""
|
| 677 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 678 |
+
if n_rep == 1:
|
| 679 |
+
return hidden_states
|
| 680 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 681 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def eager_attention_forward(
|
| 685 |
+
module: nn.Module,
|
| 686 |
+
query: torch.Tensor,
|
| 687 |
+
key: torch.Tensor,
|
| 688 |
+
value: torch.Tensor,
|
| 689 |
+
attention_mask: Optional[torch.Tensor],
|
| 690 |
+
scaling: float,
|
| 691 |
+
dropout: float = 0.0,
|
| 692 |
+
):
|
| 693 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 694 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 695 |
+
|
| 696 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 697 |
+
if attention_mask is not None:
|
| 698 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 699 |
+
attn_weights = attn_weights + causal_mask
|
| 700 |
+
|
| 701 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 702 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 703 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 704 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 705 |
+
|
| 706 |
+
return attn_output, attn_weights
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class BlockFFNAttention(nn.Module):
|
| 710 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 711 |
+
|
| 712 |
+
def __init__(self, config: BlockFFNConfig, layer_idx: int):
|
| 713 |
+
super().__init__()
|
| 714 |
+
self.config = config
|
| 715 |
+
self.layer_idx = layer_idx
|
| 716 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 717 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
|
| 718 |
+
self.scaling = self.head_dim**-0.5
|
| 719 |
+
self.attention_dropout = config.attention_dropout
|
| 720 |
+
self.is_causal = True
|
| 721 |
+
|
| 722 |
+
self.q_proj = nn.Linear(
|
| 723 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 724 |
+
)
|
| 725 |
+
self.k_proj = nn.Linear(
|
| 726 |
+
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
|
| 727 |
+
)
|
| 728 |
+
self.v_proj = nn.Linear(
|
| 729 |
+
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
|
| 730 |
+
)
|
| 731 |
+
self.o_proj = nn.Linear(
|
| 732 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
def forward(
|
| 736 |
+
self,
|
| 737 |
+
hidden_states: torch.Tensor,
|
| 738 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 739 |
+
attention_mask: Optional[torch.Tensor],
|
| 740 |
+
past_key_value: Optional[Cache] = None,
|
| 741 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 742 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 743 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 744 |
+
input_shape = hidden_states.shape[:-1]
|
| 745 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 746 |
+
|
| 747 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 748 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 749 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 750 |
+
|
| 751 |
+
cos, sin = position_embeddings
|
| 752 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 753 |
+
|
| 754 |
+
if past_key_value is not None:
|
| 755 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 756 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 757 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 758 |
+
|
| 759 |
+
attention_interface: Callable = eager_attention_forward
|
| 760 |
+
if self.config._attn_implementation != "eager":
|
| 761 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 762 |
+
|
| 763 |
+
attn_output, attn_weights = attention_interface(
|
| 764 |
+
self,
|
| 765 |
+
query_states,
|
| 766 |
+
key_states,
|
| 767 |
+
value_states,
|
| 768 |
+
attention_mask,
|
| 769 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 770 |
+
scaling=self.scaling,
|
| 771 |
+
**kwargs,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 775 |
+
attn_output = self.o_proj(attn_output)
|
| 776 |
+
return attn_output, attn_weights
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class BlockFFNDecoderLayer(GradientCheckpointingLayer):
|
| 780 |
+
def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool):
|
| 781 |
+
super().__init__()
|
| 782 |
+
self.config = config
|
| 783 |
+
self.hidden_size = config.hidden_size
|
| 784 |
+
|
| 785 |
+
self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx)
|
| 786 |
+
|
| 787 |
+
if is_moe_layer:
|
| 788 |
+
if config.use_blockffn:
|
| 789 |
+
self.mlp = BlockFFNLayer(config)
|
| 790 |
+
elif config.router_type in ["topk", "remoe", "topp"]:
|
| 791 |
+
self.mlp = VanillaMoELayer(config)
|
| 792 |
+
else:
|
| 793 |
+
raise NotImplementedError
|
| 794 |
+
else:
|
| 795 |
+
self.mlp = BlockFFNMLP(config)
|
| 796 |
+
self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 797 |
+
self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 798 |
+
|
| 799 |
+
def forward(
|
| 800 |
+
self,
|
| 801 |
+
hidden_states: torch.Tensor,
|
| 802 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 803 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 804 |
+
past_key_value: Optional[Cache] = None,
|
| 805 |
+
use_cache: Optional[bool] = False,
|
| 806 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 807 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 808 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 809 |
+
) -> tuple[torch.Tensor]:
|
| 810 |
+
residual = hidden_states
|
| 811 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 812 |
+
# Self Attention
|
| 813 |
+
hidden_states, _ = self.self_attn(
|
| 814 |
+
hidden_states=hidden_states,
|
| 815 |
+
attention_mask=attention_mask,
|
| 816 |
+
position_ids=position_ids,
|
| 817 |
+
past_key_value=past_key_value,
|
| 818 |
+
use_cache=use_cache,
|
| 819 |
+
cache_position=cache_position,
|
| 820 |
+
position_embeddings=position_embeddings,
|
| 821 |
+
**kwargs,
|
| 822 |
+
)
|
| 823 |
+
if self.config.use_mup:
|
| 824 |
+
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
|
| 825 |
+
else:
|
| 826 |
+
hidden_states = residual + hidden_states
|
| 827 |
+
|
| 828 |
+
# Fully Connected
|
| 829 |
+
residual = hidden_states
|
| 830 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 831 |
+
hidden_states = self.mlp(hidden_states)
|
| 832 |
+
if self.config.use_mup:
|
| 833 |
+
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
|
| 834 |
+
else:
|
| 835 |
+
hidden_states = residual + hidden_states
|
| 836 |
+
return hidden_states
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
@auto_docstring
|
| 840 |
+
class BlockFFNPreTrainedModel(PreTrainedModel):
|
| 841 |
+
config: BlockFFNConfig
|
| 842 |
+
base_model_prefix = "model"
|
| 843 |
+
supports_gradient_checkpointing = True
|
| 844 |
+
_no_split_modules = ["BlockFFNDecoderLayer"]
|
| 845 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 846 |
+
_supports_flash_attn = True
|
| 847 |
+
_supports_sdpa = True
|
| 848 |
+
_supports_flex_attn = True
|
| 849 |
+
|
| 850 |
+
_can_compile_fullgraph = True
|
| 851 |
+
_supports_attention_backend = True
|
| 852 |
+
_can_record_outputs = {
|
| 853 |
+
"hidden_states": BlockFFNDecoderLayer,
|
| 854 |
+
"attentions": BlockFFNAttention,
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
@auto_docstring
|
| 859 |
+
class BlockFFNModel(BlockFFNPreTrainedModel):
|
| 860 |
+
def __init__(self, config: BlockFFNConfig):
|
| 861 |
+
super().__init__(config)
|
| 862 |
+
self.config = config
|
| 863 |
+
self.padding_idx = config.pad_token_id
|
| 864 |
+
self.vocab_size = config.vocab_size
|
| 865 |
+
|
| 866 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 867 |
+
self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq
|
| 868 |
+
assert len(self.moe_layer_freq) == config.num_layers
|
| 869 |
+
self.layers = nn.ModuleList(
|
| 870 |
+
[BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)]
|
| 871 |
+
)
|
| 872 |
+
self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 873 |
+
self.rotary_emb = BlockFFNRotaryEmbedding(config=config)
|
| 874 |
+
self.gradient_checkpointing = False
|
| 875 |
+
|
| 876 |
+
# Initialize weights and apply final processing
|
| 877 |
+
self.post_init()
|
| 878 |
+
|
| 879 |
+
@check_model_inputs
|
| 880 |
+
@auto_docstring
|
| 881 |
+
def forward(
|
| 882 |
+
self,
|
| 883 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 884 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 885 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 886 |
+
past_key_values: Optional[Cache] = None,
|
| 887 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 888 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 889 |
+
use_cache: Optional[bool] = None,
|
| 890 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 891 |
+
) -> BaseModelOutputWithPast:
|
| 892 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 893 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 894 |
+
|
| 895 |
+
if inputs_embeds is None:
|
| 896 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 897 |
+
if self.config.use_mup:
|
| 898 |
+
inputs_embeds = inputs_embeds * self.config.mup_emb_scale
|
| 899 |
+
|
| 900 |
+
if use_cache and past_key_values is None:
|
| 901 |
+
past_key_values = DynamicCache()
|
| 902 |
+
|
| 903 |
+
if cache_position is None:
|
| 904 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 905 |
+
cache_position: torch.Tensor = torch.arange(
|
| 906 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
if position_ids is None:
|
| 910 |
+
position_ids = cache_position.unsqueeze(0)
|
| 911 |
+
|
| 912 |
+
causal_mask = create_causal_mask(
|
| 913 |
+
config=self.config,
|
| 914 |
+
input_embeds=inputs_embeds,
|
| 915 |
+
attention_mask=attention_mask,
|
| 916 |
+
cache_position=cache_position,
|
| 917 |
+
past_key_values=past_key_values,
|
| 918 |
+
position_ids=position_ids,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
hidden_states = inputs_embeds
|
| 922 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 923 |
+
|
| 924 |
+
for decoder_layer in self.layers[: self.config.num_layers]:
|
| 925 |
+
hidden_states = decoder_layer(
|
| 926 |
+
hidden_states,
|
| 927 |
+
attention_mask=causal_mask,
|
| 928 |
+
position_ids=position_ids,
|
| 929 |
+
past_key_value=past_key_values,
|
| 930 |
+
cache_position=cache_position,
|
| 931 |
+
position_embeddings=position_embeddings,
|
| 932 |
+
**kwargs,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
hidden_states = self.norm(hidden_states)
|
| 936 |
+
return BaseModelOutputWithPast(
|
| 937 |
+
last_hidden_state=hidden_states,
|
| 938 |
+
past_key_values=past_key_values,
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
@auto_docstring
|
| 943 |
+
class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin):
|
| 944 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 945 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 946 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 947 |
+
|
| 948 |
+
def __init__(self, config: BlockFFNConfig):
|
| 949 |
+
super().__init__(config)
|
| 950 |
+
self.config = config
|
| 951 |
+
self.model = BlockFFNModel(config)
|
| 952 |
+
self.vocab_size = config.vocab_size
|
| 953 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 954 |
+
|
| 955 |
+
# Initialize weights and apply final processing
|
| 956 |
+
self.post_init()
|
| 957 |
+
|
| 958 |
+
def set_decoder(self, decoder):
|
| 959 |
+
self.model = decoder
|
| 960 |
+
|
| 961 |
+
def get_decoder(self):
|
| 962 |
+
return self.model
|
| 963 |
+
|
| 964 |
+
@can_return_tuple
|
| 965 |
+
@auto_docstring
|
| 966 |
+
def forward(
|
| 967 |
+
self,
|
| 968 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 969 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 970 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 971 |
+
past_key_values: Optional[Cache] = None,
|
| 972 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 973 |
+
labels: Optional[torch.LongTensor] = None,
|
| 974 |
+
use_cache: Optional[bool] = None,
|
| 975 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 976 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 977 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 978 |
+
) -> CausalLMOutputWithPast:
|
| 979 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 980 |
+
input_ids=input_ids,
|
| 981 |
+
attention_mask=attention_mask,
|
| 982 |
+
position_ids=position_ids,
|
| 983 |
+
past_key_values=past_key_values,
|
| 984 |
+
inputs_embeds=inputs_embeds,
|
| 985 |
+
use_cache=use_cache,
|
| 986 |
+
cache_position=cache_position,
|
| 987 |
+
**kwargs,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
hidden_states = outputs.last_hidden_state
|
| 991 |
+
if self.config.use_mup:
|
| 992 |
+
hidden_states = hidden_states / self.config.mup_width_scale
|
| 993 |
+
|
| 994 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 995 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 996 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 997 |
+
|
| 998 |
+
loss = None
|
| 999 |
+
if labels is not None:
|
| 1000 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1001 |
+
|
| 1002 |
+
return CausalLMOutputWithPast(
|
| 1003 |
+
loss=loss,
|
| 1004 |
+
logits=logits,
|
| 1005 |
+
past_key_values=outputs.past_key_values,
|
| 1006 |
+
hidden_states=outputs.hidden_states,
|
| 1007 |
+
attentions=outputs.attentions,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
__all__ = [
|
| 1011 |
+
"BlockFFNForCausalLM",
|
| 1012 |
+
"BlockFFNModel",
|
| 1013 |
+
"BlockFFNPreTrainedModel",
|
| 1014 |
+
]
|
modeling_blockffn.py.bak
ADDED
|
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
from typing import Callable, Optional, Union
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
import tree
|
| 27 |
+
from abc import ABC, abstractmethod
|
| 28 |
+
from fmoe.linear import MOELinear
|
| 29 |
+
from fmoe.functions import prepare_forward, MOEScatter, MOEGather
|
| 30 |
+
|
| 31 |
+
from transformers.activations import ACT2FN
|
| 32 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
+
from transformers.generation import GenerationMixin
|
| 34 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 35 |
+
from transformers.masking_utils import create_causal_mask
|
| 36 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 37 |
+
from transformers.modeling_outputs import (
|
| 38 |
+
BaseModelOutputWithPast,
|
| 39 |
+
CausalLMOutputWithPast,
|
| 40 |
+
)
|
| 41 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 42 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 43 |
+
from transformers.processing_utils import Unpack
|
| 44 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 45 |
+
from transformers.utils.generic import check_model_inputs
|
| 46 |
+
from .configuration_blockffn import BlockFFNConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 53 |
+
class BlockFFNRMSNorm(nn.Module):
|
| 54 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 57 |
+
self.variance_epsilon = eps
|
| 58 |
+
|
| 59 |
+
def forward(self, hidden_states):
|
| 60 |
+
input_dtype = hidden_states.dtype
|
| 61 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 62 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 63 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 64 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 65 |
+
|
| 66 |
+
def extra_repr(self):
|
| 67 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class BlockFFNRotaryEmbedding(nn.Module):
|
| 71 |
+
def __init__(self, config: BlockFFNConfig, device=None):
|
| 72 |
+
super().__init__()
|
| 73 |
+
# BC: "rope_type" was originally "type"
|
| 74 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 75 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 76 |
+
else:
|
| 77 |
+
self.rope_type = "default"
|
| 78 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 79 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 80 |
+
|
| 81 |
+
self.config = config
|
| 82 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 83 |
+
|
| 84 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 85 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 86 |
+
self.original_inv_freq = self.inv_freq
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 90 |
+
def forward(self, x, position_ids):
|
| 91 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 92 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 93 |
+
|
| 94 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 95 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 96 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 97 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 98 |
+
cos = emb.cos() * self.attention_scaling
|
| 99 |
+
sin = emb.sin() * self.attention_scaling
|
| 100 |
+
|
| 101 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def rotate_half(x):
|
| 105 |
+
"""Rotates half the hidden dims of the input."""
|
| 106 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 107 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 108 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 112 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
q (`torch.Tensor`): The query tensor.
|
| 116 |
+
k (`torch.Tensor`): The key tensor.
|
| 117 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 118 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 119 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 120 |
+
Deprecated and unused.
|
| 121 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 122 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 123 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 124 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 125 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 126 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 127 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 128 |
+
Returns:
|
| 129 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 130 |
+
"""
|
| 131 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 132 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 133 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 134 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 135 |
+
return q_embed, k_embed
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class SimpleLayerNorm(nn.Module):
|
| 139 |
+
def __init__(self, dim_norm: int):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.dim_norm = dim_norm
|
| 142 |
+
self.weight = torch.nn.Parameter(torch.empty(self.dim_norm))
|
| 143 |
+
|
| 144 |
+
@torch.compile
|
| 145 |
+
def forward(self, x: torch.Tensor):
|
| 146 |
+
return x * self.weight
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class BlockFFNMLP(nn.Module):
|
| 150 |
+
def __init__(self, config: BlockFFNConfig, intermediate_size: int = None):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.config = config
|
| 153 |
+
self.hidden_size = config.hidden_size
|
| 154 |
+
self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
|
| 155 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 156 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 157 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 158 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 159 |
+
|
| 160 |
+
def forward(self, x):
|
| 161 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 162 |
+
return down_proj
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class BlockFFNRouter(nn.Module):
|
| 166 |
+
def __init__(self, config: BlockFFNConfig):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.config = config
|
| 169 |
+
self.num_experts = self.config.num_experts
|
| 170 |
+
|
| 171 |
+
if self.config.moe_router_dtype == "fp32":
|
| 172 |
+
self.router_dtype = torch.float32
|
| 173 |
+
elif self.config.moe_router_dtype == "fp64":
|
| 174 |
+
self.router_dtype = torch.float64
|
| 175 |
+
elif self.config.moe_router_dtype == "bf16":
|
| 176 |
+
self.router_dtype = torch.bfloat16
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
|
| 179 |
+
|
| 180 |
+
self.weight = torch.nn.Parameter(
|
| 181 |
+
torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor):
|
| 185 |
+
return nn.functional.linear(x.to(self.router_dtype), self.weight)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class NormSiLU(nn.Module):
|
| 189 |
+
def __init__(self, config: BlockFFNConfig):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size
|
| 192 |
+
self.activate_fn_type = config.expert_act_func
|
| 193 |
+
assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean"]
|
| 194 |
+
|
| 195 |
+
self.rms_norm = None
|
| 196 |
+
if self.activate_fn_type != "norm_silu_norms":
|
| 197 |
+
self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon)
|
| 198 |
+
self.silu = torch.nn.SiLU()
|
| 199 |
+
|
| 200 |
+
@torch.compile
|
| 201 |
+
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
assert hidden.ndim == 2
|
| 203 |
+
if self.activate_fn_type != "norm_silu_nomean":
|
| 204 |
+
hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True)
|
| 205 |
+
if self.activate_fn_type != "norm_silu_norms":
|
| 206 |
+
return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size)))
|
| 207 |
+
else:
|
| 208 |
+
return self.silu(hidden)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class BlockFFNLayer(nn.Module):
|
| 212 |
+
def __init__(self, config: BlockFFNConfig):
|
| 213 |
+
super(BlockFFNLayer, self).__init__()
|
| 214 |
+
self.config = config
|
| 215 |
+
self.num_experts, self.dim_expert, self.hidden_size = \
|
| 216 |
+
config.num_experts, config.moe_ffn_hidden_size, config.hidden_size
|
| 217 |
+
self.dim_shared_expert = config.moe_shared_expert_intermediate_size
|
| 218 |
+
self.router_norm_type = config.router_norm_type
|
| 219 |
+
|
| 220 |
+
self.moe_router = BlockFFNRouter(self.config)
|
| 221 |
+
assert config.router_act_func == "relu"
|
| 222 |
+
self.router_act = nn.ReLU()
|
| 223 |
+
if config.router_norm_type == "simple":
|
| 224 |
+
self.router_norm = SimpleLayerNorm(self.config.num_experts)
|
| 225 |
+
elif config.router_norm_type == "rms":
|
| 226 |
+
self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon)
|
| 227 |
+
else:
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
self.expert_gated = not config.expert_not_gated
|
| 231 |
+
if self.expert_gated:
|
| 232 |
+
self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
|
| 233 |
+
|
| 234 |
+
self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
|
| 235 |
+
assert config.expert_act_norm_type == "normal"
|
| 236 |
+
if config.expert_act_func == "norm_silu":
|
| 237 |
+
self.expert_act = NormSiLU(self.config)
|
| 238 |
+
elif config.expert_act_func == "silu":
|
| 239 |
+
self.expert_act = nn.SiLU()
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError
|
| 242 |
+
self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias)
|
| 243 |
+
|
| 244 |
+
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
|
| 245 |
+
if self.use_shared_expert:
|
| 246 |
+
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
|
| 247 |
+
|
| 248 |
+
self.expert_wise_scales = []
|
| 249 |
+
|
| 250 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 251 |
+
ori_shape = hidden_states.shape
|
| 252 |
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
| 253 |
+
seq_len = hidden_states.shape[0]
|
| 254 |
+
|
| 255 |
+
# router module forward
|
| 256 |
+
raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts]
|
| 257 |
+
router_score = self.router_act(raw_router_score)
|
| 258 |
+
router_score = self.router_norm(router_score)
|
| 259 |
+
|
| 260 |
+
# expert module forward
|
| 261 |
+
x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert]
|
| 262 |
+
ori_x_in = x_in
|
| 263 |
+
if self.expert_gated:
|
| 264 |
+
x_gate = self.expert_gate_proj(hidden_states)
|
| 265 |
+
x_in = x_in * self.expert_act(x_gate)
|
| 266 |
+
else:
|
| 267 |
+
x_in = self.expert_act(x_in)
|
| 268 |
+
if x_in.ndim == 3:
|
| 269 |
+
scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1)
|
| 270 |
+
else:
|
| 271 |
+
scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1)
|
| 272 |
+
output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert))
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
ori_x_in = ori_x_in.view(seq_len, self.num_experts, self.dim_expert)
|
| 276 |
+
down_proj_weight = self.expert_down_proj.weight.view(self.hidden_size, self.num_experts, self.dim_expert)
|
| 277 |
+
expert_wise_outputs = torch.einsum("sed,hed->seh", ori_x_in, down_proj_weight).transpose(0, 1).reshape(self.num_experts, seq_len * self.hidden_size)
|
| 278 |
+
expert_wise_scale = torch.norm(expert_wise_outputs, p=2, dim=1) / seq_len
|
| 279 |
+
self.expert_wise_scales.append(expert_wise_scale.tolist())
|
| 280 |
+
|
| 281 |
+
if self.use_shared_expert:
|
| 282 |
+
output = output + self.shared_experts(hidden_states)
|
| 283 |
+
return output.view(*ori_shape)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class BaseRouter(ABC, nn.Module):
|
| 287 |
+
"""Base Router class"""
|
| 288 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.config = config
|
| 291 |
+
self.num_experts = self.config.num_experts
|
| 292 |
+
|
| 293 |
+
if self.config.moe_router_dtype == "fp32":
|
| 294 |
+
self.router_dtype = torch.float32
|
| 295 |
+
elif self.config.moe_router_dtype == "fp64":
|
| 296 |
+
self.router_dtype = torch.float64
|
| 297 |
+
elif self.config.moe_router_dtype == "bf16":
|
| 298 |
+
self.router_dtype = torch.bfloat16
|
| 299 |
+
else:
|
| 300 |
+
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
|
| 301 |
+
|
| 302 |
+
self.weight = torch.nn.Parameter(
|
| 303 |
+
torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def gating(self, input: torch.Tensor):
|
| 307 |
+
return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype))
|
| 308 |
+
|
| 309 |
+
@abstractmethod
|
| 310 |
+
def routing(self, logits: torch.Tensor):
|
| 311 |
+
"""Routing function.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
logits (torch.Tensor): Logits tensor.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
|
| 318 |
+
probabilities and mapping.
|
| 319 |
+
"""
|
| 320 |
+
raise NotImplementedError("Routing function not implemented.")
|
| 321 |
+
|
| 322 |
+
@abstractmethod
|
| 323 |
+
def forward(self, input: torch.Tensor):
|
| 324 |
+
"""
|
| 325 |
+
Forward pass of the router.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
input (torch.Tensor): Input tensor.
|
| 329 |
+
"""
|
| 330 |
+
raise NotImplementedError("Forward function not implemented.")
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class TopKRouter(BaseRouter):
|
| 334 |
+
"""Route each token to the top-k experts."""
|
| 335 |
+
|
| 336 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 337 |
+
super().__init__(config)
|
| 338 |
+
self.config = config
|
| 339 |
+
self.topk = self.config.moe_router_topk
|
| 340 |
+
self.score_function = self.config.moe_router_score_function
|
| 341 |
+
self.use_pre_softmax = self.config.moe_router_pre_softmax
|
| 342 |
+
self.scaling_factor = self.config.moe_router_topk_scaling_factor
|
| 343 |
+
|
| 344 |
+
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
|
| 345 |
+
if self.enable_expert_bias:
|
| 346 |
+
self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
|
| 347 |
+
else:
|
| 348 |
+
self.expert_bias = None
|
| 349 |
+
|
| 350 |
+
def _maintain_float32_expert_bias(self):
|
| 351 |
+
"""
|
| 352 |
+
Maintain the expert bias in float32.
|
| 353 |
+
|
| 354 |
+
When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
|
| 355 |
+
We keep it in float32 to avoid routing errors when updating the expert_bias.
|
| 356 |
+
"""
|
| 357 |
+
if hasattr(self, 'expert_bias') and self.expert_bias is not None:
|
| 358 |
+
if self.expert_bias.dtype != torch.float32:
|
| 359 |
+
self.expert_bias.data = self.expert_bias.data.to(torch.float32)
|
| 360 |
+
|
| 361 |
+
def routing(self, logits: torch.Tensor):
|
| 362 |
+
"""Top-k routing function
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 369 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 370 |
+
with shape [num_tokens, num_experts].
|
| 371 |
+
"""
|
| 372 |
+
logits = logits.view(-1, self.num_experts)
|
| 373 |
+
|
| 374 |
+
if self.score_function == "softmax":
|
| 375 |
+
if self.use_pre_softmax:
|
| 376 |
+
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
|
| 377 |
+
probs, top_indices = torch.topk(scores, k=self.topk, dim=1)
|
| 378 |
+
else:
|
| 379 |
+
scores, top_indices = torch.topk(logits, k=self.topk, dim=1)
|
| 380 |
+
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
|
| 381 |
+
elif self.score_function == "sigmoid":
|
| 382 |
+
scores = torch.sigmoid(logits.float()).type_as(logits)
|
| 383 |
+
if self.expert_bias is not None:
|
| 384 |
+
scores_for_routing = scores + self.expert_bias
|
| 385 |
+
_, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1)
|
| 386 |
+
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
|
| 387 |
+
else:
|
| 388 |
+
scores, top_indices = torch.topk(scores, k=self.topk, dim=1)
|
| 389 |
+
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
|
| 390 |
+
else:
|
| 391 |
+
raise ValueError(f"Invalid score_function: {self.score_function}")
|
| 392 |
+
|
| 393 |
+
if self.scaling_factor:
|
| 394 |
+
probs = probs * self.scaling_factor
|
| 395 |
+
|
| 396 |
+
return probs, top_indices
|
| 397 |
+
|
| 398 |
+
def forward(self, input: torch.Tensor):
|
| 399 |
+
"""
|
| 400 |
+
Forward pass of the router.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
input (torch.Tensor): Input tensor.
|
| 404 |
+
"""
|
| 405 |
+
self._maintain_float32_expert_bias()
|
| 406 |
+
logits = self.gating(input)
|
| 407 |
+
top_scores, top_indices = self.routing(logits)
|
| 408 |
+
return top_scores, top_indices
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class ReMoERouter(BaseRouter):
|
| 412 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 413 |
+
super().__init__(config)
|
| 414 |
+
self.config = config
|
| 415 |
+
self.router_act = torch.nn.ReLU()
|
| 416 |
+
|
| 417 |
+
def routing(self, logits: torch.Tensor):
|
| 418 |
+
"""Top-k routing function
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 425 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 426 |
+
with shape [num_tokens, num_experts].
|
| 427 |
+
"""
|
| 428 |
+
logits = logits.view(-1, self.num_experts)
|
| 429 |
+
|
| 430 |
+
router_score = self.router_act(logits)
|
| 431 |
+
routing_map = router_score > 0
|
| 432 |
+
|
| 433 |
+
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
|
| 434 |
+
sorted_map = sorted_probs <= 0
|
| 435 |
+
sorted_indices = torch.where(sorted_map, -1, sorted_indices)
|
| 436 |
+
max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1)
|
| 437 |
+
assert torch.all(sorted_map[:, max_valid_num:])
|
| 438 |
+
sorted_probs = sorted_probs[:, :max_valid_num]
|
| 439 |
+
sorted_indices = sorted_indices[:, :max_valid_num]
|
| 440 |
+
assert torch.sum(routing_map) == torch.sum(sorted_indices != -1)
|
| 441 |
+
return sorted_probs, sorted_indices
|
| 442 |
+
|
| 443 |
+
def forward(self, input: torch.Tensor):
|
| 444 |
+
"""
|
| 445 |
+
Forward pass of the router.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
input (torch.Tensor): Input tensor.
|
| 449 |
+
"""
|
| 450 |
+
logits = self.gating(input)
|
| 451 |
+
top_scores, top_indices = self.routing(logits)
|
| 452 |
+
return top_scores, top_indices
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class TopPRouter(BaseRouter):
|
| 456 |
+
def __init__(self, config: BlockFFNConfig) -> None:
|
| 457 |
+
super().__init__(config)
|
| 458 |
+
self.config = config
|
| 459 |
+
self.top_p = config.moe_router_topp
|
| 460 |
+
|
| 461 |
+
def routing(self, logits: torch.Tensor):
|
| 462 |
+
"""Top-k routing function
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
logits (torch.Tensor): Logits tensor after gating.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
probs (torch.Tensor): The probabilities of token to experts assignment.
|
| 469 |
+
routing_map (torch.Tensor): The mapping of token to experts assignment,
|
| 470 |
+
with shape [num_tokens, num_experts].
|
| 471 |
+
"""
|
| 472 |
+
logits = logits.view(-1, self.num_experts)
|
| 473 |
+
|
| 474 |
+
router_score = torch.abs(logits)
|
| 475 |
+
router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20)
|
| 476 |
+
|
| 477 |
+
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
|
| 478 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 479 |
+
mask = cumulative_probs > self.top_p
|
| 480 |
+
|
| 481 |
+
threshold_indices = mask.long().argmax(dim=-1)
|
| 482 |
+
threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
|
| 483 |
+
|
| 484 |
+
mask = mask & ~threshold_mask
|
| 485 |
+
sorted_indices = torch.where(mask, -1, sorted_indices)
|
| 486 |
+
sorted_probs = torch.where(mask, 0.0, sorted_probs)
|
| 487 |
+
|
| 488 |
+
max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1)
|
| 489 |
+
assert torch.all(mask[:, max_valid_num:])
|
| 490 |
+
|
| 491 |
+
sorted_indices = sorted_indices[:, :max_valid_num]
|
| 492 |
+
sorted_probs = sorted_probs[:, :max_valid_num]
|
| 493 |
+
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
|
| 494 |
+
return sorted_probs, sorted_indices
|
| 495 |
+
|
| 496 |
+
def forward(self, input: torch.Tensor):
|
| 497 |
+
"""
|
| 498 |
+
Forward pass of the router.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
input (torch.Tensor): Input tensor.
|
| 502 |
+
"""
|
| 503 |
+
logits = self.gating(input)
|
| 504 |
+
top_scores, top_indices = self.routing(logits)
|
| 505 |
+
return top_scores, top_indices
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class FastTopKCalculator:
|
| 509 |
+
def __init__(self, num_experts: int):
|
| 510 |
+
self.num_experts = num_experts
|
| 511 |
+
|
| 512 |
+
def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module):
|
| 513 |
+
(
|
| 514 |
+
pos,
|
| 515 |
+
local_expert_count,
|
| 516 |
+
global_expert_count,
|
| 517 |
+
fwd_expert_count,
|
| 518 |
+
fwd_batch_size,
|
| 519 |
+
) = prepare_forward(topk_indices, self.num_experts, 1)
|
| 520 |
+
topk = 1
|
| 521 |
+
if len(topk_indices.shape) == 2:
|
| 522 |
+
topk = topk_indices.shape[1]
|
| 523 |
+
|
| 524 |
+
def scatter_func(tensor):
|
| 525 |
+
return MOEScatter.apply(
|
| 526 |
+
tensor,
|
| 527 |
+
torch.div(pos, topk, rounding_mode='floor'),
|
| 528 |
+
local_expert_count,
|
| 529 |
+
global_expert_count,
|
| 530 |
+
fwd_batch_size,
|
| 531 |
+
1,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
x = tree.map_structure(scatter_func, hidden_states)
|
| 535 |
+
x = experts(x, fwd_expert_count, topk_indices=topk_indices)
|
| 536 |
+
|
| 537 |
+
out_batch_size = tree.flatten(hidden_states)[0].shape[0]
|
| 538 |
+
if len(topk_indices.shape) == 2:
|
| 539 |
+
out_batch_size *= topk_indices.shape[1]
|
| 540 |
+
|
| 541 |
+
def gather_func(tensor):
|
| 542 |
+
return MOEGather.apply(
|
| 543 |
+
tensor,
|
| 544 |
+
pos,
|
| 545 |
+
local_expert_count,
|
| 546 |
+
global_expert_count,
|
| 547 |
+
out_batch_size,
|
| 548 |
+
1,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
outp = tree.map_structure(gather_func, x)
|
| 552 |
+
return outp
|
| 553 |
+
|
| 554 |
+
def forward(self, hidden_states, topk_indices, topk_weights, experts):
|
| 555 |
+
assert topk_indices.shape == topk_weights.shape
|
| 556 |
+
top_k = topk_indices.shape[-1]
|
| 557 |
+
dim3 = hidden_states.ndim == 3
|
| 558 |
+
if dim3:
|
| 559 |
+
batch_size, seq_len, dim = hidden_states.shape
|
| 560 |
+
hidden_states = hidden_states.view(batch_size * seq_len, dim)
|
| 561 |
+
else:
|
| 562 |
+
assert hidden_states.ndim == 2
|
| 563 |
+
batch_size, (seq_len, dim) = -1, hidden_states.shape
|
| 564 |
+
fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts)
|
| 565 |
+
|
| 566 |
+
def view_func(tensor):
|
| 567 |
+
n_dim = tensor.shape[-1]
|
| 568 |
+
tensor = tensor.view(-1, top_k, n_dim)
|
| 569 |
+
return tensor
|
| 570 |
+
|
| 571 |
+
moe_output = tree.map_structure(view_func, fwd)
|
| 572 |
+
topk_weights = topk_weights.unsqueeze(1)
|
| 573 |
+
|
| 574 |
+
def bmm_func(tensor):
|
| 575 |
+
n_dim = tensor.shape[-1]
|
| 576 |
+
tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim)
|
| 577 |
+
return tensor
|
| 578 |
+
|
| 579 |
+
moe_output = tree.map_structure(bmm_func, moe_output)
|
| 580 |
+
if dim3:
|
| 581 |
+
moe_output = moe_output.view(batch_size, seq_len, -1)
|
| 582 |
+
return moe_output
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class MoELinearExperts(nn.Module):
|
| 586 |
+
def __init__(
|
| 587 |
+
self,
|
| 588 |
+
dim_in: int,
|
| 589 |
+
dim_out: int,
|
| 590 |
+
num_experts: int,
|
| 591 |
+
ffn_bias: bool,
|
| 592 |
+
):
|
| 593 |
+
super().__init__()
|
| 594 |
+
self.dim_in = self.in_features = dim_in
|
| 595 |
+
self.dim_out = self.out_features = dim_out
|
| 596 |
+
self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
|
| 597 |
+
self.bias = None
|
| 598 |
+
if ffn_bias:
|
| 599 |
+
self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out))
|
| 600 |
+
|
| 601 |
+
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor):
|
| 602 |
+
x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias)
|
| 603 |
+
return x
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class MoEGatedExperts(nn.Module):
|
| 607 |
+
def __init__(
|
| 608 |
+
self,
|
| 609 |
+
dim_in: int,
|
| 610 |
+
dim_ff: int,
|
| 611 |
+
is_gated: bool,
|
| 612 |
+
act_name: str,
|
| 613 |
+
num_experts: int,
|
| 614 |
+
ffn_bias: bool = False,
|
| 615 |
+
):
|
| 616 |
+
super().__init__()
|
| 617 |
+
self.is_gated = is_gated
|
| 618 |
+
self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts
|
| 619 |
+
if self.is_gated:
|
| 620 |
+
self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
|
| 621 |
+
self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
|
| 622 |
+
self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias)
|
| 623 |
+
|
| 624 |
+
self.act_fn = ACT2FN[act_name]
|
| 625 |
+
|
| 626 |
+
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 627 |
+
if self.is_gated:
|
| 628 |
+
gate_score = self.gate_proj(x, fwd_expert_count)
|
| 629 |
+
up_proj = self.up_proj(x, fwd_expert_count)
|
| 630 |
+
x = up_proj * self.act_fn(gate_score)
|
| 631 |
+
else:
|
| 632 |
+
up_score = self.up_proj(x, fwd_expert_count)
|
| 633 |
+
x = self.act_fn(up_score)
|
| 634 |
+
x = self.down_proj(x, fwd_expert_count)
|
| 635 |
+
return x
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class VanillaMoELayer(nn.Module):
|
| 639 |
+
def __init__(self, config: BlockFFNConfig):
|
| 640 |
+
super(VanillaMoELayer, self).__init__()
|
| 641 |
+
self.config = config
|
| 642 |
+
|
| 643 |
+
# Initialize router
|
| 644 |
+
if config.router_type == "topk":
|
| 645 |
+
self.router = TopKRouter(config=self.config)
|
| 646 |
+
elif config.router_type == "remoe":
|
| 647 |
+
self.router = ReMoERouter(config=self.config)
|
| 648 |
+
elif config.router_type == "topp":
|
| 649 |
+
self.router = TopPRouter(config=self.config)
|
| 650 |
+
else:
|
| 651 |
+
raise NotImplementedError(f"Router type {config.router_type} not implemented.")
|
| 652 |
+
|
| 653 |
+
self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts)
|
| 654 |
+
|
| 655 |
+
# Initialize experts
|
| 656 |
+
self.experts = MoEGatedExperts(
|
| 657 |
+
dim_in=self.config.hidden_size,
|
| 658 |
+
dim_ff=self.config.moe_ffn_hidden_size,
|
| 659 |
+
is_gated=not self.config.expert_not_gated,
|
| 660 |
+
act_name="silu",
|
| 661 |
+
num_experts=self.config.num_experts,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size
|
| 665 |
+
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
|
| 666 |
+
if self.use_shared_expert:
|
| 667 |
+
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
|
| 668 |
+
|
| 669 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 670 |
+
top_scores, top_indices = self.router(hidden_states)
|
| 671 |
+
y = self.mix_calculator.forward(
|
| 672 |
+
hidden_states=hidden_states,
|
| 673 |
+
topk_indices=top_indices.contiguous(),
|
| 674 |
+
topk_weights=top_scores.type_as(hidden_states),
|
| 675 |
+
experts=self.experts,
|
| 676 |
+
)
|
| 677 |
+
if self.shared_experts is not None:
|
| 678 |
+
y = y + self.shared_experts(hidden_states)
|
| 679 |
+
return y
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 683 |
+
"""
|
| 684 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 685 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 686 |
+
"""
|
| 687 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 688 |
+
if n_rep == 1:
|
| 689 |
+
return hidden_states
|
| 690 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 691 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def eager_attention_forward(
|
| 695 |
+
module: nn.Module,
|
| 696 |
+
query: torch.Tensor,
|
| 697 |
+
key: torch.Tensor,
|
| 698 |
+
value: torch.Tensor,
|
| 699 |
+
attention_mask: Optional[torch.Tensor],
|
| 700 |
+
scaling: float,
|
| 701 |
+
dropout: float = 0.0,
|
| 702 |
+
):
|
| 703 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 704 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 705 |
+
|
| 706 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 707 |
+
if attention_mask is not None:
|
| 708 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 709 |
+
attn_weights = attn_weights + causal_mask
|
| 710 |
+
|
| 711 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 712 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 713 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 714 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 715 |
+
|
| 716 |
+
return attn_output, attn_weights
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
class BlockFFNAttention(nn.Module):
|
| 720 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 721 |
+
|
| 722 |
+
def __init__(self, config: BlockFFNConfig, layer_idx: int):
|
| 723 |
+
super().__init__()
|
| 724 |
+
self.config = config
|
| 725 |
+
self.layer_idx = layer_idx
|
| 726 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 727 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
|
| 728 |
+
self.scaling = self.head_dim**-0.5
|
| 729 |
+
self.attention_dropout = config.attention_dropout
|
| 730 |
+
self.is_causal = True
|
| 731 |
+
|
| 732 |
+
self.q_proj = nn.Linear(
|
| 733 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 734 |
+
)
|
| 735 |
+
self.k_proj = nn.Linear(
|
| 736 |
+
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
|
| 737 |
+
)
|
| 738 |
+
self.v_proj = nn.Linear(
|
| 739 |
+
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
|
| 740 |
+
)
|
| 741 |
+
self.o_proj = nn.Linear(
|
| 742 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
def forward(
|
| 746 |
+
self,
|
| 747 |
+
hidden_states: torch.Tensor,
|
| 748 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 749 |
+
attention_mask: Optional[torch.Tensor],
|
| 750 |
+
past_key_value: Optional[Cache] = None,
|
| 751 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 752 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 753 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 754 |
+
input_shape = hidden_states.shape[:-1]
|
| 755 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 756 |
+
|
| 757 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 758 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 759 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 760 |
+
|
| 761 |
+
cos, sin = position_embeddings
|
| 762 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 763 |
+
|
| 764 |
+
if past_key_value is not None:
|
| 765 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 766 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 767 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 768 |
+
|
| 769 |
+
attention_interface: Callable = eager_attention_forward
|
| 770 |
+
if self.config._attn_implementation != "eager":
|
| 771 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 772 |
+
|
| 773 |
+
attn_output, attn_weights = attention_interface(
|
| 774 |
+
self,
|
| 775 |
+
query_states,
|
| 776 |
+
key_states,
|
| 777 |
+
value_states,
|
| 778 |
+
attention_mask,
|
| 779 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 780 |
+
scaling=self.scaling,
|
| 781 |
+
**kwargs,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 785 |
+
attn_output = self.o_proj(attn_output)
|
| 786 |
+
return attn_output, attn_weights
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
class BlockFFNDecoderLayer(GradientCheckpointingLayer):
|
| 790 |
+
def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool):
|
| 791 |
+
super().__init__()
|
| 792 |
+
self.config = config
|
| 793 |
+
self.hidden_size = config.hidden_size
|
| 794 |
+
|
| 795 |
+
self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx)
|
| 796 |
+
|
| 797 |
+
if is_moe_layer:
|
| 798 |
+
if config.use_blockffn:
|
| 799 |
+
self.mlp = BlockFFNLayer(config)
|
| 800 |
+
elif config.router_type in ["topk", "remoe", "topp"]:
|
| 801 |
+
self.mlp = VanillaMoELayer(config)
|
| 802 |
+
else:
|
| 803 |
+
raise NotImplementedError
|
| 804 |
+
else:
|
| 805 |
+
self.mlp = BlockFFNMLP(config)
|
| 806 |
+
self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 807 |
+
self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 808 |
+
|
| 809 |
+
def forward(
|
| 810 |
+
self,
|
| 811 |
+
hidden_states: torch.Tensor,
|
| 812 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 813 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 814 |
+
past_key_value: Optional[Cache] = None,
|
| 815 |
+
use_cache: Optional[bool] = False,
|
| 816 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 817 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 818 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 819 |
+
) -> tuple[torch.Tensor]:
|
| 820 |
+
residual = hidden_states
|
| 821 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 822 |
+
# Self Attention
|
| 823 |
+
hidden_states, _ = self.self_attn(
|
| 824 |
+
hidden_states=hidden_states,
|
| 825 |
+
attention_mask=attention_mask,
|
| 826 |
+
position_ids=position_ids,
|
| 827 |
+
past_key_value=past_key_value,
|
| 828 |
+
use_cache=use_cache,
|
| 829 |
+
cache_position=cache_position,
|
| 830 |
+
position_embeddings=position_embeddings,
|
| 831 |
+
**kwargs,
|
| 832 |
+
)
|
| 833 |
+
if self.config.use_mup:
|
| 834 |
+
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
|
| 835 |
+
else:
|
| 836 |
+
hidden_states = residual + hidden_states
|
| 837 |
+
|
| 838 |
+
# Fully Connected
|
| 839 |
+
residual = hidden_states
|
| 840 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 841 |
+
hidden_states = self.mlp(hidden_states)
|
| 842 |
+
if self.config.use_mup:
|
| 843 |
+
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
|
| 844 |
+
else:
|
| 845 |
+
hidden_states = residual + hidden_states
|
| 846 |
+
return hidden_states
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
@auto_docstring
|
| 850 |
+
class BlockFFNPreTrainedModel(PreTrainedModel):
|
| 851 |
+
config: BlockFFNConfig
|
| 852 |
+
base_model_prefix = "model"
|
| 853 |
+
supports_gradient_checkpointing = True
|
| 854 |
+
_no_split_modules = ["BlockFFNDecoderLayer"]
|
| 855 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 856 |
+
_supports_flash_attn = True
|
| 857 |
+
_supports_sdpa = True
|
| 858 |
+
_supports_flex_attn = True
|
| 859 |
+
|
| 860 |
+
_can_compile_fullgraph = True
|
| 861 |
+
_supports_attention_backend = True
|
| 862 |
+
_can_record_outputs = {
|
| 863 |
+
"hidden_states": BlockFFNDecoderLayer,
|
| 864 |
+
"attentions": BlockFFNAttention,
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
@auto_docstring
|
| 869 |
+
class BlockFFNModel(BlockFFNPreTrainedModel):
|
| 870 |
+
def __init__(self, config: BlockFFNConfig):
|
| 871 |
+
super().__init__(config)
|
| 872 |
+
self.config = config
|
| 873 |
+
self.padding_idx = config.pad_token_id
|
| 874 |
+
self.vocab_size = config.vocab_size
|
| 875 |
+
|
| 876 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 877 |
+
self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq
|
| 878 |
+
assert len(self.moe_layer_freq) == config.num_layers
|
| 879 |
+
self.layers = nn.ModuleList(
|
| 880 |
+
[BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)]
|
| 881 |
+
)
|
| 882 |
+
self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
|
| 883 |
+
self.rotary_emb = BlockFFNRotaryEmbedding(config=config)
|
| 884 |
+
self.gradient_checkpointing = False
|
| 885 |
+
|
| 886 |
+
# Initialize weights and apply final processing
|
| 887 |
+
self.post_init()
|
| 888 |
+
|
| 889 |
+
@check_model_inputs
|
| 890 |
+
@auto_docstring
|
| 891 |
+
def forward(
|
| 892 |
+
self,
|
| 893 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 894 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 895 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 896 |
+
past_key_values: Optional[Cache] = None,
|
| 897 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 898 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 899 |
+
use_cache: Optional[bool] = None,
|
| 900 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 901 |
+
) -> BaseModelOutputWithPast:
|
| 902 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 903 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 904 |
+
|
| 905 |
+
if inputs_embeds is None:
|
| 906 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 907 |
+
if self.config.use_mup:
|
| 908 |
+
inputs_embeds = inputs_embeds * self.config.mup_emb_scale
|
| 909 |
+
|
| 910 |
+
if use_cache and past_key_values is None:
|
| 911 |
+
past_key_values = DynamicCache()
|
| 912 |
+
|
| 913 |
+
if cache_position is None:
|
| 914 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 915 |
+
cache_position: torch.Tensor = torch.arange(
|
| 916 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
if position_ids is None:
|
| 920 |
+
position_ids = cache_position.unsqueeze(0)
|
| 921 |
+
|
| 922 |
+
causal_mask = create_causal_mask(
|
| 923 |
+
config=self.config,
|
| 924 |
+
input_embeds=inputs_embeds,
|
| 925 |
+
attention_mask=attention_mask,
|
| 926 |
+
cache_position=cache_position,
|
| 927 |
+
past_key_values=past_key_values,
|
| 928 |
+
position_ids=position_ids,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
hidden_states = inputs_embeds
|
| 932 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 933 |
+
|
| 934 |
+
for decoder_layer in self.layers[: self.config.num_layers]:
|
| 935 |
+
hidden_states = decoder_layer(
|
| 936 |
+
hidden_states,
|
| 937 |
+
attention_mask=causal_mask,
|
| 938 |
+
position_ids=position_ids,
|
| 939 |
+
past_key_value=past_key_values,
|
| 940 |
+
cache_position=cache_position,
|
| 941 |
+
position_embeddings=position_embeddings,
|
| 942 |
+
**kwargs,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
hidden_states = self.norm(hidden_states)
|
| 946 |
+
return BaseModelOutputWithPast(
|
| 947 |
+
last_hidden_state=hidden_states,
|
| 948 |
+
past_key_values=past_key_values,
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
@auto_docstring
|
| 953 |
+
class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin):
|
| 954 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 955 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 956 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 957 |
+
|
| 958 |
+
def __init__(self, config: BlockFFNConfig):
|
| 959 |
+
super().__init__(config)
|
| 960 |
+
self.config = config
|
| 961 |
+
self.model = BlockFFNModel(config)
|
| 962 |
+
self.vocab_size = config.vocab_size
|
| 963 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 964 |
+
|
| 965 |
+
# Initialize weights and apply final processing
|
| 966 |
+
self.post_init()
|
| 967 |
+
|
| 968 |
+
def set_decoder(self, decoder):
|
| 969 |
+
self.model = decoder
|
| 970 |
+
|
| 971 |
+
def get_decoder(self):
|
| 972 |
+
return self.model
|
| 973 |
+
|
| 974 |
+
@can_return_tuple
|
| 975 |
+
@auto_docstring
|
| 976 |
+
def forward(
|
| 977 |
+
self,
|
| 978 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 979 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 980 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 981 |
+
past_key_values: Optional[Cache] = None,
|
| 982 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 983 |
+
labels: Optional[torch.LongTensor] = None,
|
| 984 |
+
use_cache: Optional[bool] = None,
|
| 985 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 986 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 987 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 988 |
+
) -> CausalLMOutputWithPast:
|
| 989 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 990 |
+
input_ids=input_ids,
|
| 991 |
+
attention_mask=attention_mask,
|
| 992 |
+
position_ids=position_ids,
|
| 993 |
+
past_key_values=past_key_values,
|
| 994 |
+
inputs_embeds=inputs_embeds,
|
| 995 |
+
use_cache=use_cache,
|
| 996 |
+
cache_position=cache_position,
|
| 997 |
+
**kwargs,
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
hidden_states = outputs.last_hidden_state
|
| 1001 |
+
if self.config.use_mup:
|
| 1002 |
+
hidden_states = hidden_states / self.config.mup_width_scale
|
| 1003 |
+
|
| 1004 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1005 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1006 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1007 |
+
|
| 1008 |
+
loss = None
|
| 1009 |
+
if labels is not None:
|
| 1010 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1011 |
+
|
| 1012 |
+
return CausalLMOutputWithPast(
|
| 1013 |
+
loss=loss,
|
| 1014 |
+
logits=logits,
|
| 1015 |
+
past_key_values=outputs.past_key_values,
|
| 1016 |
+
hidden_states=outputs.hidden_states,
|
| 1017 |
+
attentions=outputs.attentions,
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
__all__ = [
|
| 1021 |
+
"BlockFFNForCausalLM",
|
| 1022 |
+
"BlockFFNModel",
|
| 1023 |
+
"BlockFFNPreTrainedModel",
|
| 1024 |
+
]
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74b0e11eedf5a22c1bf2c1a55805f9fbbc4a859404d3530cddd6ed16f8292166
|
| 3 |
+
size 785588525
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<|im_end|>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"content": "<|im_start|>",
|
| 12 |
+
"lstrip": false,
|
| 13 |
+
"normalized": false,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"single_word": false
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"content": "<|tool_call|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"content": "<|execute_start|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"content": "<|execute_end|>",
|
| 33 |
+
"lstrip": false,
|
| 34 |
+
"normalized": false,
|
| 35 |
+
"rstrip": false,
|
| 36 |
+
"single_word": false
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"content": "<|fim_prefix|>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"content": "<|fim_middle|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"content": "<|fim_suffix|>",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false
|
| 58 |
+
}
|
| 59 |
+
],
|
| 60 |
+
"bos_token": {
|
| 61 |
+
"content": "<s>",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false
|
| 66 |
+
},
|
| 67 |
+
"eos_token": {
|
| 68 |
+
"content": "</s>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": false,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false
|
| 73 |
+
},
|
| 74 |
+
"unk_token": {
|
| 75 |
+
"content": "<unk>",
|
| 76 |
+
"lstrip": false,
|
| 77 |
+
"normalized": false,
|
| 78 |
+
"rstrip": false,
|
| 79 |
+
"single_word": false
|
| 80 |
+
}
|
| 81 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
|
| 3 |
+
size 1181204
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"0": {
|
| 6 |
+
"content": "<unk>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"1": {
|
| 14 |
+
"content": "<s>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"2": {
|
| 22 |
+
"content": "</s>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"73440": {
|
| 30 |
+
"content": "<|im_end|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"73441": {
|
| 38 |
+
"content": "<|im_start|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"73442": {
|
| 46 |
+
"content": "<|tool_call|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"73443": {
|
| 54 |
+
"content": "<|execute_start|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"73444": {
|
| 62 |
+
"content": "<|execute_end|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"73445": {
|
| 70 |
+
"content": "<|fim_prefix|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"73446": {
|
| 78 |
+
"content": "<|fim_middle|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"73447": {
|
| 86 |
+
"content": "<|fim_suffix|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
"additional_special_tokens": [
|
| 95 |
+
"<|im_end|>",
|
| 96 |
+
"<|im_start|>",
|
| 97 |
+
"<|tool_call|>",
|
| 98 |
+
"<|execute_start|>",
|
| 99 |
+
"<|execute_end|>",
|
| 100 |
+
"<|fim_prefix|>",
|
| 101 |
+
"<|fim_middle|>",
|
| 102 |
+
"<|fim_suffix|>"
|
| 103 |
+
],
|
| 104 |
+
"bos_token": "<s>",
|
| 105 |
+
"clean_up_tokenization_spaces": false,
|
| 106 |
+
"eos_token": "<|im_end|>",
|
| 107 |
+
"legacy": true,
|
| 108 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 109 |
+
"pad_token": null,
|
| 110 |
+
"sp_model_kwargs": {},
|
| 111 |
+
"spaces_between_special_tokens": false,
|
| 112 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 113 |
+
"unk_token": "<unk>",
|
| 114 |
+
"use_default_system_prompt": false,
|
| 115 |
+
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 116 |
+
}
|