diff --git "a/modeling_interns2_preview.py" "b/modeling_interns2_preview.py" new file mode 100644--- /dev/null +++ "b/modeling_interns2_preview.py" @@ -0,0 +1,3098 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/interns2_preview/modular_interns2_preview.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_interns2_preview.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import initialization as init +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, EncoderDecoderCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_experts_implementation, use_kernelized_func +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from transformers.utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available +from transformers.utils.output_capturing import OutputRecorder, capture_outputs + +from .configuration_interns2_preview import ( + InternS2PreviewConfig, + InternS2PreviewTextConfig, + InternS2PreviewTimeSeriesConfig, + InternS2PreviewVisionConfig, +) + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_linear_attention_available(): + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +else: + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None + FusedRMSNormGated = None + +logger = logging.get_logger(__name__) + + +class InternS2PreviewVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class InternS2PreviewDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention + cache (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: InternS2PreviewConfig): + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + + # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx].device + beam_idx = beam_idx.to(device) + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + return self.conv_states[self.last_linear_layer] is not None + + +class InternS2PreviewRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +is_fast_path_available = all( + (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +) + + +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class InternS2PreviewGatedDeltaNet(nn.Module): + def __init__(self, config: InternS2PreviewConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + InternS2PreviewRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: InternS2PreviewDynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class InternS2PreviewRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst InternS2Preview is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +@use_kernelized_func(apply_rotary_pos_emb) +class InternS2PreviewAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternS2PreviewConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = InternS2PreviewRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = InternS2PreviewRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output * torch.sigmoid(gate) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class InternS2PreviewMLP(nn.Module): + def __init__(self, config: InternS2PreviewConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_experts_implementation +class InternS2PreviewExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class InternS2PreviewTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + + +class InternS2PreviewSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = InternS2PreviewTopKRouter(config) + self.experts = InternS2PreviewExperts(config) + self.shared_expert = InternS2PreviewMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output + + expert_output += shared_expert_output + expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) + return expert_output + + +class InternS2PreviewDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InternS2PreviewTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = InternS2PreviewGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = InternS2PreviewAttention(config, layer_idx) + self.mlp = InternS2PreviewSparseMoeBlock(config) + self.input_layernorm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class InternS2PreviewPreTrainedModel(PreTrainedModel): + config: InternS2PreviewConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["InternS2PreviewDecoderLayer", "InternS2PreviewVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] + _can_record_outputs = { + "router_logits": OutputRecorder(InternS2PreviewTopKRouter, index=0), + "hidden_states": InternS2PreviewDecoderLayer, + "attentions": InternS2PreviewAttention, + } + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, InternS2PreviewGatedDeltaNet): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, InternS2PreviewRMSNorm): + init.zeros_(module.weight) + elif isinstance(module, InternS2PreviewExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, InternS2PreviewSparseMoeBlock): + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, InternS2PreviewVisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +class InternS2PreviewVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class InternS2PreviewVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class InternS2PreviewVisionPatchMerger(nn.Module): + def __init__(self, config: InternS2PreviewVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class InternS2PreviewVisionAttention(nn.Module): + def __init__(self, config: InternS2PreviewVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class InternS2PreviewVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = InternS2PreviewVisionAttention(config=config) + self.mlp = InternS2PreviewVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class InternS2PreviewVisionModel(InternS2PreviewPreTrainedModel): + config: InternS2PreviewVisionConfig + _no_split_modules = ["InternS2PreviewVisionBlock"] + _can_record_outputs = { + "hidden_states": InternS2PreviewVisionBlock, + "attentions": InternS2PreviewVisionAttention, + } + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = InternS2PreviewVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = InternS2PreviewVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([InternS2PreviewVisionBlock(config) for _ in range(config.depth)]) + self.merger = InternS2PreviewVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.gradient_checkpointing = False + + self.post_init() + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + @merge_with_config_defaults + @capture_outputs + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + + merged_hidden_states = self.merger(hidden_states) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +class InternS2PreviewTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: InternS2PreviewTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) + + @staticmethod + def compute_default_rope_parameters( + config: InternS2PreviewTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, InternS2Preview has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + +class InternS2PreviewTextModel(InternS2PreviewPreTrainedModel): + def __init__(self, config: InternS2PreviewTextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [InternS2PreviewDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = InternS2PreviewTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = InternS2PreviewDynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return InternS2PreviewModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +class InternS2PreviewTimeSeriesAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + layer_idx: int | None = None, + config: InternS2PreviewTimeSeriesConfig | None = None, + num_key_value_groups: int = 1, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.num_key_value_groups = num_key_value_groups + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool = False, + cache_position: torch.Tensor | None = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + + # Scaling is susceptible to floating point arithmetics' inprecisions + # which can lead to different results (this is dependent from model + # to model, e.g. intern_s2_preview_time_series is one such case). We therefore keep the + # original order of scaling to follow the original implementation + # and enforce no scaling (1.0) in the attention call below. + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(*q_input_shape) + query_states = query_states.transpose(1, 2).contiguous() + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache + else: + past_key_values = past_key_values.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_values and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + output_attentions=output_attentions, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class InternS2PreviewTimeSeriesEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InternS2PreviewTimeSeriesConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = InternS2PreviewTimeSeriesAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, attn_weights + + +class InternS2PreviewTimeSeriesEncoder(PreTrainedModel): + config: InternS2PreviewTimeSeriesConfig + base_model_prefix = "model" + main_input_name = "input_features" + input_modalities = ("audio", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["InternS2PreviewTimeSeriesEncoderLayer", "WhisperDecoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + def __init__(self, config: InternS2PreviewTimeSeriesConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + self.embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + # self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, self.embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, self.embed_dim) + + self.layers = nn.ModuleList( + [InternS2PreviewTimeSeriesEncoderLayer(config) for _ in range(config.encoder_layers)] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + self.post_init() + + self.mask_type = None + self.chunk_length = None + + self.adapt_in = nn.Linear(config.ts_adapt_in_dim, 80) + self.adapt_out = nn.Linear(self.embed_dim, config.ts_adapt_out_dim) + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def define_masktype(self, masktype, chunk_length=None): + self.mask_type = masktype + self.chunk_length = chunk_length + + def _make_causal_mask( + self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + # Copied from transformers.models.bart.modeling_bart._expand_mask + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + # print(mask.size()) + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = self._make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + return combined_attention_mask + + def prepare_chunk_attention_mask(self, attention_mask, input_shape, inputs_embeds): + + block_size = round(self.chunk_length / 4 * 2) + matrix_size = input_shape[1] + + matrix = torch.ones(matrix_size, matrix_size) + + num_full_blocks = round(matrix_size // block_size) + remainder = matrix_size % block_size + for i in range(num_full_blocks): + row_start = i * block_size + col_start = i * block_size + matrix[row_start : row_start + block_size, col_start : col_start + block_size] = torch.zeros( + block_size, block_size + ) + + if remainder > 0: + last_row_start = num_full_blocks * block_size + last_col_start = num_full_blocks * block_size + matrix[last_row_start : last_row_start + remainder, last_col_start : last_col_start + remainder] = ( + torch.zeros(remainder, remainder) + ) + + matrix = matrix * -65504 + matrix = matrix.unsqueeze(0).unsqueeze(0).repeat(input_shape[0], 1, 1, 1) + attention_mask = matrix.to(inputs_embeds.device) + return attention_mask + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # (N, T, C) -> (T, N, C) -> (N, C, T) + input_features = input_features.permute(1, 0, 2) + input_features = self.adapt_in(input_features) + input_features = input_features.permute(1, 2, 0) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # (N, C, T) -> (N, C, T//2) + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + # (N, C, T) -> (N, T, C) + inputs_embeds = inputs_embeds.permute(0, 2, 1) # torch.Size([1, 100, 768]) + embed_pos = self.embed_positions.weight # torch.Size([1500, 768]) + + if inputs_embeds.shape[1] > embed_pos.shape[0]: + target_len = inputs_embeds.shape[1] + padding = [0, 0, 0, target_len - embed_pos.shape[0]] + + embed_pos = nn.functional.pad(embed_pos, pad=padding, mode="constant", value=0) + hidden_states = inputs_embeds[:, : embed_pos.shape[0], :] + embed_pos + else: + hidden_states = inputs_embeds + embed_pos[: inputs_embeds.shape[1], :] + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + input_shape = inputs_embeds.size()[:-1] + past_key_values_length = 0 + attention_mask = None + if self.mask_type == "chunk": + attention_mask = self.prepare_chunk_attention_mask(attention_mask, input_shape, inputs_embeds) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (self.layer_norm(hidden_states),) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + # TODO test whether OK + # layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # (N, T, C) -> (T, N, C) + hidden_states = hidden_states.permute(1, 0, 2) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.adapt_out(hidden_states) + + # (T, N, C) -> (N, T, C) + hidden_states = hidden_states.permute(1, 0, 2) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return ModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class InternS2PreviewTimeSeriesConcatSubsampling(nn.Module): + def __init__(self, in_channels: int, concat_size: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels * concat_size + + def forward(self, ts_signals: torch.Tensor, ts_lens: torch.Tensor): + if ts_signals.shape[1] % 2 != 0: + ts_signals = ts_signals[:, :-1, :] + even_frames = ts_signals[:, ::2, :] + odd_frames = ts_signals[:, 1::2, :] + ts_signals = torch.cat((even_frames, odd_frames), dim=2) + ts_lens = ts_lens // 2 + return ts_signals, ts_lens + + +class InternS2PreviewTimeSeriesFixPositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 20000): + super().__init__() + pe = torch.zeros(max_len, d_model, dtype=torch.float) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1).to(torch.float32) # (max_len, 1, d_model) + self.register_buffer("pe", pe, persistent=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (seq_len, batch_size, d_model) + x = x + self.pe[: x.size(0), :] + return x.clone() + + +class InternS2PreviewTimeSeriesMultiChannelAdaptiveSubsampling(nn.Module): + def __init__(self, hidden_dim: int = 128, nhead: int = 8, num_encoder_layers: int = 1): + super().__init__() + self.conv = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=5, stride=1, padding=2) + encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead) + self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers) + self.pos_encoder = InternS2PreviewTimeSeriesFixPositionalEncoding(d_model=hidden_dim) + self.subsampling = InternS2PreviewTimeSeriesConcatSubsampling(128, 2) + + def forward( + self, inputs: torch.Tensor, input_lens: torch.Tensor, sr: torch.Tensor, channels: torch.LongTensor + ) -> tuple[torch.Tensor, torch.Tensor]: + features, feature_lens = self.forward_patch(inputs, input_lens, sr, channels) + outputs = features + output_lens = feature_lens + return outputs, output_lens + + def forward_patch( + self, inputs: torch.Tensor, input_lens: torch.Tensor, sr: torch.Tensor, channels: torch.LongTensor + ) -> tuple[torch.Tensor, torch.Tensor]: + sr = sr.float() + strides = torch.floor(160 / ((1 + torch.exp(-sr / 100)) ** 6)) + patch_sizes = strides * 2 + patched_outputs = [] + output_lens = [] + + for i in range(len(inputs)): + le = input_lens[i] + channel = channels[i] + seq = inputs[i, :le, :channel] # [seq_len, channel] + ps = patch_sizes[i].item() + st = strides[i].item() + + output_len = torch.ceil((le - ps) / st) + 1 + pad_len = ((output_len - 1) * st + ps - le).long().item() + if seq.ndim == 1: + seq = seq.unsqueeze(-1) + seq = nn.functional.pad(seq, (0, 0, 0, pad_len), "constant", 0) + assert output_len > 0, (seq.shape, ps, st, le, output_len) + output_lens.append(output_len) + indices = (torch.arange(0, output_len * st, st).unsqueeze(1) + torch.arange(ps)).long() + patched = seq[indices] + + output = self.forward_encoder(patched) # [num_patch, D] + patched_outputs.append(output) + + outputs = nn.utils.rnn.pad_sequence(patched_outputs, batch_first=True) + output_lens = torch.tensor(output_lens).squeeze().to(outputs.device).long() + if output_lens.ndim == 0: + output_lens = output_lens.unsqueeze(0) + + outputs, output_lens = self.subsampling(outputs.clone(), output_lens.clone()) + return outputs, output_lens + + def forward_encoder(self, x: torch.Tensor) -> torch.Tensor: + num_patch, patch_len, C = x.shape + # conv1 + x = x.reshape(num_patch * C, 1, patch_len) # each channel is treated as an independent sample into conv1 + x = nn.functional.relu(self.conv(x)) # [B*C, D1, L] + x = x.permute(2, 0, 1) # [L, B*C, D1] + + x = self.pos_encoder(x) # [L, B*C, D1] + x = self.transformer_encoder(x.to(torch.bfloat16)) + x = x.mean(0) + + x = x.reshape(num_patch, C, -1) + + return x.mean(1) + + +class InternS2PreviewTimeSeriesProjector(nn.Module): + def __init__(self, config: InternS2PreviewTimeSeriesConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.ts_hidden_dim) + self.linear_1 = nn.Linear(config.ts_hidden_dim, config.out_hidden_size) + self.act = ACT2FN[config.activation_function] + self.linear_2 = nn.Linear(config.out_hidden_size, config.out_hidden_size) + + def forward(self, ts_features): + hidden_states = self.layer_norm(ts_features) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class InternS2PreviewTimeSeriesModel(PreTrainedModel): + main_input_name = "time_series_signals" + _supports_flash_attn_2 = False + config_class = InternS2PreviewTimeSeriesConfig + _no_split_modules = ["InternS2PreviewTimeSeriesEncoderLayer"] + + def __init__(self, config: InternS2PreviewTimeSeriesConfig): + super().__init__(config) + self.config = config + self.encoder_embed = InternS2PreviewTimeSeriesMultiChannelAdaptiveSubsampling() + self.encoder = InternS2PreviewTimeSeriesEncoder(config) + self.projector = InternS2PreviewTimeSeriesProjector(config) + + def get_input_embeddings(self): + return self.encoder_embed + + def make_pad_mask(self, lengths: torch.Tensor) -> torch.Tensor: + r""" + lengths (`torch.Tensor` of shape `(batch_size,)`): + A 1-D tensor containing sentence lengths. + + Returns: + A 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are filled with `False`. + Example: + ```python + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> self.make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + ``` + """ + assert lengths.ndim == 1, lengths.ndim + max_len = lengths.max() + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + return expaned_lengths >= lengths.unsqueeze(-1) + + def forward( + self, + time_series_signals: torch.FloatTensor = None, + ts_lens: torch.Tensor = None, + sr: torch.Tensor = None, + channels: torch.LongTensor | None = None, + output_hidden_states: bool = None, + return_dict: bool = None, + time_series_embeds: torch.FloatTensor = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if time_series_signals is None and time_series_embeds is None: + raise ValueError("You have to specify time_series_signals or time_series_embeds") + + if ( + time_series_embeds is not None + and len(time_series_embeds.shape) == 3 + and time_series_embeds.shape[-1] == self.config.ts_adapt_in_dim + ): + time_series_embeds = time_series_embeds + else: + if time_series_signals.ndim == 3: + time_series_embeds, ts_lens = self.encoder_embed(time_series_signals, ts_lens, sr, channels) + else: + raise ValueError(f"wrong time_series_signals size: {time_series_signals[0].shape}") + + # [B, 64000, 1] -> [B, 200, 256] -> [B, 100, 1024] + encoder_outputs = self.encoder( + input_features=time_series_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # ts_lens after encoder + ts_lens = (ts_lens + 1) // 2 + assert torch.all(ts_lens > 0), f"The length of time_series_embeds is so small. ts_lens: {ts_lens}" + + src_key_padding_mask = self.make_pad_mask(ts_lens) + last_hidden_state = encoder_outputs.last_hidden_state + + ts_pad_mask = src_key_padding_mask + ts_embeds = self.projector(last_hidden_state) + + return ts_embeds, ts_pad_mask + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class InternS2PreviewModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + router_logits: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for InternS2Preview causal language model (or autoregressive) outputs. + """ +) +class InternS2PreviewCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + router_logits: tuple[torch.FloatTensor] | None = None + aux_loss: torch.FloatTensor | None = None + + +@auto_docstring +class InternS2PreviewModel(InternS2PreviewPreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + + config: InternS2PreviewConfig + _no_split_modules = [ + "Qwen3_5MoeTextDecoderLayer", + "Qwen3_5MoeVisionBlock", + "InternS2PreviewTimeSeriesEncoderLayer", + ] + + def __init__(self, config): + super().__init__(config) + self.visual = InternS2PreviewVisionModel._from_config(config.vision_config) + self.language_model = InternS2PreviewTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + self.time_series = InternS2PreviewTimeSeriesModel._from_config(config.ts_config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, InternS2Preview use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to separate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + image_grid_thw_list = image_grid_thw.tolist() if image_grid_thw is not None else None + video_grid_thw_list = video_grid_thw.tolist() if video_grid_thw is not None else None + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = image_grid_thw_list[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw_list[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) + image_embeds = vision_output.pooler_output + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + vision_output.pooler_output = image_embeds + + return vision_output + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor | None = None, + video_features: torch.FloatTensor | None = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None: + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None: + torch_compilable_check( + inputs_embeds[special_video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", + ) + return special_image_mask, special_video_mask + + def compute_3d_position_ids( + self, + input_ids: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: torch.Tensor | None = None, + ) -> torch.Tensor | None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + can_compute_mrope = input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None) + + if can_compute_mrope and (self.rope_deltas is None or past_key_values_length == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # Use pre-calculated rope-deltas to infer correct 3D position ids + elif self.rope_deltas is not None: + batch_size, seq_length, _ = inputs_embeds.shape + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids = position_ids.masked_fill(attention_mask == 0, 0) + position_ids = position_ids.view(1, batch_size, -1).repeat(3, 1, 1).to(inputs_embeds.device) + else: + position_ids = torch.arange(past_key_values_length, past_key_values_length + seq_length) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1).to(inputs_embeds.device) + delta = self.rope_deltas.repeat_interleave(batch_size // self.rope_deltas.shape[0], dim=0) + position_ids = position_ids + delta.to(device=position_ids.device) + else: + # Can't build correct 3D positions. Let the model infer it from `cache_position` + position_ids = None + return position_ids + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + ts_values: torch.FloatTensor | list[torch.FloatTensor] = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + ts_lens: torch.Tensor | list[torch.Tensor] = None, + ts_sr: torch.FloatTensor | list[torch.FloatTensor] = None, + ts_channels: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | InternS2PreviewModelOutputWithPast: + r""" + ts_values (`torch.FloatTensor` of shape `(batch_size, seq_len, num_channels)`, *optional*): + The tensors corresponding to the input time series signals. + ts_lens (`torch.Tensor` of shape `(batch_size,)`, *optional*): + The valid lengths of each time series signal in the batch. + ts_sr (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): + The sampling rates of each time series signal in the batch. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if ts_values is not None: + ts_features, ts_pad_mask = self.get_ts_feature(ts_values, ts_lens, ts_sr, ts_channels) # [B, T, C], [B, T] + ts_features = ts_features[~ts_pad_mask].to( + inputs_embeds.device, inputs_embeds.dtype + ) # [num_valid_ts_tokens, C] + B, N, C = inputs_embeds.shape + input_ids = input_ids.reshape(B * N) + inputs_embeds = inputs_embeds.reshape(B * N, C) + # replace ts_token in inputs_embeds and attention_mask + ts_placeholder = input_ids == self.config.ts_token_id + n_ts_placeholders = ts_placeholder.sum().item() + n_ts_tokens = ts_features.size(0) + assert n_ts_placeholders == n_ts_tokens, ( + f"[ERROR]: Mismatch: tokens={n_ts_placeholders}, ts_embeds_valid={n_ts_tokens}" + ) + + try: + # TODO why not scatter? + inputs_embeds[ts_placeholder] = inputs_embeds[ts_placeholder] * 0.0 + ts_features + except Exception as e: + print( + f"warning: {e}, inputs_embeds[selected].shape={inputs_embeds[ts_placeholder].shape}, ts_embeds_valid.shape={ts_features.shape}" + ) + inputs_embeds[ts_placeholder] = inputs_embeds[ts_placeholder] * 0.0 + n_ts_tokens[:n_ts_placeholders] + + inputs_embeds = inputs_embeds.reshape(B, N, C) + + if position_ids is None: + position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return InternS2PreviewModelOutputWithPast( + **outputs, + rope_deltas=self.rope_deltas, + ) + + @can_return_tuple + @auto_docstring + def get_ts_feature( + self, + ts_values: torch.FloatTensor | list[torch.FloatTensor], + ts_lens: torch.Tensor | list[torch.Tensor], + sr: torch.FloatTensor | list[torch.FloatTensor], + ts_channels: torch.LongTensor | None = None, + ) -> tuple[torch.FloatTensor, torch.Tensor]: + r""" + ts_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, input_size)`): + The time series values to be fed to a model. + ts_lens (`torch.Tensor` of shape `(batch_size,)`): + The length of the time series. + sr (`torch.FloatTensor` of shape `(batch_size,)`): + The sampling rate of the time series. + """ + ts_embeds, ts_pad_mask = self.time_series( + time_series_signals=ts_values, + ts_lens=ts_lens, + sr=sr, + channels=ts_channels, + output_hidden_states=False, + return_dict=True, + ) + return ts_embeds, ts_pad_mask + + +# NOTE: Cannot inherit from `Qwen3_5MoeForCausalLM` here due to a converter limitation: when the modular +# `__init__` contains statements before `super().__init__()`, `_fix_init_location` unconditionally moves +# `super().__init__()` to the top, which prevents pre-processing `config` (e.g. extracting `config.text_config`) +# before the parent call. The full body is therefore written out explicitly. +class InternS2PreviewForCausalLM(InternS2PreviewPreTrainedModel, GenerationMixin): + # NOTE: `config` is annotated as `InternS2PreviewConfig` rather than `InternS2PreviewTextConfig` because remote + # code only exposes a single AutoConfig entry. The auto factory checks `model_class.config_class` against + # the loaded config type; a mismatch raises ValueError, so the annotation must match `InternS2PreviewConfig`. + config: InternS2PreviewConfig + + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] + + def __init__(self, config: InternS2PreviewConfig | InternS2PreviewTextConfig): + if isinstance(config, InternS2PreviewConfig): + config = config.text_config + super().__init__(config) + self.model = InternS2PreviewTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class InternS2PreviewForConditionalGeneration(InternS2PreviewPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + + config: InternS2PreviewConfig + + def __init__(self, config): + super().__init__(config) + self.model = InternS2PreviewModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + return self.model.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + ts_values: torch.FloatTensor | list[torch.FloatTensor] = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + ts_lens: torch.Tensor | list[torch.Tensor] = None, + ts_sr: torch.FloatTensor | list[torch.FloatTensor] = None, + ts_channels: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | InternS2PreviewCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration + + >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + ts_values=ts_values, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + ts_lens=ts_lens, + ts_sr=ts_sr, + ts_channels=ts_channels, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return InternS2PreviewCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + ts_values=None, + image_grid_thw=None, + video_grid_thw=None, + ts_lens=None, + ts_sr=None, + ts_channels=None, + is_first_iteration=False, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + ts_values=ts_values, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + ts_lens=ts_lens, + ts_sr=ts_sr, + ts_channels=ts_channels, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + if not is_first_iteration and use_cache: + model_inputs["ts_values"] = None + model_inputs["ts_lens"] = None + model_inputs["ts_sr"] = None + model_inputs["ts_channels"] = None + + return model_inputs + + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + # Overwritten -- requires 3D position ids + + text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) + + # Early exit in case we are continuing generation from past kv + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + text_positions += self.model.rope_deltas + return text_positions + + # Otherwise compute 3d position ids for vision tokens and concat with text position ids + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if is_input_ids and ( + model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None + ): + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + text_positions = text_positions[None, ...] + position_ids = torch.cat([text_positions, vision_positions], dim=0) + + return position_ids + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- InternS2Preview use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but InternS2Preview append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if key == "position_ids" and dict_to_expand[key].ndim == 3: + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1) + elif ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + @property + def time_series(self): + return self.model.time_series + + @auto_docstring + def get_ts_feature( + self, + ts_values: torch.FloatTensor | list[torch.FloatTensor], + ts_lens: torch.Tensor | list[torch.Tensor], + sr: torch.FloatTensor | list[torch.FloatTensor], + ): + r""" + ts_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, input_size)`): + The time series values to be fed to a model. + ts_lens (`torch.Tensor` of shape `(batch_size,)`): + The length of the time series. + sr (`torch.FloatTensor` of shape `(batch_size,)`): + The sampling rate of the time series. + """ + return self.model.get_ts_feature(ts_values, ts_lens, sr) + + +__all__ = [ + "InternS2PreviewVisionModel", + "InternS2PreviewTextModel", + "InternS2PreviewModel", + "InternS2PreviewForCausalLM", + "InternS2PreviewForConditionalGeneration", + "InternS2PreviewPreTrainedModel", +]