Syon-Li commited on
Commit
2c1bdff
·
verified ·
1 Parent(s): ca40891

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +3 -0
  2. modeling_qwen3cut.py +130 -0
config.json CHANGED
@@ -4,6 +4,9 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
7
  "bos_token_id": 151643,
8
  "dtype": "bfloat16",
9
  "eos_token_id": 151645,
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModelForCausalLM": "modeling_qwen3cut.Qwen3ForCut"
9
+ },
10
  "bos_token_id": 151643,
11
  "dtype": "bfloat16",
12
  "eos_token_id": 151645,
modeling_qwen3cut.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.qwen3.modeling_qwen3 import (
2
+ create_causal_mask,
3
+ create_sliding_window_causal_mask,
4
+ Qwen3ForCausalLM,
5
+ Qwen3PreTrainedModel,
6
+ Qwen3Model,
7
+ GenerationMixin,
8
+ Unpack,
9
+ TransformersKwargs,
10
+ )
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ CausalLMOutputWithPast,
14
+ )
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+ import torch
17
+ from typing import Optional, Union
18
+ from torch import nn
19
+
20
+
21
+
22
+ class Qwen3ForCut(Qwen3PreTrainedModel, GenerationMixin):
23
+ # _tied_weights_keys = ["lm_head.weight"]
24
+ # _tp_plan = {"lm_head": "colwise_rep"}
25
+ # _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
26
+
27
+ def __init__(self, config, ):
28
+ super().__init__(config)
29
+ self.model = Qwen3Model(config)
30
+ self.vocab_size = config.vocab_size
31
+ self.cut_head = nn.ModuleList(
32
+ [nn.Linear(config.hidden_size, 512, bias=False), nn.ReLU(inplace=False), nn.Linear(512, 2, bias=False)]
33
+ )
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+
39
+ def forward(
40
+ self,
41
+ input_ids: Optional[torch.LongTensor] = None,
42
+ attention_mask: Optional[torch.Tensor] = None,
43
+ cls_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[Cache] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None, # [bsz, q_len]
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ cache_position: Optional[torch.LongTensor] = None,
52
+ logits_to_keep: Union[int, torch.Tensor] = 0,
53
+ **kwargs: Unpack[TransformersKwargs],
54
+ ) -> CausalLMOutputWithPast:
55
+ r"""
56
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
57
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
58
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
59
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
60
+ Example:
61
+ ```python
62
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
63
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
64
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
65
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
66
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
67
+ >>> # Generate
68
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
69
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
71
+ ```"""
72
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
+ output_hidden_states = (
74
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
75
+ )
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs: BaseModelOutputWithPast = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ position_ids=position_ids,
82
+ past_key_values=past_key_values,
83
+ inputs_embeds=inputs_embeds,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ cache_position=cache_position,
88
+ **kwargs,
89
+ )
90
+
91
+ hidden_states = outputs.last_hidden_state
92
+ bsz, q_len, h_size = hidden_states.shape
93
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
94
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
95
+ # logits = self.lm_head(hidden_states[:, slice_indices, :])
96
+
97
+
98
+ loss = None
99
+ logits = None
100
+ if self.training and (labels is not None):
101
+ # hidden_states = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1)
102
+ cut_point_mask = (labels!=-100)
103
+ r_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="right")
104
+ l_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="left")
105
+ shift_hidden_states = hidden_states[r_shift_mask].contiguous()
106
+ shift_labels = labels[l_shift_mask].contiguous()
107
+
108
+ loss_fct = LigerCrossEntropyLoss()
109
+
110
+ logits = shift_hidden_states
111
+ for mlp in self.cut_head:
112
+ logits = mlp(logits)
113
+
114
+ loss = loss_fct(logits, shift_labels)
115
+ else:
116
+ logits = hidden_states
117
+ # logits = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1)
118
+ for mlp in self.cut_head:
119
+ logits = mlp(logits)
120
+
121
+ if labels is not None:
122
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
123
+
124
+ return CausalLMOutputWithPast(
125
+ loss=loss,
126
+ logits=logits,
127
+ past_key_values=outputs.past_key_values,
128
+ hidden_states=outputs.hidden_states,
129
+ attentions=outputs.attentions,
130
+ )