Update modeling_telechat3.py
Browse files- modeling_telechat3.py +33 -33
modeling_telechat3.py
CHANGED
|
@@ -44,7 +44,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 44 |
from transformers.processing_utils import Unpack
|
| 45 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
| 46 |
|
| 47 |
-
from .configuration_telechat3 import
|
| 48 |
|
| 49 |
logger = logging.get_logger(__name__)
|
| 50 |
|
|
@@ -152,10 +152,10 @@ ROPE_INIT_FUNCTIONS['telechat3-yarn'] = _compute_telechat_yarn_parameters
|
|
| 152 |
|
| 153 |
|
| 154 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 155 |
-
class
|
| 156 |
def __init__(self, hidden_size, eps=1e-6):
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
"""
|
| 160 |
super().__init__()
|
| 161 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
@@ -172,8 +172,8 @@ class Telechat3RMSNorm(nn.Module):
|
|
| 172 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 173 |
|
| 174 |
|
| 175 |
-
class
|
| 176 |
-
def __init__(self, config:
|
| 177 |
super().__init__()
|
| 178 |
# BC: "rope_type" was originally "type"
|
| 179 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
@@ -240,7 +240,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
| 240 |
return q_embed, k_embed
|
| 241 |
|
| 242 |
|
| 243 |
-
class
|
| 244 |
def __init__(self, config):
|
| 245 |
super().__init__()
|
| 246 |
self.config = config
|
|
@@ -294,10 +294,10 @@ def eager_attention_forward(
|
|
| 294 |
return attn_output, attn_weights
|
| 295 |
|
| 296 |
|
| 297 |
-
class
|
| 298 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 299 |
|
| 300 |
-
def __init__(self, config:
|
| 301 |
super().__init__()
|
| 302 |
self.config = config
|
| 303 |
self.layer_idx = layer_idx
|
|
@@ -364,16 +364,16 @@ class Telechat3Attention(nn.Module):
|
|
| 364 |
return attn_output, attn_weights
|
| 365 |
|
| 366 |
|
| 367 |
-
class
|
| 368 |
-
def __init__(self, config:
|
| 369 |
super().__init__()
|
| 370 |
self.hidden_size = config.hidden_size
|
| 371 |
|
| 372 |
-
self.self_attn =
|
| 373 |
|
| 374 |
-
self.mlp =
|
| 375 |
-
self.input_layernorm =
|
| 376 |
-
self.post_attention_layernorm =
|
| 377 |
|
| 378 |
def forward(
|
| 379 |
self,
|
|
@@ -418,11 +418,11 @@ class Telechat3DecoderLayer(GradientCheckpointingLayer):
|
|
| 418 |
|
| 419 |
|
| 420 |
@auto_docstring
|
| 421 |
-
class
|
| 422 |
-
config_class =
|
| 423 |
base_model_prefix = "model"
|
| 424 |
supports_gradient_checkpointing = True
|
| 425 |
-
_no_split_modules = ["
|
| 426 |
_skip_keys_device_placement = ["past_key_values"]
|
| 427 |
_supports_flash_attn_3 = True
|
| 428 |
_supports_flash_attn_2 = True
|
|
@@ -443,23 +443,23 @@ class Telechat3PreTrainedModel(PreTrainedModel):
|
|
| 443 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 444 |
if module.padding_idx is not None:
|
| 445 |
module.weight.data[module.padding_idx].zero_()
|
| 446 |
-
elif isinstance(module,
|
| 447 |
module.weight.data.fill_(1.0)
|
| 448 |
|
| 449 |
|
| 450 |
@auto_docstring
|
| 451 |
-
class
|
| 452 |
-
def __init__(self, config:
|
| 453 |
super().__init__(config)
|
| 454 |
self.padding_idx = config.pad_token_id
|
| 455 |
self.vocab_size = config.vocab_size
|
| 456 |
|
| 457 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 458 |
self.layers = nn.ModuleList(
|
| 459 |
-
[
|
| 460 |
)
|
| 461 |
-
self.norm =
|
| 462 |
-
self.rotary_emb =
|
| 463 |
self.gradient_checkpointing = False
|
| 464 |
|
| 465 |
# Initialize weights and apply final processing
|
|
@@ -577,14 +577,14 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
| 577 |
|
| 578 |
|
| 579 |
@auto_docstring
|
| 580 |
-
class
|
| 581 |
_tied_weights_keys = ["lm_head.weight"]
|
| 582 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 583 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 584 |
|
| 585 |
def __init__(self, config):
|
| 586 |
super().__init__(config)
|
| 587 |
-
self.model =
|
| 588 |
self.vocab_size = config.vocab_size
|
| 589 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 590 |
|
|
@@ -666,9 +666,9 @@ class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
|
|
| 666 |
|
| 667 |
@auto_docstring(
|
| 668 |
custom_intro="""
|
| 669 |
-
The
|
| 670 |
|
| 671 |
-
[`
|
| 672 |
(e.g. GPT-2) do.
|
| 673 |
|
| 674 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
@@ -678,11 +678,11 @@ class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
|
|
| 678 |
each row of the batch).
|
| 679 |
"""
|
| 680 |
)
|
| 681 |
-
class
|
| 682 |
def __init__(self, config):
|
| 683 |
super().__init__(config)
|
| 684 |
self.num_labels = config.num_labels
|
| 685 |
-
self.model =
|
| 686 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 687 |
|
| 688 |
# Initialize weights and apply final processing
|
|
@@ -765,13 +765,13 @@ class Telechat3ForSequenceClassification(Telechat3PreTrainedModel):
|
|
| 765 |
|
| 766 |
|
| 767 |
@auto_docstring
|
| 768 |
-
class
|
| 769 |
base_model_prefix = "transformer"
|
| 770 |
|
| 771 |
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Telechat3
|
| 772 |
def __init__(self, config):
|
| 773 |
super().__init__(config)
|
| 774 |
-
self.transformer =
|
| 775 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 776 |
|
| 777 |
# Initialize weights and apply final processing
|
|
@@ -829,11 +829,11 @@ class Telechat3ForQuestionAnswering(Telechat3PreTrainedModel):
|
|
| 829 |
|
| 830 |
|
| 831 |
@auto_docstring
|
| 832 |
-
class
|
| 833 |
def __init__(self, config):
|
| 834 |
super().__init__(config)
|
| 835 |
self.num_labels = config.num_labels
|
| 836 |
-
self.model =
|
| 837 |
if getattr(config, "classifier_dropout", None) is not None:
|
| 838 |
classifier_dropout = config.classifier_dropout
|
| 839 |
elif getattr(config, "hidden_dropout", None) is not None:
|
|
|
|
| 44 |
from transformers.processing_utils import Unpack
|
| 45 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
| 46 |
|
| 47 |
+
from .configuration_telechat3 import TeleChat3Config
|
| 48 |
|
| 49 |
logger = logging.get_logger(__name__)
|
| 50 |
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 155 |
+
class TeleChat3RMSNorm(nn.Module):
|
| 156 |
def __init__(self, hidden_size, eps=1e-6):
|
| 157 |
"""
|
| 158 |
+
TeleChat3RMSNorm is equivalent to T5LayerNorm
|
| 159 |
"""
|
| 160 |
super().__init__()
|
| 161 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
|
| 172 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 173 |
|
| 174 |
|
| 175 |
+
class TeleChat3RotaryEmbedding(nn.Module):
|
| 176 |
+
def __init__(self, config: TeleChat3Config, device=None):
|
| 177 |
super().__init__()
|
| 178 |
# BC: "rope_type" was originally "type"
|
| 179 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
|
|
| 240 |
return q_embed, k_embed
|
| 241 |
|
| 242 |
|
| 243 |
+
class TeleChat3MLP(nn.Module):
|
| 244 |
def __init__(self, config):
|
| 245 |
super().__init__()
|
| 246 |
self.config = config
|
|
|
|
| 294 |
return attn_output, attn_weights
|
| 295 |
|
| 296 |
|
| 297 |
+
class TeleChat3Attention(nn.Module):
|
| 298 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 299 |
|
| 300 |
+
def __init__(self, config: TeleChat3Config, layer_idx: int):
|
| 301 |
super().__init__()
|
| 302 |
self.config = config
|
| 303 |
self.layer_idx = layer_idx
|
|
|
|
| 364 |
return attn_output, attn_weights
|
| 365 |
|
| 366 |
|
| 367 |
+
class TeleChat3DecoderLayer(GradientCheckpointingLayer):
|
| 368 |
+
def __init__(self, config: TeleChat3Config, layer_idx: int):
|
| 369 |
super().__init__()
|
| 370 |
self.hidden_size = config.hidden_size
|
| 371 |
|
| 372 |
+
self.self_attn = TeleChat3Attention(config=config, layer_idx=layer_idx)
|
| 373 |
|
| 374 |
+
self.mlp = TeleChat3MLP(config)
|
| 375 |
+
self.input_layernorm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 376 |
+
self.post_attention_layernorm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 377 |
|
| 378 |
def forward(
|
| 379 |
self,
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
@auto_docstring
|
| 421 |
+
class TeleChat3PreTrainedModel(PreTrainedModel):
|
| 422 |
+
config_class = TeleChat3Config
|
| 423 |
base_model_prefix = "model"
|
| 424 |
supports_gradient_checkpointing = True
|
| 425 |
+
_no_split_modules = ["TeleChat3DecoderLayer"]
|
| 426 |
_skip_keys_device_placement = ["past_key_values"]
|
| 427 |
_supports_flash_attn_3 = True
|
| 428 |
_supports_flash_attn_2 = True
|
|
|
|
| 443 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 444 |
if module.padding_idx is not None:
|
| 445 |
module.weight.data[module.padding_idx].zero_()
|
| 446 |
+
elif isinstance(module, TeleChat3RMSNorm):
|
| 447 |
module.weight.data.fill_(1.0)
|
| 448 |
|
| 449 |
|
| 450 |
@auto_docstring
|
| 451 |
+
class TeleChat3Model(TeleChat3PreTrainedModel):
|
| 452 |
+
def __init__(self, config: TeleChat3Config):
|
| 453 |
super().__init__(config)
|
| 454 |
self.padding_idx = config.pad_token_id
|
| 455 |
self.vocab_size = config.vocab_size
|
| 456 |
|
| 457 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 458 |
self.layers = nn.ModuleList(
|
| 459 |
+
[TeleChat3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 460 |
)
|
| 461 |
+
self.norm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 462 |
+
self.rotary_emb = TeleChat3RotaryEmbedding(config=config)
|
| 463 |
self.gradient_checkpointing = False
|
| 464 |
|
| 465 |
# Initialize weights and apply final processing
|
|
|
|
| 577 |
|
| 578 |
|
| 579 |
@auto_docstring
|
| 580 |
+
class TeleChat3ForCausalLM(TeleChat3PreTrainedModel, GenerationMixin):
|
| 581 |
_tied_weights_keys = ["lm_head.weight"]
|
| 582 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 583 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 584 |
|
| 585 |
def __init__(self, config):
|
| 586 |
super().__init__(config)
|
| 587 |
+
self.model = TeleChat3Model(config)
|
| 588 |
self.vocab_size = config.vocab_size
|
| 589 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 590 |
|
|
|
|
| 666 |
|
| 667 |
@auto_docstring(
|
| 668 |
custom_intro="""
|
| 669 |
+
The TeleChat3 Model transformer with a sequence classification head on top (linear layer).
|
| 670 |
|
| 671 |
+
[`TeleChat3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 672 |
(e.g. GPT-2) do.
|
| 673 |
|
| 674 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
|
| 678 |
each row of the batch).
|
| 679 |
"""
|
| 680 |
)
|
| 681 |
+
class TeleChat3ForSequenceClassification(TeleChat3PreTrainedModel):
|
| 682 |
def __init__(self, config):
|
| 683 |
super().__init__(config)
|
| 684 |
self.num_labels = config.num_labels
|
| 685 |
+
self.model = TeleChat3Model(config)
|
| 686 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 687 |
|
| 688 |
# Initialize weights and apply final processing
|
|
|
|
| 765 |
|
| 766 |
|
| 767 |
@auto_docstring
|
| 768 |
+
class TeleChat3ForQuestionAnswering(TeleChat3PreTrainedModel):
|
| 769 |
base_model_prefix = "transformer"
|
| 770 |
|
| 771 |
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Telechat3
|
| 772 |
def __init__(self, config):
|
| 773 |
super().__init__(config)
|
| 774 |
+
self.transformer = TeleChat3Model(config)
|
| 775 |
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 776 |
|
| 777 |
# Initialize weights and apply final processing
|
|
|
|
| 829 |
|
| 830 |
|
| 831 |
@auto_docstring
|
| 832 |
+
class TeleChat3ForTokenClassification(TeleChat3PreTrainedModel):
|
| 833 |
def __init__(self, config):
|
| 834 |
super().__init__(config)
|
| 835 |
self.num_labels = config.num_labels
|
| 836 |
+
self.model = TeleChat3Model(config)
|
| 837 |
if getattr(config, "classifier_dropout", None) is not None:
|
| 838 |
classifier_dropout = config.classifier_dropout
|
| 839 |
elif getattr(config, "hidden_dropout", None) is not None:
|