| from transformers.models.qwen3.modeling_qwen3 import ( |
| create_causal_mask, |
| create_sliding_window_causal_mask, |
| Qwen3ForCausalLM, |
| Qwen3PreTrainedModel, |
| Qwen3Model, |
| GenerationMixin, |
| Unpack, |
| TransformersKwargs, |
| ) |
|
|
|
|
|
|
| class Qwen3ForCut(Qwen3PreTrainedModel, GenerationMixin): |
| |
| |
| |
|
|
| def __init__(self, config, ): |
| super().__init__(config) |
| self.model = Qwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.cut_head = nn.ModuleList( |
| [nn.Linear(config.hidden_size, 512, bias=False), nn.ReLU(inplace=False), nn.Linear(512, 2, bias=False)] |
| ) |
|
|
| |
| self.post_init() |
|
|
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| cls_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
| >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| bsz, q_len, h_size = hidden_states.shape |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| |
|
|
| |
| loss = None |
| logits = None |
| if self.training and (labels is not None): |
| |
| cut_point_mask = (labels!=-100) |
| r_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="right") |
| l_shift_mask = get_shift_mask(cut_point_mask=cut_point_mask, side="left") |
| shift_hidden_states = hidden_states[r_shift_mask].contiguous() |
| shift_labels = labels[l_shift_mask].contiguous() |
|
|
| loss_fct = LigerCrossEntropyLoss() |
|
|
| logits = shift_hidden_states |
| for mlp in self.cut_head: |
| logits = mlp(logits) |
|
|
| loss = loss_fct(logits, shift_labels) |
| else: |
| logits = hidden_states |
| |
| for mlp in self.cut_head: |
| logits = mlp(logits) |
| |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |