Buckets:
| """ | |
| 2026.5.1 | |
| 2026.5.2 | |
| 5.5.0 | |
| 0.24.0 | |
| __UNSLOTH_VERSIONING__ | |
| """ | |
| # Unsloth auto generated code | |
| # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. | |
| # | |
| # This program is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU Lesser General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # (at your option) any later version. | |
| # | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU General Public License for more details. | |
| # | |
| # You should have received a copy of the GNU Lesser General Public License | |
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| import os | |
| import sys | |
| import torch | |
| import importlib.util | |
| import math | |
| if importlib.util.find_spec("unsloth_studio") is None: | |
| UNSLOTH_STUDIO_ENABLED = False | |
| else: | |
| UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" | |
| pass | |
| from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable | |
| import math | |
| UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" | |
| UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1" | |
| UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",) | |
| UNSLOTH_COMPILE_LOCATION = os.environ.get("UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache") | |
| if UNSLOTH_COMPILE_LOCATION not in sys.path: | |
| sys.path.insert(0, UNSLOTH_COMPILE_LOCATION) | |
| import logging | |
| logger_compiler = logging.getLogger(__name__) | |
| if UNSLOTH_ENABLE_LOGGING: | |
| logger_compiler.setLevel(logging.DEBUG) | |
| global INFERENCE_RUNS | |
| INFERENCE_RUNS = 0 | |
| try: | |
| import torch._dynamo.eval_frame as torch_dynamo_eval_frame | |
| torch_dynamo_eval_frame._stance.stance | |
| torch_compiler_set_stance = torch.compiler.set_stance | |
| except: | |
| torch_dynamo_eval_frame = None | |
| torch_compiler_set_stance = None | |
| pass | |
| from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT | |
| from unsloth_zoo.loss_utils import ( | |
| fused_linear_cross_entropy, | |
| unsloth_fused_ce_loss, | |
| ) | |
| scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention | |
| def disable_compile_scaled_dot_product_attention(*args, **kwargs): | |
| return scaled_dot_product_attention(*args, **kwargs) | |
| pass | |
| from transformers.modeling_flash_attention_utils import is_flash_attn_available | |
| if is_flash_attn_available(): | |
| try: | |
| from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask | |
| except: | |
| flash_attn_supports_top_left_mask = None | |
| try: | |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward | |
| except: | |
| _flash_attention_forward = None | |
| try: | |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs | |
| except: | |
| FlashAttentionKwargs = None | |
| try: | |
| from transformers.modeling_flash_attention_utils import flash_attn_varlen_func | |
| except: | |
| flash_attn_varlen_func = None | |
| else: | |
| flash_attn_supports_top_left_mask = None | |
| _flash_attention_forward = None | |
| FlashAttentionKwargs = None | |
| flash_attn_varlen_func = None | |
| pass | |
| torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True} | |
| from torch.nn import CrossEntropyLoss | |
| def normal_cross_entropy_loss(self, hidden_states, labels): | |
| logits = self.lm_head(hidden_states) | |
| logits = logits.float() | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| return loss, logits | |
| pass | |
| # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie | |
| # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' | |
| LOGITS_ERROR_STRING = \ | |
| "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ | |
| 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ | |
| "```\nimport os\n"\ | |
| "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ | |
| "trainer.train()\n```\n"\ | |
| "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" | |
| def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) | |
| def return_none(*args, **kwargs): return None | |
| class EmptyLogits: | |
| def __init__(self): return | |
| def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error | |
| __getitem__ = raise_logits_error | |
| __getattr__ = raise_getattr_error | |
| def __repr__(self): return LOGITS_ERROR_STRING | |
| def __str__ (self): return LOGITS_ERROR_STRING | |
| pass | |
| EMPTY_LOGITS = EmptyLogits() | |
| functions = dir(torch.Tensor) | |
| for j, function in enumerate(functions): | |
| if function.startswith("__") and function.endswith("__"): | |
| exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()) | |
| try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals()) | |
| except: continue | |
| pass | |
| def mask_attention_mask_out(labels = None, attention_mask = None): | |
| if labels is not None and attention_mask is not None: | |
| attention_mask = attention_mask.to(device = labels.device) | |
| labels[attention_mask == 0] = -100 | |
| return labels | |
| pass | |
| from torch import Tensor | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from unsloth_zoo.temporary_patches.common import torch_compile | |
| from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable | |
| from transformers.models.qwen3_5.modeling_qwen3_5 import (F, Callable, Any, Optional, torch, nn, init, ACT2FN, Cache, GenerationMixin, FlashAttentionKwargs, BaseModelOutputWithPast, ModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithPast, ROPE_INIT_FUNCTIONS, dynamic_rope_update, ALL_ATTENTION_FUNCTIONS, PreTrainedModel, Unpack, TransformersKwargs, can_return_tuple, is_flash_attention_requested, maybe_autocast, Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig, causal_conv1d_fn, causal_conv1d_update, FusedRMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, logger, __name__, is_fast_path_available, Qwen3_5PreTrainedModel, Qwen3_5Model, Qwen3_5TextModel, Qwen3_5ForCausalLM, Qwen3_5CausalLMOutputWithPast, Qwen3_5ForConditionalGeneration, Qwen3_5GatedDeltaNet) | |
| def Qwen3_5VisionRotaryEmbedding_forward(self, seqlen: int) -> torch.Tensor: | |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(seq, self.inv_freq) | |
| return freqs | |
| class Qwen3_5VisionRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor # fix linting for `register_buffer` | |
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.theta = theta | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, seqlen: int) -> torch.Tensor: | |
| return Qwen3_5VisionRotaryEmbedding_forward(self, seqlen=seqlen) | |
| # power user: used with advanced RoPE types (e.g. dynamic rope) | |
| def Qwen3_5TextRotaryEmbedding_forward(self, x, position_ids): | |
| # In contrast to other models, Qwen3_5 has different position ids for the grids | |
| # So we expand the inv_freq to shape (3, ...) | |
| if position_ids.ndim == 2: | |
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) | |
| position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) | |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" | |
| with maybe_autocast(device_type=device_type, enabled=False): # Force float32 | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) | |
| freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() * self.attention_scaling | |
| sin = emb.sin() * self.attention_scaling | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| class Qwen3_5TextRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor # fix linting for `register_buffer` | |
| def __init__(self, config: Qwen3_5TextConfig, device=None): | |
| super().__init__() | |
| self.max_seq_len_cached = config.max_position_embeddings | |
| self.original_max_seq_len = config.max_position_embeddings | |
| self.config = config | |
| self.rope_type = self.config.rope_parameters["rope_type"] | |
| rope_init_fn: Callable = self.compute_default_rope_parameters | |
| if self.rope_type != "default": | |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) | |
| self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) | |
| def compute_default_rope_parameters( | |
| config: Qwen3_5TextConfig | None = None, | |
| device: Optional["torch.device"] = None, | |
| seq_len: int | None = None, | |
| ) -> tuple["torch.Tensor", float]: | |
| """ | |
| Computes the inverse frequencies according to the original RoPE implementation | |
| Args: | |
| config ([`~transformers.PreTrainedConfig`]): | |
| The model configuration. | |
| device (`torch.device`): | |
| The device to use for initialization of the inverse frequencies. | |
| seq_len (`int`, *optional*): | |
| The current sequence length. Unused for this type of RoPE. | |
| Returns: | |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the | |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). | |
| """ | |
| base = config.rope_parameters["rope_theta"] | |
| partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) | |
| head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads | |
| dim = int(head_dim * partial_rotary_factor) | |
| attention_factor = 1.0 # Unused in this type of RoPE | |
| # Compute the inverse frequencies | |
| inv_freq = 1.0 / ( | |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) | |
| ) | |
| return inv_freq, attention_factor | |
| def forward(self, x, position_ids): | |
| return Qwen3_5TextRotaryEmbedding_forward(self, x=x, position_ids=position_ids) | |
| def apply_interleaved_mrope(self, freqs, mrope_section): | |
| """Apply interleaved MRoPE to 3D rotary embeddings. | |
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to | |
| interleaved [THWTHWTHW...TT], preserving frequency continuity. | |
| args: | |
| x: (3, bs, seq_len, head_dim // 2) | |
| mrope_section: (3,) | |
| returns: | |
| x_t: (bs, seq_len, head_dim // 2) | |
| """ | |
| freqs_t = freqs[0] # just overwrite the first dimension T | |
| for dim, offset in enumerate((1, 2), start=1): # H, W | |
| length = mrope_section[dim] * 3 | |
| idx = slice(offset, length, 3) | |
| freqs_t[..., idx] = freqs[dim, ..., idx] | |
| return freqs_t | |
| def Qwen3_5RMSNorm_forward(self, x): | |
| output = self._norm(x.float()) | |
| # Llama does x.to(float16) * w whilst Qwen3_5 is (x * w).to(float16) | |
| # See https://github.com/huggingface/transformers/pull/29402 | |
| output = output * (1.0 + self.weight.float()) | |
| return output.type_as(x) | |
| class Qwen3_5RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.zeros(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| return Qwen3_5RMSNorm_forward(self, x=x) | |
| def extra_repr(self): | |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" | |
| def Qwen3_5RMSNormGated_forward(self, hidden_states, gate=None): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| # Norm before gate | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| hidden_states = self.weight * hidden_states.to(input_dtype) | |
| hidden_states = hidden_states * F.silu(gate.to(torch.float32)) | |
| return hidden_states.to(input_dtype) | |
| class Qwen3_5RMSNormGated(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6, **kwargs): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states, gate=None): | |
| return Qwen3_5RMSNormGated_forward(self, hidden_states=hidden_states, gate=gate) | |
| def apply_mask_to_padding_states(hidden_states, attention_mask): | |
| """ | |
| Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 | |
| """ | |
| # NOTE: attention mask is a 2D boolean tensor | |
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: | |
| dtype = hidden_states.dtype | |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) | |
| return hidden_states | |
| def torch_causal_conv1d_update( | |
| hidden_states, | |
| conv_state, | |
| weight, | |
| bias=None, | |
| activation=None, | |
| ): | |
| _, hidden_size, seq_len = hidden_states.shape | |
| state_len = conv_state.shape[-1] | |
| hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) | |
| conv_state.copy_(hidden_states_new[:, :, -state_len:]) | |
| out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) | |
| out = F.silu(out[:, :, -seq_len:]) | |
| out = out.to(hidden_states.dtype) | |
| return out | |
| def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): | |
| """This function is intended to align with the l2norm implementation in the FLA library.""" | |
| inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) | |
| return x * inv_norm | |
| def torch_chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g, | |
| beta, | |
| chunk_size=64, | |
| initial_state=None, | |
| output_final_state=False, | |
| use_qk_l2norm_in_kernel=False, | |
| ): | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = l2norm(query, dim=-1, eps=1e-6) | |
| key = l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size | |
| query = F.pad(query, (0, 0, 0, pad_size)) | |
| key = F.pad(key, (0, 0, 0, pad_size)) | |
| value = F.pad(value, (0, 0, 0, pad_size)) | |
| beta = F.pad(beta, (0, pad_size)) | |
| g = F.pad(g, (0, pad_size)) | |
| total_sequence_length = sequence_length + pad_size | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| v_beta = value * beta.unsqueeze(-1) | |
| k_beta = key * beta.unsqueeze(-1) | |
| # reshape to chunks | |
| query, key, value, k_beta, v_beta = [ | |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) | |
| ] | |
| g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) | |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) | |
| # chunk decay | |
| g = g.cumsum(dim=-1) | |
| decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() | |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) | |
| for i in range(1, chunk_size): | |
| row = attn[..., i, :i].clone() | |
| sub = attn[..., :i, :i].clone() | |
| attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) | |
| attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) | |
| value = attn @ v_beta | |
| k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| core_attn_out = torch.zeros_like(value) | |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) | |
| # for each chunk | |
| for i in range(0, total_sequence_length // chunk_size): | |
| q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] | |
| attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) | |
| v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state | |
| v_new = v_i - v_prime | |
| attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state | |
| core_attn_out[:, :, i] = attn_inter + attn @ v_new | |
| last_recurrent_state = ( | |
| last_recurrent_state * g[:, :, i, -1, None, None].exp() | |
| + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new | |
| ) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) | |
| core_attn_out = core_attn_out[:, :, :sequence_length] | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| def torch_recurrent_gated_delta_rule( | |
| query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False | |
| ): | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = l2norm(query, dim=-1, eps=1e-6) | |
| key = l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| for i in range(sequence_length): | |
| q_t = query[:, :, i] | |
| k_t = key[:, :, i] | |
| v_t = value[:, :, i] | |
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) | |
| beta_t = beta[:, :, i].unsqueeze(-1) | |
| last_recurrent_state = last_recurrent_state * g_t | |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | |
| delta = (v_t - kv_mem) * beta_t | |
| last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | |
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| def Qwen3_5GatedDeltaNet_forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: Cache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) | |
| # Set up dimensions for reshapes later | |
| batch_size, seq_len, _ = hidden_states.shape | |
| use_precomputed_states = ( | |
| cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 | |
| ) | |
| # getting projected states from cache if it exists | |
| if use_precomputed_states: | |
| conv_state = cache_params.layers[self.layer_idx].conv_states | |
| recurrent_state = cache_params.layers[self.layer_idx].recurrent_states | |
| mixed_qkv = self.in_proj_qkv(hidden_states) | |
| mixed_qkv = mixed_qkv.transpose(1, 2) | |
| z = self.in_proj_z(hidden_states) | |
| z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| b = self.in_proj_b(hidden_states) | |
| a = self.in_proj_a(hidden_states) | |
| if use_precomputed_states: | |
| # 2. Convolution sequence transformation | |
| # NOTE: the conv state is updated in `causal_conv1d_update` | |
| mixed_qkv = self.causal_conv1d_update( | |
| mixed_qkv, | |
| conv_state, | |
| self.conv1d.weight.squeeze(1), | |
| self.conv1d.bias, | |
| self.activation, | |
| ) | |
| else: | |
| if cache_params is not None: | |
| conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) | |
| conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) | |
| if self.causal_conv1d_fn is not None: | |
| mixed_qkv = self.causal_conv1d_fn( | |
| x=mixed_qkv, | |
| weight=self.conv1d.weight.squeeze(1), | |
| bias=self.conv1d.bias, | |
| activation=self.activation, | |
| seq_idx=None, | |
| ) | |
| else: | |
| mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) | |
| mixed_qkv = mixed_qkv.transpose(1, 2) | |
| query, key, value = torch.split( | |
| mixed_qkv, | |
| [ | |
| self.key_dim, | |
| self.key_dim, | |
| self.value_dim, | |
| ], | |
| dim=-1, | |
| ) | |
| query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| beta = b.sigmoid() | |
| # If the model is loaded in fp16, without the .float() here, A might be -inf | |
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) | |
| if self.num_v_heads // self.num_k_heads > 1: | |
| query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| if not use_precomputed_states: | |
| core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g=g, | |
| beta=beta, | |
| initial_state=None, | |
| output_final_state=cache_params is not None, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| else: | |
| core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g=g, | |
| beta=beta, | |
| initial_state=recurrent_state, | |
| output_final_state=cache_params is not None, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| # Update cache | |
| if cache_params is not None: | |
| cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) | |
| # reshape input data into 2D tensor | |
| core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) | |
| z = z.reshape(-1, self.head_v_dim) | |
| core_attn_out = self.norm(core_attn_out, z) | |
| core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) | |
| output = self.out_proj(core_attn_out) | |
| return output | |
| class Qwen3_5GatedDeltaNet(nn.Module): | |
| def __init__(self, config: Qwen3_5Config, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_v_heads = config.linear_num_value_heads | |
| self.num_k_heads = config.linear_num_key_heads | |
| self.head_k_dim = config.linear_key_head_dim | |
| self.head_v_dim = config.linear_value_head_dim | |
| self.key_dim = self.head_k_dim * self.num_k_heads | |
| self.value_dim = self.head_v_dim * self.num_v_heads | |
| self.conv_kernel_size = config.linear_conv_kernel_dim | |
| self.layer_idx = layer_idx | |
| self.activation = config.hidden_act | |
| self.act = ACT2FN[config.hidden_act] | |
| self.layer_norm_epsilon = config.rms_norm_eps | |
| # QKV | |
| self.conv_dim = self.key_dim * 2 + self.value_dim | |
| self.conv1d = nn.Conv1d( | |
| in_channels=self.conv_dim, | |
| out_channels=self.conv_dim, | |
| bias=False, | |
| kernel_size=self.conv_kernel_size, | |
| groups=self.conv_dim, | |
| padding=self.conv_kernel_size - 1, | |
| ) | |
| # time step projection (discretization) | |
| # instantiate once and copy inv_dt in init_weights of PretrainedModel | |
| self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) | |
| A = torch.empty(self.num_v_heads).uniform_(0, 16) | |
| self.A_log = nn.Parameter(torch.log(A)) | |
| self.norm = ( | |
| Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) | |
| if FusedRMSNormGated is None | |
| else FusedRMSNormGated( | |
| self.head_v_dim, | |
| eps=self.layer_norm_epsilon, | |
| activation=self.activation, | |
| device=torch.cuda.current_device(), | |
| dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), | |
| ) | |
| ) | |
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | |
| self.causal_conv1d_fn = causal_conv1d_fn | |
| self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update | |
| self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule | |
| self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule | |
| if not is_fast_path_available: | |
| logger.warning_once( | |
| "The fast path is not available because one of the required library is not installed. Falling back to " | |
| "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" | |
| " https://github.com/Dao-AILab/causal-conv1d" | |
| ) | |
| self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) | |
| self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) | |
| self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: Cache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| return Qwen3_5GatedDeltaNet_forward(self, hidden_states=hidden_states, cache_params=cache_params, attention_mask=attention_mask) | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): | |
| """Applies Rotary Position Embedding to the query and key tensors. | |
| Removes the interleaving of cos and sin from GLM | |
| Args: | |
| q (`torch.Tensor`): The query tensor. | |
| k (`torch.Tensor`): The key tensor. | |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
| sin (`torch.Tensor`): The sine part of the rotary embedding. | |
| unsqueeze_dim (`int`, *optional*, defaults to 1): | |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
| Returns: | |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
| """ | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| # Keep half or full tensor for later concatenation | |
| rotary_dim = cos.shape[-1] | |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] | |
| # Apply rotary embeddings on the first half or full tensor | |
| q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) | |
| k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) | |
| # Concatenate back to full shape | |
| q_embed = torch.cat([q_embed, q_pass], dim=-1) | |
| k_embed = torch.cat([k_embed, k_pass], dim=-1) | |
| return q_embed, k_embed | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """ | |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
| """ | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| scaling: float, | |
| dropout: float = 0.0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ): | |
| key_states = repeat_kv(key, module.num_key_value_groups) | |
| value_states = repeat_kv(value, module.num_key_value_groups) | |
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling | |
| if attention_mask is not None: | |
| if isinstance(attention_mask, dict): | |
| attention_mask = attention_mask.get(getattr(module, 'layer_type', None), None) | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights | |
| def Qwen3_5Attention_forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: torch.Tensor | None, | |
| past_key_values: Cache | None = None, | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| input_shape = hidden_states.shape[:-1] | |
| hidden_shape = (*input_shape, -1, self.head_dim) | |
| query_states, gate = torch.chunk( | |
| self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 | |
| ) | |
| gate = gate.reshape(*input_shape, -1) | |
| query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) | |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| # Unsloth: align V dtype with Q after RoPE (fixes 4-bit dtype mismatch) | |
| if value_states.dtype != query_states.dtype: | |
| value_states = value_states.to(query_states.dtype) | |
| if past_key_values is not None: | |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( | |
| self.config._attn_implementation, eager_attention_forward | |
| ) | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| scaling=self.scaling, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | |
| attn_output = attn_output * torch.sigmoid(gate) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, attn_weights | |
| class Qwen3_5Attention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, config: Qwen3_5Config, layer_idx: int): | |
| super().__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads | |
| self.scaling = self.head_dim**-0.5 | |
| self.attention_dropout = config.attention_dropout | |
| self.is_causal = True | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.o_proj = nn.Linear( | |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias | |
| ) | |
| self.q_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! | |
| self.k_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: torch.Tensor | None, | |
| past_key_values: Cache | None = None, | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| return Qwen3_5Attention_forward(self, hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs) | |
| def Qwen3_5MLP_forward(self, x): | |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| return down_proj | |
| class Qwen3_5MLP(nn.Module): | |
| def __init__(self, config: Qwen3_5Config, intermediate_size: int): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = intermediate_size | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, x): | |
| return Qwen3_5MLP_forward(self, x=x) | |
| def Qwen3_5VisionMLP_forward(self, hidden_state): | |
| return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) | |
| class Qwen3_5VisionMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) | |
| self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, hidden_state): | |
| return Qwen3_5VisionMLP_forward(self, hidden_state=hidden_state) | |
| def Qwen3_5VisionPatchEmbed_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| target_dtype = self.proj.weight.dtype | |
| hidden_states = hidden_states.view( | |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
| ) | |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) | |
| return hidden_states | |
| class Qwen3_5VisionPatchEmbed(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.patch_size = config.patch_size | |
| self.temporal_patch_size = config.temporal_patch_size | |
| self.in_channels = config.in_channels | |
| self.embed_dim = config.hidden_size | |
| kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] | |
| self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| return Qwen3_5VisionPatchEmbed_forward(self, hidden_states=hidden_states) | |
| def Qwen3_5VisionPatchMerger_forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) | |
| x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) | |
| return x | |
| class Qwen3_5VisionPatchMerger(nn.Module): | |
| def __init__(self, config: Qwen3_5VisionConfig, use_postshuffle_norm=False) -> None: | |
| super().__init__() | |
| self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) | |
| self.use_postshuffle_norm = use_postshuffle_norm | |
| self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.act_fn = nn.GELU() | |
| self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return Qwen3_5VisionPatchMerger_forward(self, x=x) | |
| def apply_rotary_pos_emb_vision( | |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| orig_q_dtype = q.dtype | |
| orig_k_dtype = k.dtype | |
| q, k = q.float(), k.float() | |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| q_embed = q_embed.to(orig_q_dtype) | |
| k_embed = k_embed.to(orig_k_dtype) | |
| return q_embed, k_embed | |
| def Qwen3_5VisionAttention_forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| rotary_pos_emb: torch.Tensor | None = None, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| query_states, key_states, value_states = ( | |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| ) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) | |
| query_states = query_states.transpose(0, 1).unsqueeze(0) | |
| key_states = key_states.transpose(0, 1).unsqueeze(0) | |
| value_states = value_states.transpose(0, 1).unsqueeze(0) | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( | |
| self.config._attn_implementation, eager_attention_forward | |
| ) | |
| if is_flash_attention_requested(self.config): | |
| # Flash Attention: Use cu_seqlens for variable length attention | |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() | |
| attn_output, _ = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| cu_seq_lens_q=cu_seqlens, | |
| cu_seq_lens_k=cu_seqlens, | |
| max_length_q=max_seqlen, | |
| max_length_k=max_seqlen, | |
| is_causal=False, | |
| **kwargs, | |
| ) | |
| else: | |
| # Other implementations: Process each chunk separately | |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] | |
| splits = [ | |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) | |
| ] | |
| attn_outputs = [ | |
| attention_interface( | |
| self, | |
| q, | |
| k, | |
| v, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| is_causal=False, | |
| **kwargs, | |
| )[0] | |
| for q, k, v in zip(*splits) | |
| ] | |
| attn_output = torch.cat(attn_outputs, dim=1) | |
| attn_output = attn_output.reshape(seq_length, -1).contiguous() | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class Qwen3_5VisionAttention(nn.Module): | |
| def __init__(self, config: Qwen3_5VisionConfig) -> None: | |
| super().__init__() | |
| self.dim = config.hidden_size | |
| self.num_heads = config.num_heads | |
| self.head_dim = self.dim // self.num_heads | |
| self.num_key_value_groups = 1 # needed for eager attention | |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) | |
| self.proj = nn.Linear(self.dim, self.dim) | |
| self.scaling = self.head_dim**-0.5 | |
| self.config = config | |
| self.attention_dropout = 0.0 | |
| self.is_causal = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| rotary_pos_emb: torch.Tensor | None = None, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| return Qwen3_5VisionAttention_forward(self, hidden_states=hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, **kwargs) | |
| def Qwen3_5ForCausalLM_forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> CausalLMOutputWithPast: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM | |
| >>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3_5-8B") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3_5-8B") | |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" | |
| >>> inputs = tokenizer(prompt, return_tensors="pt") | |
| >>> # Generate | |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | |
| ```""" | |
| outputs: BaseModelOutputWithPast = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS | |
| loss = None | |
| NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' | |
| RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" | |
| n_items = None | |
| if (kwargs) != () and type(kwargs) is dict: | |
| n_items = (kwargs).get("num_items_in_batch", None) | |
| if n_items is None: n_items = (kwargs).get("n_items", None) | |
| if n_items is None: | |
| all_locals = locals() | |
| if 'loss_kwargs' in all_locals: | |
| __kwargs = all_locals['loss_kwargs'] | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| if n_items is None and 'kwargs' in all_locals: | |
| __kwargs = all_locals['kwargs'] | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| if n_items is None: | |
| all_locals = all_locals.values() | |
| for __kwargs in all_locals: | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| break | |
| pass | |
| requires_grad_ = self.lm_head.weight.requires_grad | |
| requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32 | |
| if RETURN_HIDDEN_STATES: | |
| logits = hidden_states[:, slice_indices, :] | |
| elif labels is None: | |
| # Set compiler stance to fail on recompiles for inference | |
| global INFERENCE_RUNS | |
| if torch_dynamo_eval_frame is not None: | |
| old_stance = torch_dynamo_eval_frame._stance.stance | |
| else: | |
| old_stance = None | |
| if old_stance is not None and INFERENCE_RUNS == 1: | |
| # Skip guards and return to eager -> we still need guards! | |
| torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False) | |
| if UNSLOTH_ENABLE_LOGGING: | |
| logger_compiler.info( | |
| f"Unsloth: Removing compiler guards after 1 inference run. "\ | |
| f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\ | |
| f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}" | |
| ) | |
| elif old_stance == "eager_on_recompile": | |
| pass | |
| elif old_stance == "default" and INFERENCE_RUNS > 1: | |
| # Reset compiler stance | |
| torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False) | |
| if UNSLOTH_ENABLE_LOGGING: | |
| logger_compiler.info( | |
| f"Unsloth: Reseting guards. "\ | |
| f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\ | |
| f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}" | |
| ) | |
| INFERENCE_RUNS = 0 | |
| INFERENCE_RUNS += 1 | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: | |
| loss = fused_linear_cross_entropy( | |
| hidden_states = hidden_states[:, slice_indices, :], | |
| lm_weight = self.lm_head.weight, | |
| labels = labels.to(self.lm_head.weight.device), | |
| num_items_in_batch = n_items, | |
| logit_softcapping = None if () == () else (), | |
| ) | |
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: | |
| lm_head_weight = self.lm_head.weight | |
| lm_head_bias = getattr(self.lm_head, "bias", None) | |
| # ========= NEW fused ========= | |
| _hidden_states = hidden_states[:, slice_indices, :] | |
| torch._dynamo.mark_dynamic(_hidden_states, 1) | |
| torch._dynamo.mark_dynamic(labels, 1) | |
| loss = unsloth_fused_ce_loss( | |
| trainer = None, | |
| hidden_states = _hidden_states, | |
| lm_head_weight = lm_head_weight, | |
| lm_head_bias = lm_head_bias, | |
| labels = labels, | |
| mask = None, | |
| n_items = n_items, | |
| scaling = getattr(self, "accelerator_scaler", None), | |
| target_gb = None, | |
| torch_compile = not UNSLOTH_COMPILE_DISABLE, | |
| logit_scale_multiply = () if () != () else 0, | |
| logit_scale_divide = () if () != () else 0, | |
| logit_softcapping = () if () != () else 0, | |
| ) | |
| else: | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| if () != (): | |
| logits = logits * () | |
| if () != (): | |
| logits = logits / () | |
| if () not in (None, (),): | |
| logits = logits / () | |
| logits = torch.tanh(logits) | |
| logits = logits * () | |
| loss = self.loss_function(logits=logits, labels=labels.to(self.lm_head.weight.device), vocab_size=self.config.vocab_size, **kwargs) | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin): | |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} | |
| _tp_plan = {"lm_head": "colwise_gather_output"} | |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} | |
| config: Qwen3_5TextConfig | |
| _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = Qwen3_5TextModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> CausalLMOutputWithPast: | |
| return Qwen3_5ForCausalLM_forward(self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs) | |
| def Qwen3_5ForConditionalGeneration_forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| pixel_values: torch.Tensor | None = None, | |
| pixel_values_videos: torch.FloatTensor | None = None, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| mm_token_type_ids: torch.IntTensor | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | Qwen3_5CausalLMOutputWithPast: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each video in LLM. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration | |
| >>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") | |
| >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") | |
| >>> messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", | |
| }, | |
| {"type": "text", "text": "Describe the image."}, | |
| ], | |
| } | |
| ] | |
| >>> inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| >>> # Generate | |
| >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) | |
| >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | |
| >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| >>> print(output_text) | |
| ``` | |
| """ | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| mm_token_type_ids=mm_token_type_ids, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS | |
| loss = None | |
| NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' | |
| RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" | |
| n_items = None | |
| if () != () and type() is dict: | |
| n_items = ().get("num_items_in_batch", None) | |
| if n_items is None: n_items = ().get("n_items", None) | |
| if n_items is None: | |
| all_locals = locals() | |
| if 'loss_kwargs' in all_locals: | |
| __kwargs = all_locals['loss_kwargs'] | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| if n_items is None and 'kwargs' in all_locals: | |
| __kwargs = all_locals['kwargs'] | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| if n_items is None: | |
| all_locals = all_locals.values() | |
| for __kwargs in all_locals: | |
| if type(__kwargs) is dict: | |
| n_items = __kwargs.get("num_items_in_batch", None) | |
| if n_items is None: n_items = __kwargs.get("n_items", None) | |
| break | |
| pass | |
| requires_grad_ = self.lm_head.weight.requires_grad | |
| requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32 | |
| if RETURN_HIDDEN_STATES: | |
| logits = hidden_states[:, slice_indices, :] | |
| elif labels is None: | |
| # Set compiler stance to fail on recompiles for inference | |
| global INFERENCE_RUNS | |
| if torch_dynamo_eval_frame is not None: | |
| old_stance = torch_dynamo_eval_frame._stance.stance | |
| else: | |
| old_stance = None | |
| if old_stance is not None and INFERENCE_RUNS == 1: | |
| # Skip guards and return to eager -> we still need guards! | |
| torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False) | |
| if UNSLOTH_ENABLE_LOGGING: | |
| logger_compiler.info( | |
| f"Unsloth: Removing compiler guards after 1 inference run. "\ | |
| f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\ | |
| f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}" | |
| ) | |
| elif old_stance == "eager_on_recompile": | |
| pass | |
| elif old_stance == "default" and INFERENCE_RUNS > 1: | |
| # Reset compiler stance | |
| torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False) | |
| if UNSLOTH_ENABLE_LOGGING: | |
| logger_compiler.info( | |
| f"Unsloth: Reseting guards. "\ | |
| f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\ | |
| f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}" | |
| ) | |
| INFERENCE_RUNS = 0 | |
| INFERENCE_RUNS += 1 | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: | |
| loss = fused_linear_cross_entropy( | |
| hidden_states = hidden_states[:, slice_indices, :], | |
| lm_weight = self.lm_head.weight, | |
| labels = labels.to(self.lm_head.weight.device), | |
| num_items_in_batch = n_items, | |
| logit_softcapping = None if () == () else (), | |
| ) | |
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: | |
| lm_head_weight = self.lm_head.weight | |
| lm_head_bias = getattr(self.lm_head, "bias", None) | |
| # ========= NEW fused ========= | |
| _hidden_states = hidden_states[:, slice_indices, :] | |
| torch._dynamo.mark_dynamic(_hidden_states, 1) | |
| torch._dynamo.mark_dynamic(labels, 1) | |
| loss = unsloth_fused_ce_loss( | |
| trainer = None, | |
| hidden_states = _hidden_states, | |
| lm_head_weight = lm_head_weight, | |
| lm_head_bias = lm_head_bias, | |
| labels = labels, | |
| mask = None, | |
| n_items = n_items, | |
| scaling = getattr(self, "accelerator_scaler", None), | |
| target_gb = None, | |
| torch_compile = not UNSLOTH_COMPILE_DISABLE, | |
| logit_scale_multiply = () if () != () else 0, | |
| logit_scale_divide = () if () != () else 0, | |
| logit_softcapping = () if () != () else 0, | |
| ) | |
| else: | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| if () != (): | |
| logits = logits * () | |
| if () != (): | |
| logits = logits / () | |
| if () not in (None, (),): | |
| logits = logits / () | |
| logits = torch.tanh(logits) | |
| logits = logits * () | |
| loss = self.loss_function(logits=logits, labels=labels.to(self.lm_head.weight.device), vocab_size=self.config.text_config.vocab_size) | |
| return Qwen3_5CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| rope_deltas=outputs.rope_deltas, | |
| ) | |
| class Qwen3_5ForConditionalGeneration(Qwen3_5PreTrainedModel, GenerationMixin): | |
| _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} | |
| # Reference: fix gemma3 grad acc #37208 | |
| accepts_loss_kwargs = False | |
| config: Qwen3_5Config | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = Qwen3_5Model(config) | |
| self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.model.set_input_embeddings(value) | |
| def get_video_features( | |
| self, | |
| pixel_values_videos: torch.FloatTensor, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input videos. | |
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each video in LLM. | |
| """ | |
| return self.model.get_video_features( | |
| pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs | |
| ) | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input images. | |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| """ | |
| return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| pixel_values: torch.Tensor | None = None, | |
| pixel_values_videos: torch.FloatTensor | None = None, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| mm_token_type_ids: torch.IntTensor | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | Qwen3_5CausalLMOutputWithPast: | |
| return Qwen3_5ForConditionalGeneration_forward(self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, mm_token_type_ids=mm_token_type_ids, logits_to_keep=logits_to_keep, **kwargs) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| position_ids=None, | |
| use_cache=True, | |
| pixel_values=None, | |
| pixel_values_videos=None, | |
| image_grid_thw=None, | |
| video_grid_thw=None, | |
| is_first_iteration=False, | |
| **kwargs, | |
| ): | |
| # Overwritten -- in specific circumstances we don't want to forward image inputs to the model | |
| model_inputs = super().prepare_inputs_for_generation( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| position_ids=position_ids, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| use_cache=use_cache, | |
| is_first_iteration=is_first_iteration, | |
| **kwargs, | |
| ) | |
| if not is_first_iteration and use_cache: | |
| model_inputs["pixel_values"] = None | |
| model_inputs["pixel_values_videos"] = None | |
| return model_inputs | |
| def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): | |
| # Overwritten -- requires 3D position ids | |
| text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) | |
| # Early exit in case we are continuing generation from past kv | |
| past_length = 0 | |
| if (cache := model_kwargs.get("past_key_values")) is not None: | |
| past_length = cache.get_seq_length() | |
| if past_length != 0 and self.model.rope_deltas is not None: | |
| position_ids = text_positions[None, ...] + self.model.rope_deltas | |
| return position_ids | |
| # Otherwise compute 3d position ids for vision tokens and concat with text position ids | |
| if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: | |
| inputs_tensor = model_kwargs["input_ids"] | |
| is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] | |
| if ( | |
| is_input_ids | |
| and model_kwargs.get("mm_token_type_ids") is not None | |
| and (model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None) | |
| ): | |
| model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} | |
| vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) | |
| self.model.rope_deltas = rope_deltas | |
| else: | |
| vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) | |
| self.model.rope_deltas = torch.zeros( | |
| inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device | |
| ) | |
| # Concatenate "text + vision" positions into [4, bs, seq-len] | |
| text_positions = text_positions[None, ...] | |
| position_ids = torch.cat([text_positions, vision_positions], dim=0) | |
| return position_ids | |
| def _get_image_nums_and_video_nums( | |
| self, | |
| input_ids: torch.LongTensor | None, | |
| inputs_embeds: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Get the number of images and videos for each sample to calculate the separation length of the sample tensor. | |
| These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. | |
| Args: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Returns: | |
| image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) | |
| video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) | |
| """ | |
| image_token_id = self.config.image_token_id | |
| video_token_id = self.config.video_token_id | |
| vision_start_token_id = self.config.vision_start_token_id | |
| if inputs_embeds is not None: | |
| vision_start_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| image_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| video_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| else: | |
| vision_start_mask = input_ids == vision_start_token_id | |
| image_mask = input_ids == image_token_id | |
| video_mask = input_ids == video_token_id | |
| vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) | |
| image_nums = torch.sum(vision_first_mask & image_mask, dim=1) | |
| video_nums = torch.sum(vision_first_mask & video_mask, dim=1) | |
| return image_nums, video_nums | |
| def _expand_inputs_for_generation( | |
| self, | |
| expand_size: int = 1, | |
| is_encoder_decoder: bool = False, | |
| input_ids: torch.LongTensor | None = None, | |
| **model_kwargs, | |
| ) -> tuple[torch.LongTensor, dict[str, Any]]: | |
| # Overwritten -- Qwen3_5 use timestamps and remove second_per_grid_ts | |
| # Support for expanding tensors without a batch size dimension | |
| # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw | |
| # pixel_values.shape[0] is sum(seqlen_images for samples) | |
| # image_grid_thw.shape[0] is sum(num_images for samples) | |
| if expand_size == 1: | |
| return input_ids, model_kwargs | |
| visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] | |
| def _expand_dict_for_generation_visual(dict_to_expand): | |
| image_grid_thw = model_kwargs.get("image_grid_thw", None) | |
| video_grid_thw = model_kwargs.get("video_grid_thw", None) | |
| image_nums, video_nums = self._get_image_nums_and_video_nums( | |
| input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) | |
| ) | |
| # video_nums: (batch_size,) | |
| # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), | |
| # but Qwen3_5 append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw | |
| if video_grid_thw is not None: | |
| cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) | |
| cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) | |
| # Find video boundaries in cumulative_frame_counts | |
| video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) | |
| # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] | |
| video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) | |
| def _repeat_interleave_samples(x, lengths, repeat_times): | |
| samples = torch.split(x, lengths) | |
| repeat_args = [repeat_times] + [1] * (x.dim() - 1) | |
| result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) | |
| return result | |
| for key in dict_to_expand: | |
| if key == "pixel_values": | |
| # split images into samples | |
| samples = torch.split(image_grid_thw, list(image_nums)) | |
| # compute the sequence length of images for each sample | |
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "image_grid_thw": | |
| # get the num of images for each sample | |
| lengths = list(image_nums) | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "pixel_values_videos": | |
| samples = torch.split(video_grid_thw, list(video_nums)) | |
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "video_grid_thw": | |
| lengths = list(video_nums) | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| return dict_to_expand | |
| def _expand_dict_for_generation(dict_to_expand): | |
| for key in dict_to_expand: | |
| if key == "position_ids" and dict_to_expand[key].ndim == 3: | |
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1) | |
| elif ( | |
| dict_to_expand[key] is not None | |
| and isinstance(dict_to_expand[key], torch.Tensor) | |
| and key not in visual_keys | |
| ): | |
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) | |
| return dict_to_expand | |
| model_kwargs = _expand_dict_for_generation_visual(model_kwargs) | |
| if input_ids is not None: | |
| input_ids = input_ids.repeat_interleave(expand_size, dim=0) | |
| model_kwargs = _expand_dict_for_generation(model_kwargs) | |
| if is_encoder_decoder: | |
| if model_kwargs.get("encoder_outputs") is None: | |
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | |
| model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) | |
| return input_ids, model_kwargs | |
| if hasattr(logger, "addFilter"): | |
| import logging | |
| class HideLoggingMessage(logging.Filter): | |
| def __init__(self, text): self.text = text | |
| def filter(self, x): return not (self.text in x.getMessage()) | |
| pass | |
| logger.addFilter(HideLoggingMessage("`use_cache=True`")) | |
Xet Storage Details
- Size:
- 77 kB
- Xet hash:
- 2361ced28c01143cb0dc62f8f8ce0c4803373b24191f6abcac77f053f372b2a0
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.