bochen2079/katherine-k0 / logs /unsloth_compiled_cache /unsloth_compiled_module_qwen3_5.py
bochen2079's picture
download
raw
77 kB
"""
2026.5.1
2026.5.2
5.5.0
0.24.0
__UNSLOTH_VERSIONING__
"""
# Unsloth auto generated code
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import os
import sys
import torch
import importlib.util
import math
if importlib.util.find_spec("unsloth_studio") is None:
UNSLOTH_STUDIO_ENABLED = False
else:
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
pass
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
import math
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
UNSLOTH_COMPILE_LOCATION = os.environ.get("UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache")
if UNSLOTH_COMPILE_LOCATION not in sys.path:
sys.path.insert(0, UNSLOTH_COMPILE_LOCATION)
import logging
logger_compiler = logging.getLogger(__name__)
if UNSLOTH_ENABLE_LOGGING:
logger_compiler.setLevel(logging.DEBUG)
global INFERENCE_RUNS
INFERENCE_RUNS = 0
try:
import torch._dynamo.eval_frame as torch_dynamo_eval_frame
torch_dynamo_eval_frame._stance.stance
torch_compiler_set_stance = torch.compiler.set_stance
except:
torch_dynamo_eval_frame = None
torch_compiler_set_stance = None
pass
from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
from unsloth_zoo.loss_utils import (
fused_linear_cross_entropy,
unsloth_fused_ce_loss,
)
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
@torch.compiler.disable(recursive = False)
def disable_compile_scaled_dot_product_attention(*args, **kwargs):
return scaled_dot_product_attention(*args, **kwargs)
pass
from transformers.modeling_flash_attention_utils import is_flash_attn_available
if is_flash_attn_available():
try:
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
except:
flash_attn_supports_top_left_mask = None
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except:
_flash_attention_forward = None
try:
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
except:
FlashAttentionKwargs = None
try:
from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
except:
flash_attn_varlen_func = None
else:
flash_attn_supports_top_left_mask = None
_flash_attention_forward = None
FlashAttentionKwargs = None
flash_attn_varlen_func = None
pass
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
from torch.nn import CrossEntropyLoss
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def normal_cross_entropy_loss(self, hidden_states, labels):
logits = self.lm_head(hidden_states)
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return loss, logits
pass
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
LOGITS_ERROR_STRING = \
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
"```\nimport os\n"\
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
"trainer.train()\n```\n"\
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
def return_none(*args, **kwargs): return None
class EmptyLogits:
def __init__(self): return
def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
__getitem__ = raise_logits_error
__getattr__ = raise_getattr_error
def __repr__(self): return LOGITS_ERROR_STRING
def __str__ (self): return LOGITS_ERROR_STRING
pass
EMPTY_LOGITS = EmptyLogits()
functions = dir(torch.Tensor)
for j, function in enumerate(functions):
if function.startswith("__") and function.endswith("__"):
exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
except: continue
pass
def mask_attention_mask_out(labels = None, attention_mask = None):
if labels is not None and attention_mask is not None:
attention_mask = attention_mask.to(device = labels.device)
labels[attention_mask == 0] = -100
return labels
pass
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import functional as F
from unsloth_zoo.temporary_patches.common import torch_compile
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
from transformers.models.qwen3_5.modeling_qwen3_5 import (F, Callable, Any, Optional, torch, nn, init, ACT2FN, Cache, GenerationMixin, FlashAttentionKwargs, BaseModelOutputWithPast, ModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithPast, ROPE_INIT_FUNCTIONS, dynamic_rope_update, ALL_ATTENTION_FUNCTIONS, PreTrainedModel, Unpack, TransformersKwargs, can_return_tuple, is_flash_attention_requested, maybe_autocast, Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig, causal_conv1d_fn, causal_conv1d_update, FusedRMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, logger, __name__, is_fast_path_available, Qwen3_5PreTrainedModel, Qwen3_5Model, Qwen3_5TextModel, Qwen3_5ForCausalLM, Qwen3_5CausalLMOutputWithPast, Qwen3_5ForConditionalGeneration, Qwen3_5GatedDeltaNet)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def Qwen3_5VisionRotaryEmbedding_forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen3_5VisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
return Qwen3_5VisionRotaryEmbedding_forward(self, seqlen=seqlen)
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def Qwen3_5TextRotaryEmbedding_forward(self, x, position_ids):
# In contrast to other models, Qwen3_5 has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Qwen3_5TextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Qwen3_5TextConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10])
@staticmethod
def compute_default_rope_parameters(
config: Qwen3_5TextConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
def forward(self, x, position_ids):
return Qwen3_5TextRotaryEmbedding_forward(self, x=x, position_ids=position_ids)
def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THWTHWTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def Qwen3_5RMSNorm_forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Qwen3_5 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
class Qwen3_5RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return Qwen3_5RMSNorm_forward(self, x=x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def Qwen3_5RMSNormGated_forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
return Qwen3_5RMSNormGated_forward(self, hidden_states=hidden_states, gate=gate)
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
# NOTE: attention mask is a 2D boolean tensor
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def torch_causal_conv1d_update(
hidden_states,
conv_state,
weight,
bias=None,
activation=None,
):
_, hidden_size, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(hidden_states_new[:, :, -state_len:])
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
out = F.silu(out[:, :, -seq_len:])
out = out.to(hidden_states.dtype)
return out
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
@torch.compiler.disable(recursive = False)
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
# for each chunk
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
@torch.compiler.disable(recursive = False)
def torch_recurrent_gated_delta_rule(
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
@torch.compiler.disable(recursive = False)
def Qwen3_5GatedDeltaNet_forward(
self,
hidden_states: torch.Tensor,
cache_params: Cache | None = None,
attention_mask: torch.Tensor | None = None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)
# getting projected states from cache if it exists
if use_precomputed_states:
conv_state = cache_params.layers[self.layer_idx].conv_states
recurrent_state = cache_params.layers[self.layer_idx].recurrent_states
mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
# Update cache
if cache_params is not None:
cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx)
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
class Qwen3_5GatedDeltaNet(nn.Module):
def __init__(self, config: Qwen3_5Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=False,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
)
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.norm = (
Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
if FusedRMSNormGated is None
else FusedRMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
self.causal_conv1d_fn = causal_conv1d_fn
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of the required library is not installed. Falling back to "
"torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
cache_params: Cache | None = None,
attention_mask: torch.Tensor | None = None,
):
return Qwen3_5GatedDeltaNet_forward(self, hidden_states=hidden_states, cache_params=cache_params, attention_mask=attention_mask)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
if isinstance(attention_mask, dict):
attention_mask = attention_mask.get(getattr(module, 'layer_type', None), None)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
@torch.compiler.disable(recursive = False)
def Qwen3_5Attention_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_values: Cache | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Unsloth: align V dtype with Q after RoPE (fixes 4-bit dtype mismatch)
if value_states.dtype != query_states.dtype:
value_states = value_states.to(query_states.dtype)
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output * torch.sigmoid(gate)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen3_5Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3_5Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_values: Cache | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
return Qwen3_5Attention_forward(self, hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs)
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def Qwen3_5MLP_forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen3_5MLP(nn.Module):
def __init__(self, config: Qwen3_5Config, intermediate_size: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return Qwen3_5MLP_forward(self, x=x)
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def Qwen3_5VisionMLP_forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
class Qwen3_5VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return Qwen3_5VisionMLP_forward(self, hidden_state=hidden_state)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def Qwen3_5VisionPatchEmbed_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class Qwen3_5VisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return Qwen3_5VisionPatchEmbed_forward(self, hidden_states=hidden_states)
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def Qwen3_5VisionPatchMerger_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
class Qwen3_5VisionPatchMerger(nn.Module):
def __init__(self, config: Qwen3_5VisionConfig, use_postshuffle_norm=False) -> None:
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.act_fn = nn.GELU()
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return Qwen3_5VisionPatchMerger_forward(self, x=x)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
@torch.compiler.disable(recursive = False)
def Qwen3_5VisionAttention_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
if is_flash_attention_requested(self.config):
# Flash Attention: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Qwen3_5VisionAttention(nn.Module):
def __init__(self, config: Qwen3_5VisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
return Qwen3_5VisionAttention_forward(self, hidden_states=hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, **kwargs)
@torch.compiler.disable(recursive = False)
@can_return_tuple
def Qwen3_5ForCausalLM_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM
>>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3_5-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3_5-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS
loss = None
NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0'
RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1"
n_items = None
if (kwargs) != () and type(kwargs) is dict:
n_items = (kwargs).get("num_items_in_batch", None)
if n_items is None: n_items = (kwargs).get("n_items", None)
if n_items is None:
all_locals = locals()
if 'loss_kwargs' in all_locals:
__kwargs = all_locals['loss_kwargs']
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
if n_items is None and 'kwargs' in all_locals:
__kwargs = all_locals['kwargs']
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
if n_items is None:
all_locals = all_locals.values()
for __kwargs in all_locals:
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
break
pass
requires_grad_ = self.lm_head.weight.requires_grad
requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32
if RETURN_HIDDEN_STATES:
logits = hidden_states[:, slice_indices, :]
elif labels is None:
# Set compiler stance to fail on recompiles for inference
global INFERENCE_RUNS
if torch_dynamo_eval_frame is not None:
old_stance = torch_dynamo_eval_frame._stance.stance
else:
old_stance = None
if old_stance is not None and INFERENCE_RUNS == 1:
# Skip guards and return to eager -> we still need guards!
torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False)
if UNSLOTH_ENABLE_LOGGING:
logger_compiler.info(
f"Unsloth: Removing compiler guards after 1 inference run. "\
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
)
elif old_stance == "eager_on_recompile":
pass
elif old_stance == "default" and INFERENCE_RUNS > 1:
# Reset compiler stance
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
if UNSLOTH_ENABLE_LOGGING:
logger_compiler.info(
f"Unsloth: Reseting guards. "\
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
)
INFERENCE_RUNS = 0
INFERENCE_RUNS += 1
logits = self.lm_head(hidden_states[:, slice_indices, :])
elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_:
loss = fused_linear_cross_entropy(
hidden_states = hidden_states[:, slice_indices, :],
lm_weight = self.lm_head.weight,
labels = labels.to(self.lm_head.weight.device),
num_items_in_batch = n_items,
logit_softcapping = None if () == () else (),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:
lm_head_weight = self.lm_head.weight
lm_head_bias = getattr(self.lm_head, "bias", None)
# ========= NEW fused =========
_hidden_states = hidden_states[:, slice_indices, :]
torch._dynamo.mark_dynamic(_hidden_states, 1)
torch._dynamo.mark_dynamic(labels, 1)
loss = unsloth_fused_ce_loss(
trainer = None,
hidden_states = _hidden_states,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels = labels,
mask = None,
n_items = n_items,
scaling = getattr(self, "accelerator_scaler", None),
target_gb = None,
torch_compile = not UNSLOTH_COMPILE_DISABLE,
logit_scale_multiply = () if () != () else 0,
logit_scale_divide = () if () != () else 0,
logit_softcapping = () if () != () else 0,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if () != ():
logits = logits * ()
if () != ():
logits = logits / ()
if () not in (None, (),):
logits = logits / ()
logits = torch.tanh(logits)
logits = logits * ()
loss = self.loss_function(logits=logits, labels=labels.to(self.lm_head.weight.device), vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Qwen3_5TextConfig
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen3_5TextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
return Qwen3_5ForCausalLM_forward(self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs)
@torch.compiler.disable(recursive = False)
@can_return_tuple
def Qwen3_5ForConditionalGeneration_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3_5CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
Example:
```python
>>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
>>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
},
{"type": "text", "text": "Describe the image."},
],
}
]
>>> inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
>>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
>>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(output_text)
```
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
mm_token_type_ids=mm_token_type_ids,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS
loss = None
NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0'
RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1"
n_items = None
if () != () and type() is dict:
n_items = ().get("num_items_in_batch", None)
if n_items is None: n_items = ().get("n_items", None)
if n_items is None:
all_locals = locals()
if 'loss_kwargs' in all_locals:
__kwargs = all_locals['loss_kwargs']
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
if n_items is None and 'kwargs' in all_locals:
__kwargs = all_locals['kwargs']
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
if n_items is None:
all_locals = all_locals.values()
for __kwargs in all_locals:
if type(__kwargs) is dict:
n_items = __kwargs.get("num_items_in_batch", None)
if n_items is None: n_items = __kwargs.get("n_items", None)
break
pass
requires_grad_ = self.lm_head.weight.requires_grad
requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32
if RETURN_HIDDEN_STATES:
logits = hidden_states[:, slice_indices, :]
elif labels is None:
# Set compiler stance to fail on recompiles for inference
global INFERENCE_RUNS
if torch_dynamo_eval_frame is not None:
old_stance = torch_dynamo_eval_frame._stance.stance
else:
old_stance = None
if old_stance is not None and INFERENCE_RUNS == 1:
# Skip guards and return to eager -> we still need guards!
torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False)
if UNSLOTH_ENABLE_LOGGING:
logger_compiler.info(
f"Unsloth: Removing compiler guards after 1 inference run. "\
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
)
elif old_stance == "eager_on_recompile":
pass
elif old_stance == "default" and INFERENCE_RUNS > 1:
# Reset compiler stance
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
if UNSLOTH_ENABLE_LOGGING:
logger_compiler.info(
f"Unsloth: Reseting guards. "\
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
)
INFERENCE_RUNS = 0
INFERENCE_RUNS += 1
logits = self.lm_head(hidden_states[:, slice_indices, :])
elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_:
loss = fused_linear_cross_entropy(
hidden_states = hidden_states[:, slice_indices, :],
lm_weight = self.lm_head.weight,
labels = labels.to(self.lm_head.weight.device),
num_items_in_batch = n_items,
logit_softcapping = None if () == () else (),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:
lm_head_weight = self.lm_head.weight
lm_head_bias = getattr(self.lm_head, "bias", None)
# ========= NEW fused =========
_hidden_states = hidden_states[:, slice_indices, :]
torch._dynamo.mark_dynamic(_hidden_states, 1)
torch._dynamo.mark_dynamic(labels, 1)
loss = unsloth_fused_ce_loss(
trainer = None,
hidden_states = _hidden_states,
lm_head_weight = lm_head_weight,
lm_head_bias = lm_head_bias,
labels = labels,
mask = None,
n_items = n_items,
scaling = getattr(self, "accelerator_scaler", None),
target_gb = None,
torch_compile = not UNSLOTH_COMPILE_DISABLE,
logit_scale_multiply = () if () != () else 0,
logit_scale_divide = () if () != () else 0,
logit_softcapping = () if () != () else 0,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if () != ():
logits = logits * ()
if () != ():
logits = logits / ()
if () not in (None, (),):
logits = logits / ()
logits = torch.tanh(logits)
logits = logits * ()
loss = self.loss_function(logits=logits, labels=labels.to(self.lm_head.weight.device), vocab_size=self.config.text_config.vocab_size)
return Qwen3_5CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
)
class Qwen3_5ForConditionalGeneration(Qwen3_5PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
# Reference: fix gemma3 grad acc #37208
accepts_loss_kwargs = False
config: Qwen3_5Config
def __init__(self, config):
super().__init__(config)
self.model = Qwen3_5Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.model.get_video_features(
pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3_5CausalLMOutputWithPast:
return Qwen3_5ForConditionalGeneration_forward(self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, mm_token_type_ids=mm_token_type_ids, logits_to_keep=logits_to_keep, **kwargs)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
is_first_iteration=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
use_cache=use_cache,
is_first_iteration=is_first_iteration,
**kwargs,
)
if not is_first_iteration and use_cache:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
return model_inputs
def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
# Overwritten -- requires 3D position ids
text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs)
# Early exit in case we are continuing generation from past kv
past_length = 0
if (cache := model_kwargs.get("past_key_values")) is not None:
past_length = cache.get_seq_length()
if past_length != 0 and self.model.rope_deltas is not None:
position_ids = text_positions[None, ...] + self.model.rope_deltas
return position_ids
# Otherwise compute 3d position ids for vision tokens and concat with text position ids
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if (
is_input_ids
and model_kwargs.get("mm_token_type_ids") is not None
and (model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None)
):
model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"}
vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs)
self.model.rope_deltas = rope_deltas
else:
vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1)
self.model.rope_deltas = torch.zeros(
inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device
)
# Concatenate "text + vision" positions into [4, bs, seq-len]
text_positions = text_positions[None, ...]
position_ids = torch.cat([text_positions, vision_positions], dim=0)
return position_ids
def _get_image_nums_and_video_nums(
self,
input_ids: torch.LongTensor | None,
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: torch.LongTensor | None = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
# Overwritten -- Qwen3_5 use timestamps and remove second_per_grid_ts
# Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
# video_nums: (batch_size,)
# since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
# but Qwen3_5 append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
if video_grid_thw is not None:
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
# Find video boundaries in cumulative_frame_counts
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
# example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw, list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if key == "position_ids" and dict_to_expand[key].ndim == 3:
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1)
elif (
dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
if hasattr(logger, "addFilter"):
import logging
class HideLoggingMessage(logging.Filter):
def __init__(self, text): self.text = text
def filter(self, x): return not (self.text in x.getMessage())
pass
logger.addFilter(HideLoggingMessage("`use_cache=True`"))

Xet Storage Details

Size:
77 kB
·
Xet hash:
2361ced28c01143cb0dc62f8f8ce0c4803373b24191f6abcac77f053f372b2a0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.