liuxz0801 commited on
Commit
1138da1
·
verified ·
1 Parent(s): f814882

Update modeling_telechat3.py

Browse files
Files changed (1) hide show
  1. modeling_telechat3.py +33 -33
modeling_telechat3.py CHANGED
@@ -44,7 +44,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
  from transformers.processing_utils import Unpack
45
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
46
 
47
- from .configuration_telechat3 import Telechat3Config
48
 
49
  logger = logging.get_logger(__name__)
50
 
@@ -152,10 +152,10 @@ ROPE_INIT_FUNCTIONS['telechat3-yarn'] = _compute_telechat_yarn_parameters
152
 
153
 
154
  @use_kernel_forward_from_hub("RMSNorm")
155
- class Telechat3RMSNorm(nn.Module):
156
  def __init__(self, hidden_size, eps=1e-6):
157
  """
158
- Telechat3RMSNorm is equivalent to T5LayerNorm
159
  """
160
  super().__init__()
161
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -172,8 +172,8 @@ class Telechat3RMSNorm(nn.Module):
172
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
173
 
174
 
175
- class Telechat3RotaryEmbedding(nn.Module):
176
- def __init__(self, config: Telechat3Config, device=None):
177
  super().__init__()
178
  # BC: "rope_type" was originally "type"
179
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -240,7 +240,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
240
  return q_embed, k_embed
241
 
242
 
243
- class Telechat3MLP(nn.Module):
244
  def __init__(self, config):
245
  super().__init__()
246
  self.config = config
@@ -294,10 +294,10 @@ def eager_attention_forward(
294
  return attn_output, attn_weights
295
 
296
 
297
- class Telechat3Attention(nn.Module):
298
  """Multi-headed attention from 'Attention Is All You Need' paper"""
299
 
300
- def __init__(self, config: Telechat3Config, layer_idx: int):
301
  super().__init__()
302
  self.config = config
303
  self.layer_idx = layer_idx
@@ -364,16 +364,16 @@ class Telechat3Attention(nn.Module):
364
  return attn_output, attn_weights
365
 
366
 
367
- class Telechat3DecoderLayer(GradientCheckpointingLayer):
368
- def __init__(self, config: Telechat3Config, layer_idx: int):
369
  super().__init__()
370
  self.hidden_size = config.hidden_size
371
 
372
- self.self_attn = Telechat3Attention(config=config, layer_idx=layer_idx)
373
 
374
- self.mlp = Telechat3MLP(config)
375
- self.input_layernorm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
- self.post_attention_layernorm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
 
378
  def forward(
379
  self,
@@ -418,11 +418,11 @@ class Telechat3DecoderLayer(GradientCheckpointingLayer):
418
 
419
 
420
  @auto_docstring
421
- class Telechat3PreTrainedModel(PreTrainedModel):
422
- config_class = Telechat3Config
423
  base_model_prefix = "model"
424
  supports_gradient_checkpointing = True
425
- _no_split_modules = ["Telechat3DecoderLayer"]
426
  _skip_keys_device_placement = ["past_key_values"]
427
  _supports_flash_attn_3 = True
428
  _supports_flash_attn_2 = True
@@ -443,23 +443,23 @@ class Telechat3PreTrainedModel(PreTrainedModel):
443
  module.weight.data.normal_(mean=0.0, std=std)
444
  if module.padding_idx is not None:
445
  module.weight.data[module.padding_idx].zero_()
446
- elif isinstance(module, Telechat3RMSNorm):
447
  module.weight.data.fill_(1.0)
448
 
449
 
450
  @auto_docstring
451
- class Telechat3Model(Telechat3PreTrainedModel):
452
- def __init__(self, config: Telechat3Config):
453
  super().__init__(config)
454
  self.padding_idx = config.pad_token_id
455
  self.vocab_size = config.vocab_size
456
 
457
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
458
  self.layers = nn.ModuleList(
459
- [Telechat3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
460
  )
461
- self.norm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
- self.rotary_emb = Telechat3RotaryEmbedding(config=config)
463
  self.gradient_checkpointing = False
464
 
465
  # Initialize weights and apply final processing
@@ -577,14 +577,14 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
577
 
578
 
579
  @auto_docstring
580
- class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
581
  _tied_weights_keys = ["lm_head.weight"]
582
  _tp_plan = {"lm_head": "colwise_rep"}
583
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
584
 
585
  def __init__(self, config):
586
  super().__init__(config)
587
- self.model = Telechat3Model(config)
588
  self.vocab_size = config.vocab_size
589
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
 
@@ -666,9 +666,9 @@ class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
666
 
667
  @auto_docstring(
668
  custom_intro="""
669
- The Telechat3 Model transformer with a sequence classification head on top (linear layer).
670
 
671
- [`Telechat3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
672
  (e.g. GPT-2) do.
673
 
674
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -678,11 +678,11 @@ class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
678
  each row of the batch).
679
  """
680
  )
681
- class Telechat3ForSequenceClassification(Telechat3PreTrainedModel):
682
  def __init__(self, config):
683
  super().__init__(config)
684
  self.num_labels = config.num_labels
685
- self.model = Telechat3Model(config)
686
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
687
 
688
  # Initialize weights and apply final processing
@@ -765,13 +765,13 @@ class Telechat3ForSequenceClassification(Telechat3PreTrainedModel):
765
 
766
 
767
  @auto_docstring
768
- class Telechat3ForQuestionAnswering(Telechat3PreTrainedModel):
769
  base_model_prefix = "transformer"
770
 
771
  # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Telechat3
772
  def __init__(self, config):
773
  super().__init__(config)
774
- self.transformer = Telechat3Model(config)
775
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
776
 
777
  # Initialize weights and apply final processing
@@ -829,11 +829,11 @@ class Telechat3ForQuestionAnswering(Telechat3PreTrainedModel):
829
 
830
 
831
  @auto_docstring
832
- class Telechat3ForTokenClassification(Telechat3PreTrainedModel):
833
  def __init__(self, config):
834
  super().__init__(config)
835
  self.num_labels = config.num_labels
836
- self.model = Telechat3Model(config)
837
  if getattr(config, "classifier_dropout", None) is not None:
838
  classifier_dropout = config.classifier_dropout
839
  elif getattr(config, "hidden_dropout", None) is not None:
 
44
  from transformers.processing_utils import Unpack
45
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
46
 
47
+ from .configuration_telechat3 import TeleChat3Config
48
 
49
  logger = logging.get_logger(__name__)
50
 
 
152
 
153
 
154
  @use_kernel_forward_from_hub("RMSNorm")
155
+ class TeleChat3RMSNorm(nn.Module):
156
  def __init__(self, hidden_size, eps=1e-6):
157
  """
158
+ TeleChat3RMSNorm is equivalent to T5LayerNorm
159
  """
160
  super().__init__()
161
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
172
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
173
 
174
 
175
+ class TeleChat3RotaryEmbedding(nn.Module):
176
+ def __init__(self, config: TeleChat3Config, device=None):
177
  super().__init__()
178
  # BC: "rope_type" was originally "type"
179
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 
240
  return q_embed, k_embed
241
 
242
 
243
+ class TeleChat3MLP(nn.Module):
244
  def __init__(self, config):
245
  super().__init__()
246
  self.config = config
 
294
  return attn_output, attn_weights
295
 
296
 
297
+ class TeleChat3Attention(nn.Module):
298
  """Multi-headed attention from 'Attention Is All You Need' paper"""
299
 
300
+ def __init__(self, config: TeleChat3Config, layer_idx: int):
301
  super().__init__()
302
  self.config = config
303
  self.layer_idx = layer_idx
 
364
  return attn_output, attn_weights
365
 
366
 
367
+ class TeleChat3DecoderLayer(GradientCheckpointingLayer):
368
+ def __init__(self, config: TeleChat3Config, layer_idx: int):
369
  super().__init__()
370
  self.hidden_size = config.hidden_size
371
 
372
+ self.self_attn = TeleChat3Attention(config=config, layer_idx=layer_idx)
373
 
374
+ self.mlp = TeleChat3MLP(config)
375
+ self.input_layernorm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.post_attention_layernorm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
 
378
  def forward(
379
  self,
 
418
 
419
 
420
  @auto_docstring
421
+ class TeleChat3PreTrainedModel(PreTrainedModel):
422
+ config_class = TeleChat3Config
423
  base_model_prefix = "model"
424
  supports_gradient_checkpointing = True
425
+ _no_split_modules = ["TeleChat3DecoderLayer"]
426
  _skip_keys_device_placement = ["past_key_values"]
427
  _supports_flash_attn_3 = True
428
  _supports_flash_attn_2 = True
 
443
  module.weight.data.normal_(mean=0.0, std=std)
444
  if module.padding_idx is not None:
445
  module.weight.data[module.padding_idx].zero_()
446
+ elif isinstance(module, TeleChat3RMSNorm):
447
  module.weight.data.fill_(1.0)
448
 
449
 
450
  @auto_docstring
451
+ class TeleChat3Model(TeleChat3PreTrainedModel):
452
+ def __init__(self, config: TeleChat3Config):
453
  super().__init__(config)
454
  self.padding_idx = config.pad_token_id
455
  self.vocab_size = config.vocab_size
456
 
457
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
458
  self.layers = nn.ModuleList(
459
+ [TeleChat3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
460
  )
461
+ self.norm = TeleChat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+ self.rotary_emb = TeleChat3RotaryEmbedding(config=config)
463
  self.gradient_checkpointing = False
464
 
465
  # Initialize weights and apply final processing
 
577
 
578
 
579
  @auto_docstring
580
+ class TeleChat3ForCausalLM(TeleChat3PreTrainedModel, GenerationMixin):
581
  _tied_weights_keys = ["lm_head.weight"]
582
  _tp_plan = {"lm_head": "colwise_rep"}
583
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
584
 
585
  def __init__(self, config):
586
  super().__init__(config)
587
+ self.model = TeleChat3Model(config)
588
  self.vocab_size = config.vocab_size
589
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
 
 
666
 
667
  @auto_docstring(
668
  custom_intro="""
669
+ The TeleChat3 Model transformer with a sequence classification head on top (linear layer).
670
 
671
+ [`TeleChat3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
672
  (e.g. GPT-2) do.
673
 
674
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
678
  each row of the batch).
679
  """
680
  )
681
+ class TeleChat3ForSequenceClassification(TeleChat3PreTrainedModel):
682
  def __init__(self, config):
683
  super().__init__(config)
684
  self.num_labels = config.num_labels
685
+ self.model = TeleChat3Model(config)
686
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
687
 
688
  # Initialize weights and apply final processing
 
765
 
766
 
767
  @auto_docstring
768
+ class TeleChat3ForQuestionAnswering(TeleChat3PreTrainedModel):
769
  base_model_prefix = "transformer"
770
 
771
  # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Telechat3
772
  def __init__(self, config):
773
  super().__init__(config)
774
+ self.transformer = TeleChat3Model(config)
775
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
776
 
777
  # Initialize weights and apply final processing
 
829
 
830
 
831
  @auto_docstring
832
+ class TeleChat3ForTokenClassification(TeleChat3PreTrainedModel):
833
  def __init__(self, config):
834
  super().__init__(config)
835
  self.num_labels = config.num_labels
836
+ self.model = TeleChat3Model(config)
837
  if getattr(config, "classifier_dropout", None) is not None:
838
  classifier_dropout = config.classifier_dropout
839
  elif getattr(config, "hidden_dropout", None) is not None: