Syon-Li's picture
Create modeling.py
ca40891 verified
from transformers.models.qwen3.modeling_qwen3 import (
create_causal_mask,
create_sliding_window_causal_mask,
Qwen3ForCausalLM,
Qwen3PreTrainedModel,
Qwen3Model,
GenerationMixin,
Unpack,
TransformersKwargs,
)
class Qwen3ForCut(Qwen3PreTrainedModel, GenerationMixin):
# _tied_weights_keys = ["lm_head.weight"]
# _tp_plan = {"lm_head": "colwise_rep"}
# _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config, ):
super().__init__(config)
self.model = Qwen3Model(config)
self.vocab_size = config.vocab_size
self.cut_head = nn.ModuleList(
[nn.Linear(config.hidden_size, 512, bias=False), nn.ReLU(inplace=False), nn.Linear(512, 2, bias=False)]
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cls_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, # [bsz, q_len]
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
bsz, q_len, h_size = hidden_states.shape
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
# logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
logits = None
if self.training and (labels is not None):
# hidden_states = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1)
cut_point_mask = (labels!=-100)
r_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="right")
l_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="left")
shift_hidden_states = hidden_states[r_shift_mask].contiguous()
shift_labels = labels[l_shift_mask].contiguous()
loss_fct = LigerCrossEntropyLoss()
logits = shift_hidden_states
for mlp in self.cut_head:
logits = mlp(logits)
loss = loss_fct(logits, shift_labels)
else:
logits = hidden_states
# logits = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1)
for mlp in self.cut_head:
logits = mlp(logits)
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)