from transformers.models.qwen3.modeling_qwen3 import ( create_causal_mask, create_sliding_window_causal_mask, Qwen3ForCausalLM, Qwen3PreTrainedModel, Qwen3Model, GenerationMixin, Unpack, TransformersKwargs, ) class Qwen3ForCut(Qwen3PreTrainedModel, GenerationMixin): # _tied_weights_keys = ["lm_head.weight"] # _tp_plan = {"lm_head": "colwise_rep"} # _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config, ): super().__init__(config) self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.cut_head = nn.ModuleList( [nn.Linear(config.hidden_size, 512, bias=False), nn.ReLU(inplace=False), nn.Linear(512, 2, bias=False)] ) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cls_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, # [bsz, q_len] use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Qwen3ForCausalLM >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state bsz, q_len, h_size = hidden_states.shape # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep # logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None logits = None if self.training and (labels is not None): # hidden_states = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1) cut_point_mask = (labels!=-100) r_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="right") l_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="left") shift_hidden_states = hidden_states[r_shift_mask].contiguous() shift_labels = labels[l_shift_mask].contiguous() loss_fct = LigerCrossEntropyLoss() logits = shift_hidden_states for mlp in self.cut_head: logits = mlp(logits) loss = loss_fct(logits, shift_labels) else: logits = hidden_states # logits = torch.concat(tensors=[hidden_states, hidden_states[cls_mask].reshape(bsz, 1, h_size).expand(-1, q_len, -1)], dim=-1) for mlp in self.cut_head: logits = mlp(logits) if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )