Spaces:
Running
Running
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Python Code Structure Visualizer</title> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| <script src="https://d3js.org/d3.v7.min.js"></script> | |
| <link rel="preconnect" href="https://fonts.googleapis.com"> | |
| <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Fira+Code:wght@400;500&display=swap" rel="stylesheet"> | |
| <style> | |
| body { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .fira-code { | |
| font-family: 'Fira Code', monospace; | |
| } | |
| /* Custom styles for D3 graph */ | |
| .graph-container { | |
| width: 100%; | |
| height: 100%; | |
| min-height: 500px; | |
| cursor: grab; | |
| } | |
| .graph-container:active { | |
| cursor: grabbing; | |
| } | |
| .node circle { | |
| stroke: #fff; | |
| stroke-width: 1.5px; | |
| } | |
| .node text { | |
| font-size: 10px; | |
| font-family: 'Fira Code', monospace; | |
| paint-order: stroke; | |
| stroke: #111827; /* Match dark background */ | |
| stroke-width: 3px; | |
| stroke-linecap: butt; | |
| stroke-linejoin: miter; | |
| pointer-events: none; | |
| } | |
| .link { | |
| stroke-opacity: 0.6; | |
| } | |
| .link.inheritance { | |
| stroke-dasharray: 5, 5; | |
| stroke: #60a5fa; /* blue-400 */ | |
| } | |
| .link.method { | |
| stroke: #4b5563; /* gray-600 */ | |
| } | |
| .node.selected > circle { | |
| stroke: #facc15; /* yellow-400 */ | |
| stroke-width: 3px; | |
| } | |
| /* Tab styling */ | |
| .tab.active { | |
| background-color: #3b82f6; /* blue-500 */ | |
| color: white; | |
| } | |
| </style> | |
| </head> | |
| <script id="main-code" type="text/plain"> | |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ | |
| # This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py. | |
| # Do NOT edit this file manually as any edits will be overwritten by the generation of | |
| # the file from the modular. If any change should be done, please apply the change to the | |
| # modular_gemma3n.py file directly. One of our CI enforces this. | |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ | |
| # coding=utf-8 | |
| # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. | |
| # | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import math | |
| from collections.abc import Callable, Sequence | |
| from dataclasses import dataclass | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...activations import ACT2FN | |
| from ...cache_utils import Cache, DynamicCache, HybridCache | |
| from ...generation import GenerationMixin | |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask | |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs | |
| from ...modeling_layers import GradientCheckpointingLayer | |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update | |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | |
| from ...processing_utils import Unpack | |
| from ...utils import ( | |
| ModelOutput, | |
| auto_docstring, | |
| can_return_tuple, | |
| is_torchdynamo_compiling, | |
| logging, | |
| ) | |
| from ...utils.deprecation import deprecate_kwarg | |
| from ..auto import AutoModel | |
| from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig | |
| logger = logging.get_logger(__name__) | |
| @dataclass | |
| @auto_docstring( | |
| custom_intro=""" | |
| Base class for Gemma3n outputs, with hidden states and attentions. | |
| """ | |
| ) | |
| class Gemma3nModelOutputWithPast(BaseModelOutputWithPast): | |
| r""" | |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
| `past_key_values` input) to speed up sequential decoding. | |
| image_hidden_states (`torch.FloatTensor`, *optional*): | |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. | |
| image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. | |
| audio_hidden_states (`torch.FloatTensor`, *optional*): | |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. | |
| audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. | |
| """ | |
| image_hidden_states: Optional[torch.FloatTensor] = None | |
| audio_hidden_states: Optional[torch.FloatTensor] = None | |
| @dataclass | |
| @auto_docstring( | |
| custom_intro=""" | |
| Base class for Gemma3n causal language model (or autoregressive) outputs. | |
| """ | |
| ) | |
| class Gemma3nCausalLMOutputWithPast(ModelOutput): | |
| r""" | |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
| Language modeling loss (for next-token prediction). | |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): | |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
| `past_key_values` input) to speed up sequential decoding. | |
| image_hidden_states (`torch.FloatTensor`, *optional*): | |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. | |
| image_hidden_states of the model produced by the vision encoder after projecting last hidden state. | |
| audio_hidden_states (`torch.FloatTensor`, *optional*): | |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. | |
| audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. | |
| """ | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: Optional[torch.FloatTensor] = None | |
| past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None | |
| hidden_states: Optional[tuple[torch.FloatTensor]] = None | |
| attentions: Optional[tuple[torch.FloatTensor]] = None | |
| image_hidden_states: Optional[torch.FloatTensor] = None | |
| audio_hidden_states: Optional[torch.FloatTensor] = None | |
| class Gemma3nRMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): | |
| super().__init__() | |
| self.eps = eps | |
| self.with_scale = with_scale | |
| if self.with_scale: | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| else: | |
| self.register_buffer("weight", torch.tensor(1.0), persistent=False) | |
| def _norm(self, x): | |
| return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) | |
| # See https://github.com/huggingface/transformers/pull/29402 | |
| output = self._norm(x.float()) * self.weight.float() | |
| return output.type_as(x) | |
| def extra_repr(self): | |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" | |
| # ==== Audio Encoder ==== | |
| class Gemma3nAudioRelativePositionEmbedding(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.num_heads = self.config.conf_num_attention_heads | |
| self.channels = self.config.hidden_size | |
| self.head_dim = self.channels // self.num_heads | |
| self.max_backward = max(0, self.config.conf_attention_context_left - 1) | |
| self.max_forward = self.config.conf_attention_context_right | |
| self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False) | |
| min_timescale = 1.0 | |
| max_timescale = 1.0e4 | |
| num_timescales = self.channels // 2 | |
| log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) | |
| inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) | |
| self.register_buffer( | |
| "inv_timescales", | |
| inv_timescales.float().unsqueeze(0).unsqueeze(0), | |
| persistent=False, | |
| ) | |
| def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: | |
| position = position.float().unsqueeze(-1) | |
| scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32) | |
| timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) | |
| return timing_signal.type(dtype) | |
| def _relative_shift( | |
| self, | |
| term_bd_before_shift: torch.Tensor, | |
| batch_size: int, | |
| num_heads: int, | |
| num_query_blocks: int, | |
| query_block_size: int, | |
| key_context_size: int, | |
| max_span_plus_1: int, | |
| ) -> torch.Tensor: | |
| """Performs the relative shift. | |
| Args: | |
| term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size | |
| (B), num_heads (N), num_query_blocks (U), query_block_size (W), | |
| key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1). | |
| Returns: | |
| Tensor of shape [B, N, U, W, C]. | |
| """ | |
| # term_bd_before_shift shape: [B, N, U, W, F_span] | |
| # Target shape after shift: [B, N, U, W, C] | |
| # Padding amount for the last dimension (F_span) to become (C + 1) | |
| # C = key_context_size | |
| # F_span = max_span_plus_1 | |
| pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 | |
| # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...) | |
| # We only pad the last dimension on the right. | |
| padding_tuple = (0, pad_amount_last_dim) | |
| term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple) | |
| # Shape after pad: [B, N, U, W, C+1] | |
| # Reshape for slicing (emulating JAX's behavior) | |
| # [B, N, U, W * (C+1)] | |
| term_bd_reshaped = term_bd_padded.reshape( | |
| ( | |
| batch_size, | |
| num_heads, | |
| num_query_blocks, | |
| query_block_size * (key_context_size + 1), | |
| ) | |
| ) | |
| # Slice to effective [B, N, U, W * C] | |
| term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size] | |
| # Reshape back to [B, N, U, W, C] | |
| term_bd_shifted = term_bd_sliced.reshape( | |
| ( | |
| batch_size, | |
| num_heads, | |
| num_query_blocks, | |
| query_block_size, | |
| key_context_size, | |
| ) | |
| ) | |
| return term_bd_shifted | |
| def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: | |
| # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim) | |
| # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim) | |
| # C = W + L + R (key_context_size) | |
| # F_span = L + R + 1 (max_span + 1) | |
| batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape | |
| _, _, key_context_size, _, _ = keys.shape | |
| # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R] | |
| # Length is L+R+1 = self.max_span + 1 | |
| pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze( | |
| 0 | |
| ) # Shape [1, F_span] | |
| max_span_plus_1 = pos_indices.shape[1] # F_span | |
| sin_emb_timing_signal = self._get_timing_signal_1d_pos( | |
| pos_indices, dtype=queries.dtype | |
| ) # Shape [1, F_span, self.channels] | |
| # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H] | |
| projected_sin_emb = self.pos_proj(sin_emb_timing_signal) | |
| # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H] | |
| sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze( | |
| 0 | |
| ) # Shape [F, N, H] | |
| # term_ac: Query-Key content interaction | |
| # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul | |
| # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul | |
| queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H] | |
| keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C] | |
| term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C] | |
| # term_bd: Query-Position interaction | |
| # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb) | |
| # queries shape: [B, U, W, N, H] | |
| # sin_emb shape: [F, N, H] | |
| # Target output shape: [B, N, U, W, F] | |
| # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb | |
| q_permuted = queries.permute(0, 3, 1, 2, 4) | |
| # Permute sin_emb to [N, H, F] to prepare for matmul | |
| # sin_emb original is [F, N, H] | |
| s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F] | |
| # Reshape queries for matmul: [B, N, U*W, H] | |
| q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim) | |
| # Perform matmul: [B, N, U*W, H] @ [N, H, F] | |
| # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F] | |
| # Result: [B, N, U*W, F] | |
| term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) | |
| # Reshape to target [B, N, U, W, F] | |
| term_bd_unshifed = term_bd_unshifed_matmul.reshape( | |
| batch_size, | |
| num_heads, | |
| num_query_blocks, | |
| query_block_size, | |
| max_span_plus_1, | |
| ) | |
| # Apply relative shift to term_bd_unshifed | |
| term_bd_shifted = self._relative_shift( | |
| term_bd_unshifed, | |
| batch_size, | |
| num_heads, | |
| num_query_blocks, | |
| query_block_size, | |
| key_context_size, | |
| max_span_plus_1, | |
| ) # Shape [B, N, U, W, C] | |
| return term_ac + term_bd_shifted | |
| class Gemma3nAudioAttention(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.num_heads = self.config.conf_num_attention_heads | |
| self.hidden_size = self.config.hidden_size | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.chunk_size = self.config.conf_attention_chunk_size | |
| self.max_future_horizon = self.config.conf_attention_context_right | |
| self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1) | |
| self.attention_logits_soft_cap = self.config.conf_attention_logit_cap | |
| self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon | |
| self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config) | |
| self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| q_scale = self.head_dim**-0.5 | |
| r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0)) | |
| self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False) | |
| lower_causal_mask = torch.tril( | |
| torch.ones((self.context_size, self.chunk_size), dtype=torch.bool), | |
| diagonal=0, | |
| ).T | |
| upper_causal_mask = torch.tril( | |
| torch.ones((self.chunk_size, self.context_size), dtype=torch.bool), | |
| diagonal=self.max_past_horizon + self.max_future_horizon, | |
| ) | |
| local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool) | |
| local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask | |
| self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False) | |
| self.register_buffer( | |
| "softcap", | |
| torch.tensor(self.attention_logits_soft_cap).float(), | |
| persistent=False, | |
| ) | |
| def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: | |
| batch, _, *tail_shape = x.shape | |
| left = x.new_zeros((batch, pad_left, *tail_shape)) | |
| right = x.new_zeros((batch, pad_right, *tail_shape)) | |
| x = torch.cat([left, x, right], dim=1) | |
| return x | |
| def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """Turns a sequence to non overlapping blocks. | |
| Args: | |
| hidden_states: a tensor of [batch, time, ...]. | |
| Returns: | |
| A tensor of [batch, num_blocks, block_size, ...], with necessary | |
| paddings, | |
| where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. | |
| """ | |
| shape = hidden_states.shape | |
| b, t = shape[:2] | |
| num_blocks = (t + self.chunk_size - 1) // self.chunk_size | |
| if (padding_len := num_blocks * self.chunk_size - t) > 0: | |
| hidden_states = self._pad_dim1(hidden_states, 0, padding_len) | |
| permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] | |
| hidden_states = hidden_states.reshape(permute_dims).contiguous() | |
| return hidden_states | |
| def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """Extracts temporal context for every block. | |
| Args: | |
| hidden_states: a tensor of [batch, time, ...]. | |
| Returns: | |
| A tensor of [batch, num_blocks, context_size, ...], with necessary | |
| paddings, | |
| where context_size = block_size + left_context + right_context, | |
| and output[:, i, ...] are x[:, start-left_context:end+right_context, | |
| ...], | |
| start = i * block_size, end = (i + 1) * block_size. | |
| """ | |
| pad_left = self.max_past_horizon | |
| # The JAX equivalent padding for signal.frame with pad_mode='valid' is | |
| # (left_context, right_context + block_size - 1) on the time dimension. | |
| # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given, | |
| # or (pad_dim_start, pad_dim_end) if two are given. | |
| # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H]) | |
| # or dim 1 (time for [B,T]). | |
| # The current pad_right calculation matches the JAX effective padding. | |
| pad_right = self.max_future_horizon + self.chunk_size - 1 | |
| hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right) | |
| frame_len = self.context_size | |
| frame_step = self.chunk_size | |
| # Directly use unfold without the subframe_factor logic | |
| # x.unfold(dimension, size, step) | |
| # dimension=1 (time dimension, assuming x is [B, T_padded, ...]) | |
| # size=frame_len (context_size) | |
| # step=frame_step (chunk_size) | |
| x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step) | |
| # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len] | |
| # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len] | |
| # We want to match JAX's typical output for such operations which might be | |
| # [B, num_blocks, frame_len, N, H] if N, H are present. | |
| # The relative_position_embedding expects keys as [B, U, C, N, H]. | |
| # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C. | |
| if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist | |
| # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C] | |
| # Target shape for keys in RPE: [B, U, C, N, H] | |
| x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) | |
| return x_unfolded.contiguous() | |
| def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: | |
| # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select() | |
| qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim) | |
| query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous() | |
| key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous() | |
| value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous() | |
| per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale) | |
| broadcast_shape = (1, 1, 1, self.head_dim) | |
| per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape) | |
| query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast | |
| batch_size, q_time = query_states.shape[:2] | |
| query_blocks = self._convert_to_block(query_states) | |
| key_blocks = self._extract_block_context(key_states) | |
| value_blocks = self._extract_block_context(value_states) | |
| num_query_blocks = query_blocks.shape[1] | |
| # 1. Create a mask indicating originally valid positions. | |
| original_valid_mask = ~mask # True for valid, False for padded | |
| # 2. Extract blocks from this validity mask. | |
| extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) | |
| # If subframe_factor was used in _extract_block_context for a [B, T] input mask, | |
| # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C]. | |
| # batch_size and num_query_blocks are known from query_blocks. | |
| # self.context_size is C. | |
| if ( | |
| extracted_valid_mask_blocks.ndim == 4 | |
| and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size | |
| ): | |
| extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( | |
| batch_size, num_query_blocks, self.context_size | |
| ) | |
| # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask. | |
| # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently, | |
| # but for the mask case, this should hold. | |
| if extracted_valid_mask_blocks.shape != ( | |
| batch_size, | |
| num_query_blocks, | |
| self.context_size, | |
| ): | |
| raise ValueError( | |
| "Shape of extracted_valid_mask_blocks" | |
| f" {extracted_valid_mask_blocks.shape} is not ({batch_size}," | |
| f" {num_query_blocks}, {self.context_size}) after potential reshape." | |
| ) | |
| # 3. Expand dimensions for broadcasting with logits and causal mask. | |
| # Target shape for broadcasting with logits [B,N,U,W,C] | |
| # extracted_valid_mask_blocks to [B, 1, U, 1, C] | |
| condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2) | |
| # self.local_causal_valid_mask is [W, C], True where allowed by local window. | |
| # Expand to [1, 1, 1, W, C] | |
| condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
| # 4. Combine the two conditions. | |
| # final_condition will be True where a key is *both* originally valid *and* causally accessible. | |
| # Broadcasts to [B, 1, U, W, C] | |
| final_condition_for_where = torch.logical_and( | |
| condition_from_input_validity, | |
| condition_from_causality.to(condition_from_input_validity.device), # Ensure same device | |
| ) | |
| # Embed queries and keys | |
| logits = self.relative_position_embedding(query_blocks, key_blocks) | |
| # Apply attention logit softcap | |
| # Ensure softcap is on the same device as logits | |
| softcap_val = self.softcap.to(logits.device) | |
| logits = logits / softcap_val | |
| logits = torch.tanh(logits) | |
| logits = logits * softcap_val | |
| # Apply the combined mask. | |
| # final_condition_for_where will broadcast with logits [B,N,U,W,C] | |
| logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min) | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype) | |
| # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...) | |
| b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape | |
| h_dim = value_blocks.shape[-1] | |
| prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) | |
| v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) | |
| result_bmm = torch.bmm(prob_bun, v_bun) | |
| context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4) | |
| context_vectors = context_vectors.reshape( | |
| ( | |
| batch_size, | |
| num_query_blocks * self.chunk_size, | |
| self.num_heads, | |
| self.head_dim, | |
| ) | |
| ) | |
| context_vectors = context_vectors[:, :q_time] | |
| return context_vectors | |
| class Gemma3nAudioCumulativeGroupNorm(nn.Module): | |
| """Applies Group Normalization cumulatively over the time dimension. | |
| This layer normalizes the input by calculating the mean and variance | |
| cumulatively over the time dimension (dim 1). The statistics are computed | |
| over all feature dimensions (specified by `feature_dims` and `num_channels`) | |
| for elements marked as valid by the optional `mask`. | |
| If a `mask` is provided (True for valid, False for invalid/padded), | |
| invalid time steps do not contribute to the statistics calculation, and | |
| their corresponding output values are zeroed out. | |
| Scale and bias, if enabled, are applied per-channel (last dimension). | |
| This behavior is similar to JAX's `GroupNormalization` with `num_groups=1` | |
| and `cumulative=True`. | |
| """ | |
| def __init__( | |
| self, | |
| num_channels: int, # Number of channels (size of the last dimension) | |
| feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C] | |
| eps: float = 1e-3, | |
| ): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.feature_dims = tuple(feature_dims) | |
| self.eps = eps | |
| # Scale parameter depends only on the channel dimension | |
| self.weight = nn.Parameter(torch.ones(num_channels)) | |
| # Axes for normalization: all dimensions except Batch (0) and Time (1). | |
| # For input [B, T, *feature_dims, C], these are dims from 2 onwards. | |
| self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """Applies cumulative group norm, optionally using a mask. | |
| Args: | |
| hidden_states: Input tensor, shape [B, T, *feature_dims, C]. | |
| Returns: | |
| Normalized tensor with the same shape as x. | |
| """ | |
| expected_input_suffix = self.feature_dims + (self.num_channels,) | |
| if hidden_states.shape[2:] != expected_input_suffix: | |
| raise ValueError( | |
| f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected" | |
| f" suffix (feature_dims + num_channels) {expected_input_suffix}" | |
| ) | |
| input_dtype = hidden_states.dtype | |
| # Calculations are performed in float32 for numerical stability. | |
| calc_dtype = torch.float32 | |
| x_calc = hidden_states.to(calc_dtype) | |
| # Prepare a broadcastable mask (`mask_calc`). | |
| # If no mask is provided, treat all elements as valid | |
| # (mask_calc is all ones). | |
| # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting. | |
| mask_calc = torch.ones_like(x_calc, dtype=calc_dtype) | |
| # Cumulative Statistics Calculation | |
| # 1. Sum of values over reduction axes at each time step. | |
| sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True) | |
| # 2. Cumulative sum of values over time. | |
| cum_sum_values = torch.cumsum(sum_values_at_t, dim=1) | |
| # 3. Count of valid elements in the normalization group at each time step. | |
| # (A "group" here consists of all features at a given Batch, Time). | |
| elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True) | |
| # 4. Cumulative count of valid elements over time. | |
| cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1) | |
| # Avoid division by zero if all preceding elements were masked. | |
| safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0) | |
| # 5. Cumulative mean. | |
| cum_mean = cum_sum_values / safe_cum_count_elements | |
| # 6. Sum of squared differences from the cumulative mean. | |
| # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc. | |
| # Using x_calc here for the difference, as cum_mean already accounts for masking. | |
| squared_diff_from_mean = (x_calc - cum_mean).pow(2) | |
| sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True) | |
| # 7. Cumulative sum of squared differences over time. | |
| cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1) | |
| # 8. Cumulative variance. | |
| cum_variance = cum_sum_sq_diff / safe_cum_count_elements | |
| # Normalize the input using the calculated cumulative statistics: | |
| # (x - E[x]) / sqrt(Var[x] + eps) | |
| normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps) | |
| # Apply affine transformation (scale and bias) if enabled. | |
| # Scale and bias are applied per-channel (last dimension). | |
| scale = self.weight.to(calc_dtype) | |
| # Reshape for broadcasting: [C] -> [1, ..., 1, C] | |
| scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels] | |
| normalized_x = normalized_x * scale.view(scale_view_shape) | |
| # Zero out outputs for time steps that were originally masked (where mask_calc is 0). | |
| # This ensures padded/invalid positions in the input result in zero output. | |
| final_output = normalized_x * mask_calc | |
| return final_output.to(input_dtype) | |
| class Gemma3nAudioSSCPConvBlock(nn.Module): | |
| """A single convolution block for the SubSampleConvProjection. | |
| This block consists of a 2D convolution, followed by CumulativeGroupNorm, | |
| and a ReLU activation. It handles manual padding for the convolution. | |
| """ | |
| def __init__( | |
| self, | |
| config: Gemma3nAudioConfig, | |
| idx: int, | |
| input_freq_dim: int, # Changed from input_spatial_dim | |
| manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0), | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.manual_padding = manual_padding | |
| # in_channels is 1 for the first block, or C_out from previous block's conv | |
| in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1] | |
| out_channels = self.config.sscp_conv_channel_size[idx] | |
| kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx] | |
| stride_h, stride_w = self.config.sscp_conv_stride_size[idx] | |
| self.conv = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=( | |
| kernel_h, | |
| kernel_w, | |
| ), # Kernel (kH, kW) operates on (Time, Freq_dim) | |
| stride=(stride_h, stride_w), | |
| padding=(0, 0), # Manual padding is used | |
| bias=False, | |
| ) | |
| # Calculate output frequency dimension (f_out_conv) after this convolution. | |
| # input_freq_dim is the unpadded width (feature dimension). | |
| # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) | |
| f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1] | |
| f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 | |
| self.norm = Gemma3nAudioCumulativeGroupNorm( | |
| num_channels=out_channels, # Channels of the conv output | |
| feature_dims=(f_out_conv,), # The frequency dimension size after conv | |
| eps=self.config.sscp_conv_group_norm_eps, | |
| ) | |
| self.activation = nn.ReLU() | |
| def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: | |
| # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1) | |
| # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) | |
| # F.pad applies to last two dims: F_in then T_in | |
| audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0) | |
| # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2 | |
| # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2 | |
| audio_encodings_conv = self.conv(audio_encodings_padded) | |
| # Expected conv output shape: [B, C_out, T_out, F_out] | |
| # Input to norm is [B, T_out, F_out, C_out] | |
| x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous() | |
| x_normed = self.norm(x_for_norm) | |
| # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out] | |
| audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() | |
| return self.activation(audio_encodings_normed) | |
| class Gemma3nAudioSubSampleConvProjection(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| current_f_for_block_input = config.input_feat_size # Start with original feature dim | |
| calculated_block_padding = [] | |
| calculated_f_out_dims = [] # Tracking frequency dimension output sizes | |
| for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays | |
| kernel_h, kernel_w = config.sscp_conv_kernel_size[i] | |
| stride_h, stride_w = config.sscp_conv_stride_size[i] | |
| # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like | |
| # JAX 'reverse_causal' padding is (0, kernel_size - 1) | |
| pad_t_top = 0 | |
| pad_t_bottom = kernel_h - 1 | |
| # Frequency Padding (Width for Conv2d) | |
| # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2 | |
| # and the successful test configuration. | |
| # If kernel/stride/input_freq for frequency changes, this might need re-evaluation | |
| # to match generic JAX 'SAME' behavior if it differs. | |
| pad_f_left = 1 | |
| pad_f_right = 1 | |
| manual_padding_tuple = ( | |
| pad_f_left, | |
| pad_f_right, | |
| pad_t_top, | |
| pad_t_bottom, | |
| ) | |
| calculated_block_padding.append(manual_padding_tuple) | |
| # Calculate output frequency dimension after this convolution | |
| # This uses the actual padding applied and kernel/stride. | |
| f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right | |
| f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1 | |
| calculated_f_out_dims.append(f_out_after_conv) | |
| current_f_for_block_input = f_out_after_conv | |
| self.conv_0 = Gemma3nAudioSSCPConvBlock( | |
| idx=0, | |
| input_freq_dim=config.input_feat_size, # Pass original feature dim | |
| config=config, | |
| manual_padding=calculated_block_padding[0], | |
| ) | |
| self.conv_1 = Gemma3nAudioSSCPConvBlock( | |
| idx=1, | |
| input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0 | |
| config=config, | |
| manual_padding=calculated_block_padding[1], | |
| ) | |
| final_c_out = config.sscp_conv_channel_size[-1] | |
| final_f_out = calculated_f_out_dims[-1] # Final frequency dimension | |
| self.input_proj_in_features = final_c_out * final_f_out | |
| self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False) | |
| def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: | |
| # audio_encodings is [B, T, F_in] | |
| # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in) | |
| audio_encodings_reshaped = audio_encodings.unsqueeze(1) | |
| x = self.conv_0(audio_encodings_reshaped) | |
| x = self.conv_1(x) | |
| # x from conv_1 is [B, C_out_1, T_out_1, F_out_1] | |
| b, c_out, t_out, f_out = x.shape | |
| # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1 | |
| x_permuted = x.permute(0, 2, 3, 1).contiguous() | |
| output_flattened = x_permuted.view(b, t_out, f_out * c_out) | |
| output = self.input_proj_linear(output_flattened) | |
| return output | |
| class Gemma3nAudioConformerAttention(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.post_in_features = self.config.hidden_size | |
| self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) | |
| self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size) | |
| self.attn = Gemma3nAudioAttention(config) | |
| self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False) | |
| self.post_norm = Gemma3nRMSNorm(self.config.hidden_size) | |
| def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: | |
| audio_encodings_input_to_attn = audio_encodings | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| audio_encodings_norm = self.pre_attn_norm(audio_encodings) | |
| # Output of self.attn is [B, T, NumHeads, HeadDim] | |
| audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask) | |
| # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim] | |
| # NumHeads * HeadDim = hidden_size | |
| b, t, num_heads, head_dim = audio_encodings_attn_out.shape | |
| audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim) | |
| audio_encodings = self.post(audio_encodings_reshaped) | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| return audio_encodings_input_to_attn + self.post_norm(audio_encodings) | |
| class Gemma3nAudioConformerFeedForward(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) | |
| self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) | |
| self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False) | |
| self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False) | |
| self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) | |
| self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) | |
| def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: | |
| residual = audio_encodings | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| audio_encodings = self.pre_layer_norm(audio_encodings) | |
| audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings) | |
| audio_encodings = nn.functional.silu(audio_encodings) | |
| audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings) | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| audio_encodings = self.post_layer_norm(audio_encodings) | |
| return residual + (audio_encodings * self.post_layer_scale) | |
| class Gemma3nAudioConformerLightConv1d(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) | |
| self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False) | |
| self.depthwise_conv1d = nn.Conv1d( | |
| in_channels=self.config.hidden_size, | |
| out_channels=self.config.hidden_size, | |
| kernel_size=self.config.conf_conv_kernel_size, | |
| stride=1, | |
| padding=0, # Manual causal padding | |
| groups=self.config.hidden_size, # Depthwise | |
| bias=False, | |
| ) | |
| self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) | |
| self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) | |
| self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) | |
| self.causal_padding = self.config.conf_conv_kernel_size - 1 | |
| def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: | |
| audio_encodings_residual = audio_encodings # Save for residual connection | |
| audio_encodings = self.pre_layer_norm(audio_encodings) | |
| audio_encodings = self.linear_start(audio_encodings) | |
| audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1) | |
| # Permute for Conv1d: [B, T, D] -> [B, D, T] | |
| audio_encodings_permuted = audio_encodings.permute(0, 2, 1) | |
| # Apply manual causal padding | |
| audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0)) | |
| audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) | |
| # Permute back: [B, D, T_out] -> [B, T_out, D] | |
| audio_encodings = audio_encodings.permute(0, 2, 1) | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| audio_encodings = self.conv_norm(audio_encodings) | |
| audio_encodings = nn.functional.silu(audio_encodings) | |
| audio_encodings = self.linear_end(audio_encodings) | |
| output = audio_encodings + audio_encodings_residual | |
| return output | |
| class Gemma3nAudioConformerBlock(nn.Module): | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__() | |
| self.config = config | |
| self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config) | |
| self.attention = Gemma3nAudioConformerAttention(self.config) | |
| self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config) | |
| self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config) | |
| self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) | |
| self.norm = Gemma3nRMSNorm(self.config.hidden_size) | |
| def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: | |
| audio_encodings = self.ffw_layer_start(audio_encodings) | |
| audio_encodings = self.attention(audio_encodings, audio_mel_mask) | |
| validity_mask_for_lconv = ~audio_mel_mask # True for valid | |
| audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to( | |
| audio_encodings.dtype | |
| ) | |
| audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) | |
| audio_encodings = self.ffw_layer_end(audio_encodings) | |
| audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) | |
| output = self.norm(audio_encodings) | |
| return output | |
| class Gemma3nAudioEncoder(PreTrainedModel): | |
| """An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture.""" | |
| config_class = Gemma3nAudioConfig | |
| main_input_name = "audio_mel" | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) | |
| self.conformer = nn.ModuleList( | |
| [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] | |
| ) | |
| def forward( | |
| self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor | |
| ) -> tuple[torch.Tensor, torch.BoolTensor]: | |
| """Encodes a batch of MELs. | |
| Args: | |
| audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels, | |
| mel_bins]. | |
| Returns: | |
| audio_encodings: a torch.Tensor of shape | |
| `[batch_size, self.config.audio_soft_tokens_per_image, | |
| self.config.audio_config.hidden_size]` | |
| audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. | |
| """ | |
| audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D] | |
| # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub) | |
| t_sub = audio_encodings.shape[1] | |
| time_stride_product = 1 | |
| for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)): | |
| time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0] | |
| # Create indices for gathering from the original mask. | |
| # These indices map to original time steps corresponding to the start of each | |
| # receptive field in the subsampled output. | |
| indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product | |
| indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid | |
| # Expand indices for batch compatibility if B > 1 and indices is 1D. | |
| if audio_mel_mask.ndim > 1 and indices.ndim == 1: | |
| indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub] | |
| elif ( | |
| audio_mel_mask.ndim == indices.ndim | |
| and audio_mel_mask.shape[0] == 1 | |
| and indices.shape[0] != 1 | |
| and t_sub == indices.shape[0] | |
| ): | |
| # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub] | |
| indices = indices.unsqueeze(0) | |
| current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub] | |
| for block in self.conformer: | |
| audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask | |
| if self.config.conf_reduction_factor > 1: | |
| audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] | |
| # Reduce the mask as well | |
| current_mask = current_mask[:, :: self.config.conf_reduction_factor] | |
| audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) | |
| return audio_encodings, current_mask | |
| class Gemma3nTextScaledWordEmbedding(nn.Embedding): | |
| """ | |
| This module overrides nn.Embeddings' forward by multiplying with embeddings scale. | |
| """ | |
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): | |
| super().__init__(num_embeddings, embedding_dim, padding_idx) | |
| self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | |
| def forward(self, input_ids: torch.Tensor): | |
| return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) | |
| class Gemma3nTextLaurelBlock(nn.Module): | |
| """Learned Augmented Residual Layer""" | |
| def __init__(self, config: Gemma3nTextConfig): | |
| super().__init__() | |
| self.config = config | |
| self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False) | |
| self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False) | |
| self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states) | |
| laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states) | |
| normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) | |
| return hidden_states + normed_laurel_hidden_states | |
| class Gemma3nTextMLP(nn.Module): | |
| def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size[layer_idx] | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_activation] | |
| self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| gate_proj = self.gate_proj(hidden_states) | |
| if self.activation_sparsity > 0.0: | |
| gate_proj = self._gaussian_topk(gate_proj) | |
| activations = self.act_fn(gate_proj) | |
| up_proj = self.up_proj(hidden_states) | |
| down_proj = self.down_proj(activations * up_proj) | |
| return down_proj | |
| def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: | |
| target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device) | |
| # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf(). | |
| # | |
| # References: | |
| # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html | |
| # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal | |
| # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf | |
| normal_dist = torch.distributions.normal.Normal(0, 1) | |
| std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor) | |
| std_multiplier = std_multiplier.type(inputs.dtype) | |
| inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) | |
| inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) | |
| cutoff_x = inputs_mean + inputs_std * std_multiplier | |
| return nn.functional.relu(inputs - cutoff_x) | |
| class Gemma3nTextAltUp(nn.Module): | |
| """Alternating Updates (AltUp) | |
| The AltUp module wraps transformer layers. The `predict` step modifies the | |
| input to the transformer layer, and the `correct` step propagates the output | |
| of the transformer layer to the sparsely updated dimensions. | |
| See more in the research paper: | |
| https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf | |
| """ | |
| def __init__(self, config: Gemma3nTextConfig): | |
| super().__init__() | |
| self.config = config | |
| self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size)) | |
| self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False) | |
| self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False) | |
| self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False) | |
| self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) | |
| self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False) | |
| def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: | |
| router_inputs = self.router_norm(x) * self.router_input_scale | |
| routed = self.modality_router(router_inputs) | |
| return torch.tanh(routed.float()).type_as(x) | |
| def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """Predicts the output of a layer using a trainable map. | |
| Args: | |
| hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by | |
| stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. | |
| Returns: | |
| A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions. | |
| """ | |
| modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx]) | |
| if self.training and self.config.altup_coef_clip is not None: | |
| self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) | |
| # Project and then transpose all 2D matrices contained so that mulmat gives the correct result | |
| all_coefs: torch.Tensor = ( | |
| self.prediction_coefs(modalities) | |
| .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs) | |
| .permute(0, 1, 3, 2) | |
| ) | |
| # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs] | |
| predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) | |
| predictions = predictions.permute(3, 0, 1, 2) # undo the permute | |
| predictions += hidden_states # add the original input | |
| return predictions.contiguous().type_as(hidden_states) | |
| def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: | |
| """Corrects the predictions relative to the | |
| Args: | |
| predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by | |
| stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. | |
| activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs. | |
| Returns: | |
| A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original | |
| predictions relative to the activated input embeddings. | |
| """ | |
| modalities = self.compute_router_modalities(activated) | |
| innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size) | |
| innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions | |
| if self.config.altup_coef_clip is not None: | |
| self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) | |
| # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...) | |
| # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input | |
| # and expand on dim1 for broadcastability | |
| all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0 | |
| all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) | |
| corrected = torch.mul(innovation, all_coefs) | |
| corrected += predictions # add the original input | |
| return corrected.contiguous().type_as(activated) | |
| def forward(self, corrected: torch.Tensor) -> torch.Tensor: | |
| """ | |
| This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale` | |
| (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in | |
| `scale_corrected_output` | |
| """ | |
| return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) | |
| def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: | |
| """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" | |
| return self.forward(corrected) | |
| class Gemma3nTextRotaryEmbedding(nn.Module): | |
| def __init__(self, config: Gemma3nTextConfig, device=None): | |
| super().__init__() | |
| # BC: "rope_type" was originally "type" | |
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None: | |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) | |
| else: | |
| self.rope_type = "default" | |
| self.max_seq_len_cached = config.max_position_embeddings | |
| self.original_max_seq_len = config.max_position_embeddings | |
| self.config = config | |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.original_inv_freq = self.inv_freq | |
| @torch.no_grad() | |
| @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) | |
| def forward(self, x, position_ids): | |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) | |
| position_ids_expanded = position_ids[:, None, :].float() | |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" | |
| with torch.autocast(device_type=device_type, enabled=False): # Force float32 | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() * self.attention_scaling | |
| sin = emb.sin() * self.attention_scaling | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """ | |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
| """ | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| dropout: float = 0.0, | |
| scaling: Optional[float] = None, | |
| softcap: Optional[float] = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if scaling is None: | |
| scaling = module.head_dim**-0.5 | |
| key_states = repeat_kv(key, module.num_key_value_groups) | |
| value_states = repeat_kv(value, module.num_key_value_groups) | |
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling | |
| if softcap is not None: | |
| attn_weights = attn_weights / softcap | |
| attn_weights = torch.tanh(attn_weights) | |
| attn_weights = attn_weights * softcap | |
| if attention_mask is not None: # no matter the length, we just slice it | |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
| attn_weights = attn_weights + causal_mask | |
| # upcast attention to fp32 | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights | |
| def apply_rotary_pos_emb( | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| position_ids: Optional[torch.Tensor] = None, | |
| unsqueeze_dim: int = 1, | |
| ): | |
| """Applies Rotary Position Embedding to the query and key tensors. | |
| Args: | |
| x (`torch.Tensor`): The tensor to embed. | |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
| sin (`torch.Tensor`): The sine part of the rotary embedding. | |
| position_ids (`torch.Tensor`, *optional*): | |
| Deprecated and unused. | |
| unsqueeze_dim (`int`, *optional*, defaults to 1): | |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
| Returns: | |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
| """ | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| return (x * cos) + (rotate_half(x) * sin) | |
| class Gemma3nTextAttention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, config: Gemma3nTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads | |
| self.attention_dropout = self.config.attention_dropout | |
| self.is_causal = True | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.o_proj = nn.Linear( | |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias | |
| ) | |
| self.sliding_window = config.sliding_window if self.is_sliding else None | |
| self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) | |
| self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) | |
| first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers | |
| self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 | |
| # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) | |
| layer_type = config.layer_types[layer_idx] | |
| self.kv_shared_layer_index = ( | |
| first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) | |
| if self.is_kv_shared_layer | |
| else None | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| past_key_value: Optional[Cache] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: | |
| input_shape = hidden_states.shape[:-1] | |
| hidden_shape = (*input_shape, -1, self.config.head_dim) | |
| cos, sin = position_embeddings | |
| query_states = self.q_proj(hidden_states).view(hidden_shape) | |
| query_states = self.q_norm(query_states) | |
| query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) | |
| query_states = query_states.transpose(1, 2) | |
| if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: | |
| # Device of past layer may be different from current one | |
| indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) | |
| # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) | |
| if isinstance(past_key_value, HybridCache) and self.is_sliding: | |
| max_length = past_key_value.sliding_window | |
| indices = ( | |
| slice(0, max_length) | |
| if cache_position.shape[0] > max_length | |
| else cache_position.clamp(min=0, max=max_length - 1) | |
| ) | |
| # Device of past layer may be different from current one | |
| key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) | |
| value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( | |
| query_states.device | |
| ) | |
| else: | |
| key_states = self.k_proj(hidden_states).view(hidden_shape) | |
| key_states = self.k_norm(key_states) | |
| key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape) | |
| value_states = self.v_norm(value_states) | |
| value_states = value_states.transpose(1, 2) | |
| if past_key_value is not None: | |
| # sin and cos are specific to RoPE models; cache_position needed for the static cache | |
| cache_kwargs = { | |
| "sin": sin, | |
| "cos": cos, | |
| "cache_position": cache_position, | |
| "sliding_window": self.sliding_window, | |
| } | |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=self.attention_dropout if self.training else 0.0, | |
| scaling=1.0, | |
| sliding_window=self.sliding_window, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, attn_weights | |
| class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): | |
| def __init__(self, config: Gemma3nTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.layer_idx = layer_idx | |
| self.attention_type = config.layer_types[layer_idx] | |
| self.self_attn = Gemma3nTextAttention(config, layer_idx) | |
| self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) | |
| self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) | |
| self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) | |
| self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) | |
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
| self.act_fn = ACT2FN[config.hidden_activation] | |
| self.altup = Gemma3nTextAltUp(config) | |
| self.laurel = Gemma3nTextLaurelBlock(config) | |
| self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) | |
| self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) | |
| self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) | |
| @deprecate_kwarg("last_cache_position", version="4.53.0") | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings_global: torch.Tensor, | |
| position_embeddings_local: torch.Tensor, | |
| per_layer_input: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: | |
| predictions = self.altup.predict(hidden_states) | |
| active_prediction = predictions[self.config.altup_active_idx] | |
| active_prediction_normed = self.input_layernorm(active_prediction) | |
| laurel_output = self.laurel(active_prediction_normed) | |
| # apply global RoPE to non-sliding layer only | |
| if self.self_attn.is_sliding: | |
| position_embeddings = position_embeddings_local | |
| else: | |
| position_embeddings = position_embeddings_global | |
| attn, self_attn_weights = self.self_attn( | |
| hidden_states=active_prediction_normed, | |
| position_embeddings=position_embeddings, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| **kwargs, | |
| ) | |
| attn = self.post_attention_layernorm(attn) | |
| attn_gated = active_prediction + attn | |
| attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) | |
| attn_norm = self.pre_feedforward_layernorm(attn_laurel) | |
| attn_ffw = self.mlp(attn_norm) | |
| attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) | |
| attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm | |
| corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) | |
| first_prediction = corrected_predictions[self.config.altup_active_idx].clone() | |
| if self.config.altup_correct_scale: | |
| first_prediction = self.altup.scale_corrected_output(first_prediction) | |
| # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) | |
| first_prediction = self.per_layer_input_gate(first_prediction) | |
| first_prediction = self.act_fn(first_prediction) | |
| first_prediction = torch.multiply(first_prediction, per_layer_input) | |
| # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...) | |
| first_prediction = self.per_layer_projection(first_prediction) | |
| first_prediction = self.post_per_layer_input_norm(first_prediction) | |
| corrected_predictions[1:] += first_prediction | |
| outputs = (corrected_predictions,) | |
| if output_attentions: | |
| outputs += (self_attn_weights,) | |
| return outputs | |
| @auto_docstring | |
| class Gemma3nPreTrainedModel(PreTrainedModel): | |
| config_class = Gemma3nConfig | |
| base_model_prefix = "" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["Gemma3nTextDecoderLayer"] | |
| _skip_keys_device_placement = ["past_key_values"] | |
| _supports_flash_attn_3 = True | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_flex_attn = True | |
| _supports_cache_class = True | |
| _supports_quantized_cache = True | |
| _supports_static_cache = True | |
| _supports_attention_backend = True | |
| def _init_weights(self, module): | |
| # important: this ported version of Gemma2 isn't meant for training from scratch - only | |
| # inference and fine-tuning - so the proper init weights code has been removed | |
| std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) | |
| if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, Gemma3nRMSNorm): | |
| if module.with_scale: | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, Gemma3nAudioCumulativeGroupNorm): | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, Gemma3nAudioAttention): | |
| module.per_dim_scale.data.zero_() | |
| elif isinstance(module, Gemma3nTextAltUp): | |
| module.correct_output_scale.data.zero_() | |
| @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.") | |
| class Gemma3nTextModel(Gemma3nPreTrainedModel): | |
| config_class = Gemma3nTextConfig | |
| def __init__(self, config: Gemma3nTextConfig): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 | |
| self.embed_tokens = Gemma3nTextScaledWordEmbedding( | |
| config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 | |
| ) | |
| self.layers = nn.ModuleList( | |
| [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config) | |
| self.gradient_checkpointing = False | |
| # TODO (raushan): Fix this after RoPE refactor. For now we hack it by | |
| # reassigning thetas when we want to create a local RoPE layer. Config | |
| # defaults should hold values for global RoPE. | |
| config = copy.deepcopy(config) | |
| config.rope_theta = config.rope_local_base_freq | |
| config.rope_scaling = {"rope_type": "default"} | |
| self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config) | |
| self.hidden_size = config.hidden_size | |
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
| self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( | |
| config.vocab_size_per_layer_input, | |
| config.num_hidden_layers * config.hidden_size_per_layer_input, | |
| self.padding_idx, | |
| embed_scale=config.hidden_size_per_layer_input**0.5, | |
| ) | |
| self.per_layer_model_projection = nn.Linear( | |
| self.hidden_size, | |
| config.num_hidden_layers * config.hidden_size_per_layer_input, | |
| bias=False, | |
| ) | |
| self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) | |
| self.altup_projections = nn.ModuleList( | |
| [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] | |
| ) | |
| self.altup_unembed_projections = nn.ModuleList( | |
| [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] | |
| ) | |
| self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False) | |
| self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| @can_return_tuple | |
| @auto_docstring | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| per_layer_inputs: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> BaseModelOutputWithPast: | |
| r""" | |
| per_layer_inputs (torch.Tensor, *optional*, defaults to None): | |
| Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if self.gradient_checkpointing and self.training and use_cache: | |
| logger.warning_once( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | |
| ) | |
| use_cache = False | |
| if input_ids is not None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| per_layer_inputs = self.get_per_layer_inputs(input_ids) | |
| per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) | |
| if use_cache and past_key_values is None and not self.training: | |
| past_key_values = DynamicCache() | |
| if cache_position is None: | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens, | |
| past_seen_tokens + inputs_embeds.shape[1], | |
| device=inputs_embeds.device, | |
| ) | |
| if position_ids is None: | |
| position_ids = cache_position.unsqueeze(0) | |
| # It may already have been prepared by e.g. `generate` | |
| if not isinstance(causal_mask_mapping := attention_mask, dict): | |
| # Prepare mask arguments | |
| mask_kwargs = { | |
| "config": self.config, | |
| "input_embeds": inputs_embeds, | |
| "attention_mask": attention_mask, | |
| "cache_position": cache_position, | |
| "past_key_values": past_key_values, | |
| } | |
| # Create the masks | |
| causal_mask_mapping = { | |
| "full_attention": create_causal_mask(**mask_kwargs), | |
| "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), | |
| } | |
| # embed positions | |
| hidden_states_0 = inputs_embeds | |
| # Initialize RoPE embeddings | |
| position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids) | |
| position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) | |
| # Expand hidden_states to support per-layer inputs | |
| target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 | |
| epsilon_tensor = torch.tensor(1e-5) | |
| temp_hidden_states = [hidden_states_0] | |
| for i in range(1, self.config.altup_num_inputs): | |
| # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) | |
| altup_proj = self.altup_projections[i - 1](hidden_states_0) | |
| current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) | |
| new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) | |
| new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) | |
| current_hidden_state = current_hidden_state * target_magnitude / new_magnitude | |
| temp_hidden_states.append(current_hidden_state) | |
| hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| causal_mask = causal_mask_mapping[decoder_layer.attention_type] | |
| per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| position_embeddings_global, | |
| position_embeddings_local, | |
| per_layer_input, | |
| attention_mask=causal_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| **flash_attn_kwargs, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output) | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| # Per-layer inputs to single output | |
| target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 | |
| temp_hidden_states = [hidden_states[0]] | |
| for i in range(1, self.config.altup_num_inputs): | |
| # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) | |
| altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) | |
| current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) | |
| new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) | |
| new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) | |
| current_hidden_state = current_hidden_state * target_magnitude / new_magnitude | |
| temp_hidden_states.append(current_hidden_state) | |
| hidden_states = torch.stack(temp_hidden_states) | |
| hidden_states = torch.mean(hidden_states, dim=0) | |
| hidden_states = self.norm(hidden_states) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| ) | |
| def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: | |
| return self.embed_tokens_per_layer(input_ids).reshape( | |
| *input_ids.shape, | |
| self.config.num_hidden_layers, | |
| self.hidden_size_per_layer_input, | |
| ) | |
| def project_per_layer_inputs( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| per_layer_inputs: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) | |
| per_layer_projection *= self.per_layer_projection_scale.to( | |
| dtype=inputs_embeds.dtype, device=per_layer_projection.device | |
| ) | |
| per_layer_projection = per_layer_projection.reshape( | |
| *inputs_embeds.shape[:-1], | |
| self.config.num_hidden_layers, | |
| self.hidden_size_per_layer_input, | |
| ) | |
| per_layer_projection = self.per_layer_projection_norm(per_layer_projection) | |
| if per_layer_inputs is None: | |
| return per_layer_projection | |
| if per_layer_projection.shape != per_layer_inputs.shape: | |
| # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. | |
| per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] | |
| return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( | |
| dtype=inputs_embeds.dtype, device=per_layer_projection.device | |
| ) | |
| @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") | |
| class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| _tp_plan = {"lm_head": "colwise_rep"} | |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} | |
| config_class = Gemma3nTextConfig | |
| base_model_prefix = "model" | |
| _checkpoint_conversion_mapping = {"model.language_model": "model"} | |
| def __init__(self, config: Gemma3nTextConfig): | |
| super().__init__(config) | |
| self.model = Gemma3nTextModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| @can_return_tuple | |
| @auto_docstring | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| logits_to_keep: Union[int, torch.Tensor] = 0, | |
| **loss_kwargs, | |
| ) -> CausalLMOutputWithPast: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, Gemma3nForCausalLM | |
| >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") | |
| >>> prompt = "What is your favorite condiment?" | |
| >>> inputs = tokenizer(prompt, return_tensors="pt") | |
| >>> # Generate | |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "What is your favorite condiment?" | |
| ```""" | |
| if self.training and self.config._attn_implementation != "eager": | |
| logger.warning_once( | |
| "It is strongly recommended to train Gemma3n models with the `eager` attention implementation " | |
| f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." | |
| ) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs: BaseModelOutputWithPast = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| cache_position=cache_position, | |
| **loss_kwargs, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| if self.config.final_logit_softcapping is not None: | |
| logits = logits / self.config.final_logit_softcapping | |
| logits = torch.tanh(logits) | |
| logits = logits * self.config.final_logit_softcapping | |
| loss = None | |
| if labels is not None: | |
| loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| class Gemma3nMultimodalEmbedder(nn.Module): | |
| """Embeds token ids or soft tokens for multimodal content into language model space.""" | |
| def __init__( | |
| self, | |
| multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], | |
| text_config: Gemma3nTextConfig, | |
| ): | |
| super().__init__() | |
| self.multimodal_hidden_size = multimodal_config.hidden_size | |
| self.eps = multimodal_config.rms_norm_eps | |
| self.vocab_offset = multimodal_config.vocab_offset | |
| self.vocab_size = multimodal_config.vocab_size | |
| self.text_hidden_size = text_config.hidden_size | |
| self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size) | |
| self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) | |
| self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) | |
| self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) | |
| self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Embeds token ids or soft tokens for multimodal content into language model space. | |
| Args: | |
| input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range | |
| `[vocab_offset, vocab_offset + vocab_size)`. | |
| inputs_embeds: A torch.Tensor containing the soft tokens to embed. | |
| Returns: | |
| A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. | |
| """ | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if inputs_embeds is not None: | |
| emb_norm = self.soft_embedding_norm(inputs_embeds) | |
| else: | |
| hard_emb = self.embedding(input_ids - self.vocab_offset) | |
| emb_norm = self.hard_embedding_norm(hard_emb) | |
| emb_norm_proj = self.embedding_projection(emb_norm) | |
| return self.embedding_post_projection_norm(emb_norm_proj) | |
| @auto_docstring( | |
| custom_intro=""" | |
| The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a | |
| language modeling head. | |
| """ | |
| ) | |
| class Gemma3nModel(Gemma3nPreTrainedModel): | |
| _checkpoint_conversion_mapping = {} | |
| # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch | |
| accepts_loss_kwargs = False | |
| def __init__(self, config: Gemma3nConfig): | |
| super().__init__(config) | |
| self.vision_tower = AutoModel.from_config(config=config.vision_config) | |
| self.vocab_size = config.text_config.vocab_size | |
| language_model = AutoModel.from_config(config=config.text_config) | |
| self.language_model = language_model | |
| self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 | |
| self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input | |
| self.audio_tower = AutoModel.from_config(config.audio_config) | |
| self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) | |
| self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def set_decoder(self, decoder): | |
| self.language_model = decoder | |
| def get_decoder(self): | |
| return self.language_model | |
| def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Projects the last hidden state from the vision model into language model space. | |
| Args: | |
| pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) | |
| The tensors corresponding to the input images. | |
| Returns: | |
| image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). | |
| """ | |
| vision_outputs = self.vision_tower( | |
| pixel_values=pixel_values, do_pooling=False, return_dict=True | |
| ).last_hidden_state | |
| # Convert from (batch, channels, height, width) to (batch, height * width, channels) where: | |
| # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image. | |
| vision_outputs = vision_outputs.reshape( | |
| vision_outputs.shape[0], | |
| self.config.vision_config.hidden_size, | |
| self.config.vision_soft_tokens_per_image, | |
| ).permute(0, 2, 1) | |
| # Normalize and embed the soft tokens into language model space. | |
| vision_outputs *= self.config.vision_config.hidden_size**0.5 | |
| return self.embed_vision(inputs_embeds=vision_outputs) | |
| @can_return_tuple | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, # text inputs | |
| pixel_values: Optional[torch.FloatTensor] = None, # vision inputs | |
| input_features: Optional[torch.FloatTensor] = None, # audio inputs | |
| attention_mask: Optional[torch.Tensor] = None, | |
| input_features_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| **lm_kwargs, | |
| ) -> Gemma3nCausalLMOutputWithPast: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. | |
| Example: | |
| ```python | |
| >>> from PIL import Image | |
| >>> import requests | |
| >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration | |
| >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224") | |
| >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224") | |
| >>> prompt = "Where is the cat standing?" | |
| >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" | |
| >>> image = Image.open(requests.get(url, stream=True).raw) | |
| >>> inputs = processor(images=image, text=prompt, return_tensors="pt") | |
| >>> # Generate | |
| >>> generate_ids = model.generate(**inputs,) | |
| >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "Where is the cat standing?\nsnow" | |
| ``` | |
| """ | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| if input_ids is not None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| # Prepare per-layer inputs from inputs_ids | |
| per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) | |
| per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) | |
| per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) | |
| # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) | |
| vision_mask = torch.logical_and( | |
| input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset | |
| ) | |
| dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1 | |
| vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device) | |
| vision_embeds = self.embed_vision(input_ids=vision_input_ids) | |
| expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) | |
| inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) | |
| # Handle audio tokens (>= embed_audio.vocab_offset) | |
| audio_mask = input_ids >= self.embed_audio.vocab_offset | |
| dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1 | |
| audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device) | |
| audio_embeds = self.embed_audio(input_ids=audio_input_ids) | |
| expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) | |
| inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) | |
| else: | |
| per_layer_inputs = None | |
| # Merge text and images | |
| if pixel_values is not None: | |
| image_features = self.get_image_features(pixel_values) | |
| if input_ids is None: | |
| special_image_mask = inputs_embeds == self.get_input_embeddings()( | |
| torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| else: | |
| special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) | |
| special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) | |
| if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): | |
| image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] | |
| raise ValueError( | |
| f"Number of images does not match number of special image tokens in the input text. " | |
| f"Got {image_tokens_in_text} image tokens in the text and " | |
| f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings." | |
| ) | |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) | |
| # Merge text and audio | |
| if input_features is not None and input_features_mask is not None: | |
| audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask) | |
| # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the | |
| # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will | |
| # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens | |
| # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad | |
| # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. | |
| audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device) | |
| audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) | |
| audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) | |
| audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape | |
| extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len | |
| extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) | |
| audio_features = torch.cat((audio_features, extra_padding_features), dim=1) | |
| if input_ids is None: | |
| special_audio_mask = inputs_embeds == self.embed_audio( | |
| input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| else: | |
| special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) | |
| special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) | |
| if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): | |
| audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] | |
| raise ValueError( | |
| f"Number of audio input features does not match number of special audio tokens in the input text. " | |
| f"Got {audio_tokens_in_text} audio tokens in the text and " | |
| f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings." | |
| ) | |
| audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) | |
| outputs = self.language_model( | |
| input_ids=None, | |
| per_layer_inputs=per_layer_inputs, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| cache_position=cache_position, | |
| **lm_kwargs, | |
| ) | |
| return Gemma3nModelOutputWithPast( | |
| last_hidden_state=outputs.last_hidden_state, | |
| past_key_values=outputs.past_key_values if use_cache else None, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| image_hidden_states=image_features if pixel_values is not None else None, | |
| audio_hidden_states=audio_features if input_features is not None else None, | |
| ) | |
| def get_audio_features( | |
| self, input_features: torch.Tensor, input_features_mask: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Projects the last hidden state from the audio encoder into language model space. | |
| Args: | |
| input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): | |
| The tensors corresponding to the input audio. | |
| input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`): | |
| The attention mask for the input audio. | |
| Returns: | |
| audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`). | |
| """ | |
| audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask) | |
| return self.embed_audio(inputs_embeds=audio_outputs), audio_mask | |
| @auto_docstring( | |
| custom_intro=""" | |
| The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling | |
| head. | |
| """ | |
| ) | |
| class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin): | |
| _checkpoint_conversion_mapping = {} | |
| _tied_weights_keys = ["lm_head.weight"] | |
| base_model_prefix = "model" | |
| def __init__(self, config: Gemma3nConfig): | |
| super().__init__(config) | |
| self.model = Gemma3nModel(config) | |
| self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.model.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model.set_decoder(decoder) | |
| def get_decoder(self): | |
| return self.model.get_decoder() | |
| def get_image_features(self, pixel_values): | |
| return self.model.get_image_features(pixel_values) | |
| # Make modules available throught conditional class for BC | |
| @property | |
| def language_model(self): | |
| return self.model.language_model | |
| @property | |
| def vision_tower(self): | |
| return self.model.vision_tower | |
| @property | |
| def multi_modal_projector(self): | |
| raise AttributeError("Use embed_vision instead of multi_modal_projector.") | |
| @can_return_tuple | |
| @auto_docstring | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, # text inputs | |
| pixel_values: Optional[torch.FloatTensor] = None, # vision inputs | |
| input_features: Optional[torch.FloatTensor] = None, # audio inputs | |
| attention_mask: Optional[torch.Tensor] = None, | |
| input_features_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| logits_to_keep: Union[int, torch.Tensor] = 0, | |
| **lm_kwargs, | |
| ) -> Gemma3nCausalLMOutputWithPast: | |
| r""" | |
| input_features (torch.Tensor, *optional*, defaults to None): | |
| The audio inputs to be encoded. | |
| input_features_mask (torch.Tensor, *optional*, defaults to None): | |
| The attention mask for the input audio. | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are | |
| ignored (masked), the loss is only computed for the tokens with labels in | |
| `[0, ..., config.text_config.vocab_size]`. | |
| Example: | |
| ```python | |
| >>> from PIL import Image | |
| >>> import requests | |
| >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration | |
| >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it") | |
| >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") | |
| >>> messages = [ | |
| ... { | |
| ... "role": "system", | |
| ... "content": [ | |
| ... {"type": "text", "text": "You are a helpful assistant."} | |
| ... ] | |
| ... }, | |
| ... { | |
| ... "role": "user", "content": [ | |
| ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, | |
| ... {"type": "text", "text": "Where is the cat standing?"}, | |
| ... ] | |
| ... }, | |
| ... ] | |
| >>> inputs = processor.apply_chat_template( | |
| ... messages, | |
| ... tokenizer=True, | |
| ... return_dict=True, | |
| ... return_tensors="pt", | |
| ... add_generation_prompt=True | |
| ... ) | |
| >>> # Generate | |
| >>> generate_ids = model.generate(**inputs) | |
| >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to" | |
| ``` | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| input_features=input_features, | |
| attention_mask=attention_mask, | |
| input_features_mask=input_features_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| token_type_ids=token_type_ids, | |
| cache_position=cache_position, | |
| inputs_embeds=inputs_embeds, | |
| labels=labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| **lm_kwargs, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: | |
| logits = logits / final_logit_softcapping | |
| logits = torch.tanh(logits) | |
| logits = logits * final_logit_softcapping | |
| loss = None | |
| if labels is not None: | |
| # Upcast to float if we need to compute the loss to avoid potential precision issues | |
| logits = logits.float() | |
| shift_logits = logits[..., :-1, :] | |
| shift_labels = labels[..., 1:] | |
| if attention_mask is not None: | |
| # we use the input attention mask to shift the logits and labels, because it is 2D. | |
| # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft | |
| shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) | |
| shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() | |
| shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() | |
| else: | |
| shift_logits = shift_logits.contiguous() | |
| shift_labels = shift_labels.contiguous() | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss() | |
| flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) | |
| flat_labels = shift_labels.view(-1).to(shift_logits.device) | |
| loss = loss_fct(flat_logits, flat_labels) | |
| return Gemma3nCausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| image_hidden_states=outputs.image_hidden_states, | |
| audio_hidden_states=outputs.audio_hidden_states, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| position_ids=None, | |
| pixel_values=None, | |
| input_features=None, | |
| attention_mask=None, | |
| input_features_mask=None, | |
| token_type_ids=None, | |
| use_cache=True, | |
| logits_to_keep=None, | |
| labels=None, | |
| **kwargs, | |
| ): | |
| # Overwritten -- custom `position_ids` and `pixel_values` handling | |
| model_inputs = super().prepare_inputs_for_generation( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| cache_position=cache_position, | |
| use_cache=use_cache, | |
| logits_to_keep=logits_to_keep, | |
| token_type_ids=token_type_ids, | |
| **kwargs, | |
| ) | |
| # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special | |
| # tokens anymore. Otherwise multimodal inputs should be passed to model. | |
| # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask | |
| if cache_position[0] == 0: | |
| model_inputs["pixel_values"] = pixel_values | |
| model_inputs["input_features"] = input_features | |
| model_inputs["input_features_mask"] = input_features_mask | |
| return model_inputs | |
| @property | |
| def audio_tower(self): | |
| return self.model.audio_tower | |
| __all__ = [ | |
| "Gemma3nAudioEncoder", | |
| "Gemma3nForCausalLM", | |
| "Gemma3nForConditionalGeneration", | |
| "Gemma3nModel", | |
| "Gemma3nPreTrainedModel", | |
| "Gemma3nTextModel", | |
| ] | |
| </script> | |
| <script id="dependencies" type="text/plain"> | |
| # coding=utf-8 | |
| # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | |
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import inspect | |
| import os | |
| import warnings | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, Callable, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from huggingface_hub import file_exists | |
| from packaging import version | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from ..cache_utils import ( | |
| Cache, | |
| DynamicCache, | |
| EncoderDecoderCache, | |
| HybridChunkedCache, | |
| OffloadedCache, | |
| OffloadedHybridCache, | |
| QuantizedCacheConfig, | |
| ) | |
| from ..configuration_utils import PretrainedConfig | |
| from ..dynamic_module_utils import ( | |
| check_python_requirements, | |
| get_cached_module_file, | |
| get_class_in_module, | |
| resolve_trust_remote_code, | |
| ) | |
| from ..integrations.deepspeed import is_deepspeed_zero3_enabled | |
| from ..integrations.fsdp import is_fsdp_managed_module | |
| from ..masking_utils import create_masks_for_generate | |
| from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput | |
| from ..pytorch_utils import isin_mps_friendly | |
| from ..tokenization_utils import ExtensionsTrie | |
| from ..utils import ( | |
| ModelOutput, | |
| is_accelerate_available, | |
| is_hqq_available, | |
| is_optimum_quanto_available, | |
| is_torchdynamo_exporting, | |
| logging, | |
| ) | |
| from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint | |
| from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer | |
| from .candidate_generator import ( | |
| AssistantVocabTranslatorCache, | |
| AssistedCandidateGenerator, | |
| AssistedCandidateGeneratorDifferentTokenizers, | |
| CandidateGenerator, | |
| EarlyExitCandidateGenerator, | |
| PromptLookupCandidateGenerator, | |
| UniversalSpeculativeDecodingGenerator, | |
| _crop_past_key_values, | |
| _prepare_attention_mask, | |
| _prepare_token_type_ids, | |
| ) | |
| from .configuration_utils import ( | |
| NEED_SETUP_CACHE_CLASSES_MAPPING, | |
| QUANT_BACKEND_CLASSES_MAPPING, | |
| CompileConfig, | |
| GenerationConfig, | |
| GenerationMode, | |
| ) | |
| from .continuous_batching import ContinuousMixin | |
| from .logits_process import ( | |
| EncoderNoRepeatNGramLogitsProcessor, | |
| EncoderRepetitionPenaltyLogitsProcessor, | |
| EpsilonLogitsWarper, | |
| EtaLogitsWarper, | |
| ExponentialDecayLengthPenalty, | |
| ForcedBOSTokenLogitsProcessor, | |
| ForcedEOSTokenLogitsProcessor, | |
| HammingDiversityLogitsProcessor, | |
| InfNanRemoveLogitsProcessor, | |
| LogitNormalization, | |
| LogitsProcessorList, | |
| MinLengthLogitsProcessor, | |
| MinNewTokensLengthLogitsProcessor, | |
| MinPLogitsWarper, | |
| NoBadWordsLogitsProcessor, | |
| NoRepeatNGramLogitsProcessor, | |
| PrefixConstrainedLogitsProcessor, | |
| RepetitionPenaltyLogitsProcessor, | |
| SequenceBiasLogitsProcessor, | |
| SuppressTokensAtBeginLogitsProcessor, | |
| SuppressTokensLogitsProcessor, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| TypicalLogitsWarper, | |
| UnbatchedClassifierFreeGuidanceLogitsProcessor, | |
| ) | |
| from .stopping_criteria import ( | |
| ConfidenceCriteria, | |
| EosTokenCriteria, | |
| MaxLengthCriteria, | |
| MaxTimeCriteria, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| StopStringCriteria, | |
| ) | |
| if TYPE_CHECKING: | |
| from ..modeling_utils import PreTrainedModel | |
| from ..tokenization_utils_base import PreTrainedTokenizerBase | |
| from .streamers import BaseStreamer | |
| logger = logging.get_logger(__name__) | |
| if is_accelerate_available(): | |
| from accelerate.hooks import AlignDevicesHook, add_hook_to_module | |
| # Variable names used to hold the cache at generation time | |
| ALL_CACHE_NAMES = [ | |
| "past_key_values", # default | |
| "cache_params", # mamba-based models | |
| "state", # rwkv | |
| "mems", # xlnet | |
| "past_buckets_states", # reformer | |
| ] | |
| @dataclass | |
| class GenerateDecoderOnlyOutput(ModelOutput): | |
| """ | |
| Outputs of decoder-only generation models, when using non-beam methods. | |
| Args: | |
| sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
| if all batches finished early due to the `eos_token_id`. | |
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): | |
| Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): | |
| Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
| hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. | |
| past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): | |
| Returns the model cache, used to speed up decoding. Different models have a different cache format, check | |
| the model's documentation. Usually, a [`~cache_utils.Cache`] instance. | |
| """ | |
| sequences: torch.LongTensor | |
| scores: Optional[tuple[torch.FloatTensor]] = None | |
| logits: Optional[tuple[torch.FloatTensor]] = None | |
| attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None | |
| @dataclass | |
| class GenerateEncoderDecoderOutput(ModelOutput): | |
| """ | |
| Outputs of encoder-decoder generation models, when using non-beam methods. | |
| Args: | |
| sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
| if all batches finished early due to the `eos_token_id`. | |
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): | |
| Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): | |
| Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, | |
| sequence_length, sequence_length)`. | |
| encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of | |
| shape `(batch_size, sequence_length, hidden_size)`. | |
| decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
| cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
| decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. | |
| past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Returns the model cache, used to speed up decoding. Different models have a different cache format, check | |
| the model's documentation. Usually, a [`~cache_utils.Cache`] instance. | |
| """ | |
| sequences: torch.LongTensor | |
| scores: Optional[tuple[torch.FloatTensor]] = None | |
| logits: Optional[tuple[torch.FloatTensor]] = None | |
| encoder_attentions: Optional[tuple[torch.FloatTensor]] = None | |
| encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None | |
| decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None | |
| @dataclass | |
| class GenerateBeamDecoderOnlyOutput(ModelOutput): | |
| """ | |
| Outputs of decoder-only generation models, when using beam methods. | |
| Args: | |
| sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
| if all batches finished early due to the `eos_token_id`. | |
| sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): | |
| Final beam scores of the generated `sequences`. | |
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): | |
| Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting | |
| of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. | |
| Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), | |
| with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. | |
| logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): | |
| Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. | |
| beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): | |
| Beam indices of generated token id at each generation step. `torch.LongTensor` of shape | |
| `(batch_size*num_return_sequences, sequence_length)`. | |
| attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. | |
| hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. | |
| past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): | |
| Returns the model cache, used to speed up decoding. Different models have a different cache format, check | |
| the model's documentation. Usually, a [`~cache_utils.Cache`] instance. | |
| """ | |
| sequences: torch.LongTensor | |
| sequences_scores: Optional[torch.FloatTensor] = None | |
| scores: Optional[tuple[torch.FloatTensor]] = None | |
| logits: Optional[tuple[torch.FloatTensor]] = None | |
| beam_indices: Optional[torch.LongTensor] = None | |
| attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None | |
| @dataclass | |
| class GenerateBeamEncoderDecoderOutput(ModelOutput): | |
| """ | |
| Outputs of encoder-decoder generation models, when using beam methods. | |
| Args: | |
| sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
| if all batches finished early due to the `eos_token_id`. | |
| sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): | |
| Final beam scores of the generated `sequences`. | |
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): | |
| Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting | |
| of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. | |
| Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), | |
| with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. | |
| logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): | |
| Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. | |
| beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): | |
| Beam indices of generated token id at each generation step. `torch.LongTensor` of shape | |
| `(batch_size*num_return_sequences, sequence_length)`. | |
| encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, | |
| sequence_length, sequence_length)`. | |
| encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of | |
| shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. | |
| decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, | |
| sequence_length)`. | |
| cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
| decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. | |
| past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): | |
| Returns the model cache, used to speed up decoding. Different models have a different cache format, check | |
| the model's documentation. Usually, a [`~cache_utils.Cache`] instance. | |
| """ | |
| sequences: torch.LongTensor | |
| sequences_scores: Optional[torch.FloatTensor] = None | |
| scores: Optional[tuple[torch.FloatTensor]] = None | |
| logits: Optional[tuple[torch.FloatTensor]] = None | |
| beam_indices: Optional[torch.LongTensor] = None | |
| encoder_attentions: Optional[tuple[torch.FloatTensor]] = None | |
| encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None | |
| decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None | |
| past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None | |
| # TODO (joao): remove the equivalent classes and typing shortcuts below in v5 | |
| # Equivalent classes (kept for retrocompatibility purposes) | |
| GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput | |
| ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput | |
| SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput | |
| ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput | |
| GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput | |
| SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput | |
| BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput | |
| BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput | |
| BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput | |
| BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput | |
| GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] | |
| SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] | |
| BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] | |
| BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] | |
| ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] | |
| # Typing shortcuts | |
| GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] | |
| GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] | |
| GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] | |
| class GenerationMixin(ContinuousMixin): | |
| """ | |
| A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. | |
| Inheriting from this class causes the model to have special generation-related behavior, such as loading a | |
| `GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI. | |
| A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it | |
| has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which | |
| approximately shares the same interface to public methods like `generate`. Three examples: | |
| - `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public | |
| methods in the mixin; | |
| - `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as | |
| `GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls | |
| `GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should | |
| inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase; | |
| - `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`. | |
| However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case, | |
| `BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface. | |
| The class exposes [`~generation.GenerationMixin.generate`], which can be used for: | |
| - *greedy decoding* if `num_beams=1` and `do_sample=False` | |
| - *contrastive search* if `penalty_alpha>0` and `top_k>1` | |
| - *multinomial sampling* if `num_beams=1` and `do_sample=True` | |
| - *beam-search decoding* if `num_beams>1` and `do_sample=False` | |
| - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` | |
| - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1` | |
| - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None` | |
| - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()` | |
| To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). | |
| """ | |
| def load_custom_generate( | |
| self, | |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, | |
| trust_remote_code: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Callable: | |
| """ | |
| Loads and returns a custom generate function, given a model repo. | |
| Args: | |
| pretrained_model_name_or_path (`str` or `os.PathLike`): | |
| Can be either: | |
| - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. | |
| - A path to a *directory* containing model weights saved using | |
| [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. | |
| trust_remote_code (`bool`, *optional*): | |
| Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |
| should only be set to `True` for repositories you trust and in which you have read the code, as it will | |
| execute code present on the Hub on your local machine. | |
| **kwargs: | |
| Additional keyword arguments for remote code loading. | |
| Raises: | |
| OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory. | |
| Returns: | |
| A callable that can be used to generate text. | |
| """ | |
| # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError | |
| is_local_code = os.path.exists(pretrained_model_name_or_path) | |
| has_custom_generate_folder = True | |
| if is_local_code: | |
| if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")): | |
| has_custom_generate_folder = False | |
| else: | |
| if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"): | |
| has_custom_generate_folder = False | |
| if not has_custom_generate_folder: | |
| raise OSError( | |
| f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a " | |
| "`generate.py` file, can't load the custom generate function." | |
| ) | |
| # Handle opt-in `trust_remote_code` and related exceptions | |
| error_message = ( | |
| f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override " | |
| "the default `generate` method." | |
| ) | |
| resolve_trust_remote_code( | |
| trust_remote_code, | |
| pretrained_model_name_or_path, | |
| has_local_code=is_local_code, | |
| has_remote_code=not is_local_code, | |
| error_message=error_message, | |
| ) | |
| # Load the custom generate function | |
| check_python_requirements( | |
| pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs | |
| ) | |
| module = get_cached_module_file( | |
| pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs | |
| ) | |
| custom_generate_function = get_class_in_module("generate", module) | |
| return custom_generate_function | |
| def _cache_dependant_input_preparation( | |
| self, | |
| input_ids: torch.LongTensor, | |
| inputs_embeds: Optional[torch.FloatTensor], | |
| cache_position: Optional[torch.LongTensor], | |
| ) -> tuple[torch.FloatTensor, torch.LongTensor]: | |
| """ | |
| Generic cache-dependent input preparation | |
| The code is put in a separate function to allow granular unit testing | |
| as it needs a different implementation to be exportable. | |
| If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
| - Exception 1: when passing input_embeds, input_ids may be missing entries | |
| - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
| - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. | |
| - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and | |
| generate the first token for each sequence. Later use the generated Input ids for continuation. | |
| The current implementation does not rely on ``self`` and could be | |
| a class method. It is left as a standard method to be easily rewritten. | |
| """ | |
| if is_torchdynamo_exporting(): | |
| return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) | |
| if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 | |
| inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] | |
| elif ( | |
| inputs_embeds is not None # Exception 1 | |
| or (cache_position[-1] >= input_ids.shape[1]) # Exception 3 | |
| ): | |
| input_ids = input_ids[:, -cache_position.shape[0] :] | |
| elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
| input_ids = input_ids[:, cache_position] | |
| return inputs_embeds, input_ids | |
| def _cache_dependant_input_preparation_exporting( | |
| self, | |
| input_ids: torch.LongTensor, | |
| inputs_embeds: Optional[torch.FloatTensor], | |
| cache_position: Optional[torch.LongTensor], | |
| ) -> tuple[torch.FloatTensor, torch.LongTensor]: | |
| """ | |
| This method implements method ``_cache_dependant_input_preparation`` | |
| with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. | |
| The code is put in a separate function to allow granular unit testing. | |
| """ | |
| if inputs_embeds is None: | |
| input_ids = input_ids[:, cache_position] | |
| else: | |
| # This is the code we need to implemented with torch.cond. | |
| # if input_ids.shape[1] == 0: | |
| # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] | |
| # else: | |
| # if cache_position[-1] >= input_ids.shape[1]: | |
| # input_ids = input_ids[:, -cache_position.shape[0] :] | |
| # else: | |
| # if input_ids.shape[1] != cache_position.shape[0]: | |
| # input_ids = input_ids[:, cache_position] | |
| def branch_1(inputs_embeds, cache_position): | |
| return inputs_embeds[:, -cache_position.shape[0] :] | |
| def branch_2(input_ids, cache_position): | |
| return input_ids[:, -cache_position.shape[0] :] | |
| def branch_3(input_ids, cache_position): | |
| return input_ids[:, cache_position] | |
| inputs_embeds, input_ids = torch.cond( | |
| input_ids.shape[1] == 0, | |
| ( | |
| lambda input_ids, inputs_embeds, cache_position: ( | |
| branch_1(inputs_embeds, cache_position), | |
| input_ids, | |
| ) | |
| ), | |
| ( | |
| lambda input_ids, inputs_embeds, cache_position: ( | |
| inputs_embeds, | |
| torch.cond( | |
| cache_position[-1] >= input_ids.shape[1], | |
| branch_2, | |
| lambda input_ids, cache_position: ( | |
| torch.cond( | |
| input_ids.shape[1] != cache_position.shape[0], | |
| branch_3, | |
| (lambda input_ids, cache_position: input_ids), | |
| [input_ids, cache_position], | |
| ) | |
| ), | |
| [input_ids, cache_position], | |
| ), | |
| ) | |
| ), | |
| [input_ids, inputs_embeds, cache_position], | |
| ) | |
| return inputs_embeds, input_ids | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor, | |
| past_key_values: Optional[Cache] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or | |
| slicing inputs given the existing cache. | |
| See the forward pass in the model documentation for expected arguments (different models might have different | |
| requirements for e.g. `past_key_values`). This function should work as is for most LLMs. | |
| """ | |
| # 1. Handle BC: | |
| model_inputs = {} | |
| # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) | |
| if self._supports_cache_class: | |
| model_inputs["cache_position"] = cache_position | |
| # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this | |
| # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly | |
| # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) | |
| elif cache_position is None: | |
| past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 | |
| cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) | |
| # 2. Generic cache-dependent input preparation | |
| if past_key_values is not None: | |
| model_inputs["past_key_values"] = past_key_values | |
| inputs_embeds, input_ids = self._cache_dependant_input_preparation( | |
| input_ids, inputs_embeds, cache_position | |
| ) | |
| # 3. Prepare base model inputs | |
| input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt. | |
| if not self.config.is_encoder_decoder: | |
| if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: | |
| model_inputs[input_ids_key] = None | |
| model_inputs["inputs_embeds"] = inputs_embeds | |
| else: | |
| # `clone` calls in this function ensure a consistent stride. See #32227 | |
| model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) | |
| model_inputs["inputs_embeds"] = None | |
| else: | |
| model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) | |
| # 4. Create missing `position_ids` on the fly | |
| encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None | |
| attention_mask = ( | |
| kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask | |
| ) | |
| attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" | |
| position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" | |
| if ( | |
| attention_mask is not None | |
| and kwargs.get(position_ids_key) is None | |
| and position_ids_key in set(inspect.signature(self.forward).parameters.keys()) | |
| ): | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below) | |
| # 5. Slice model inputs if it's an input that should have the same length as `input_ids` | |
| for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: | |
| model_input = kwargs.get(model_input_name) | |
| if model_input is not None: | |
| if past_key_values is not None: | |
| current_input_length = ( | |
| model_inputs["inputs_embeds"].shape[1] | |
| if model_inputs.get("inputs_embeds") is not None | |
| else model_inputs[input_ids_key].shape[1] | |
| ) | |
| model_input = model_input[:, -current_input_length:] | |
| model_input = model_input.clone(memory_format=torch.contiguous_format) | |
| model_inputs[model_input_name] = model_input | |
| # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward | |
| # pass) | |
| if ( | |
| isinstance(past_key_values, Cache) | |
| and past_key_values.is_compileable | |
| and attention_mask is not None | |
| and attention_mask.ndim == 2 | |
| ): | |
| if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None: | |
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape | |
| else: | |
| batch_size, sequence_length = model_inputs[input_ids_key].shape[:2] | |
| # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create | |
| # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder. | |
| base_model = getattr(self, self.base_model_prefix, self) | |
| decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None | |
| causal_mask_creation_function = getattr( | |
| base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None | |
| ) | |
| if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder | |
| causal_mask_creation_function = getattr( | |
| decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None | |
| ) | |
| # If it's not defined, it means the model uses the new general mask API | |
| if causal_mask_creation_function is None: # can't be found | |
| token_type_ids = getattr(model_input, "token_type_ids", None) | |
| # Some models may overwrite the general one | |
| causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) | |
| attention_mask = causal_mask_creation_function( | |
| config=self.config, | |
| # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings | |
| input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype), | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| token_type_ids=token_type_ids, | |
| ) | |
| else: | |
| attention_mask = causal_mask_creation_function( | |
| attention_mask, | |
| sequence_length=sequence_length, | |
| target_length=past_key_values.get_max_cache_shape(), | |
| dtype=self.dtype, | |
| cache_position=cache_position, | |
| batch_size=batch_size, | |
| config=self.config, | |
| past_key_values=past_key_values, | |
| ) | |
| if attention_mask is not None: | |
| model_inputs[attention_mask_key] = attention_mask | |
| if encoder_attention_mask is not None: | |
| model_inputs["attention_mask"] = encoder_attention_mask | |
| # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). | |
| for key, value in kwargs.items(): | |
| if key not in model_inputs: | |
| model_inputs[key] = value | |
| # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) | |
| model_inputs.pop("labels", None) | |
| return model_inputs | |
| def _prepare_model_inputs( | |
| self, | |
| inputs: Optional[torch.Tensor] = None, | |
| bos_token_id: Optional[torch.Tensor] = None, | |
| model_kwargs: Optional[dict[str, torch.Tensor]] = None, | |
| ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: | |
| """ | |
| This function extracts the model-specific `inputs` for generation. | |
| """ | |
| # 1. retrieve all kwargs that are non-None or non-model input related. | |
| # some encoder-decoder models have different names for model and encoder | |
| if ( | |
| self.config.is_encoder_decoder | |
| and hasattr(self, "encoder") | |
| and self.encoder.main_input_name != self.main_input_name | |
| ): | |
| input_name = self.encoder.main_input_name | |
| else: | |
| input_name = self.main_input_name | |
| model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} | |
| # 2. check whether model_input_name is passed as kwarg | |
| # if yes and `inputs` is None use kwarg inputs | |
| inputs_kwarg = model_kwargs.pop(input_name, None) | |
| if inputs_kwarg is not None and inputs is not None: | |
| raise ValueError( | |
| f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " | |
| f"Make sure to either pass {inputs} or {input_name}=..." | |
| ) | |
| elif inputs_kwarg is not None: | |
| inputs = inputs_kwarg | |
| # 3. In the presence of `inputs_embeds` for text models: | |
| # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model | |
| # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with | |
| # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) | |
| # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and | |
| # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. | |
| if input_name == "input_ids" and "inputs_embeds" in model_kwargs: | |
| if model_kwargs["inputs_embeds"] is None: | |
| model_kwargs.pop("inputs_embeds") | |
| elif not self.config.is_encoder_decoder: | |
| has_inputs_embeds_forwarding = "inputs_embeds" in set( | |
| inspect.signature(self.prepare_inputs_for_generation).parameters.keys() | |
| ) | |
| if not has_inputs_embeds_forwarding: | |
| raise ValueError( | |
| f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " | |
| "doesn't have its forwarding implemented. See the GPT2 implementation for an example " | |
| "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" | |
| ) | |
| # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of | |
| # the attention mask) can rely on the actual model input. | |
| model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( | |
| inputs, bos_token_id, model_kwargs=model_kwargs | |
| ) | |
| inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" | |
| else: | |
| if inputs is not None: | |
| raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") | |
| inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" | |
| # 4. if `inputs` is still None, try to create `input_ids` from BOS token | |
| inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) | |
| return inputs, input_name, model_kwargs | |
| def _maybe_initialize_input_ids_for_generation( | |
| self, | |
| inputs: Optional[torch.Tensor] = None, | |
| bos_token_id: Optional[torch.Tensor] = None, | |
| model_kwargs: Optional[dict[str, torch.Tensor]] = None, | |
| ) -> torch.LongTensor: | |
| """Initializes input ids for generation, if necessary.""" | |
| if inputs is not None: | |
| return inputs | |
| encoder_outputs = model_kwargs.get("encoder_outputs") | |
| if self.config.is_encoder_decoder and encoder_outputs is not None: | |
| # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding | |
| shape = encoder_outputs.last_hidden_state.size()[:-1] | |
| return torch.ones(shape, dtype=torch.long, device=self.device) * -100 | |
| # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with | |
| # soft-prompting or in multimodal implementations built on top of decoder-only language models. | |
| batch_size = 1 | |
| for value in model_kwargs.values(): | |
| if isinstance(value, torch.Tensor): | |
| batch_size = value.shape[0] | |
| break | |
| if "inputs_embeds" in model_kwargs: | |
| return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) | |
| if bos_token_id is None: | |
| raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") | |
| return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id | |
| def _prepare_attention_mask_for_generation( | |
| self, | |
| inputs_tensor: torch.Tensor, | |
| generation_config: GenerationConfig, | |
| model_kwargs: dict[str, Any], | |
| ) -> torch.LongTensor: | |
| pad_token_id = generation_config._pad_token_tensor | |
| eos_token_id = generation_config._eos_token_tensor | |
| # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model) | |
| if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: | |
| inputs_tensor = model_kwargs["input_ids"] | |
| # No information for attention mask inference -> return default attention mask | |
| default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) | |
| if pad_token_id is None: | |
| return default_attention_mask | |
| is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] | |
| if not is_input_ids: | |
| return default_attention_mask | |
| is_pad_token_in_inputs = (pad_token_id is not None) and ( | |
| isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any() | |
| ) | |
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( | |
| isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() | |
| ) | |
| can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id | |
| attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long() | |
| attention_mask = ( | |
| attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask | |
| ) | |
| return attention_mask | |
| def _prepare_encoder_decoder_kwargs_for_generation( | |
| self, | |
| inputs_tensor: torch.Tensor, | |
| model_kwargs, | |
| model_input_name: Optional[str], | |
| generation_config: GenerationConfig, | |
| ) -> dict[str, Any]: | |
| # 1. get encoder | |
| encoder = self.get_encoder() | |
| # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device | |
| # as the inputs. | |
| if hasattr(self, "hf_device_map"): | |
| if hasattr(encoder, "_hf_hook"): | |
| encoder._hf_hook.io_same_device = True | |
| else: | |
| add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) | |
| # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. | |
| irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] | |
| encoder_kwargs = { | |
| argument: value | |
| for argument, value in model_kwargs.items() | |
| if not any(argument.startswith(p) for p in irrelevant_prefix) | |
| } | |
| encoder_signature = set(inspect.signature(encoder.forward).parameters) | |
| encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature | |
| if not encoder_accepts_wildcard: | |
| encoder_kwargs = { | |
| argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature | |
| } | |
| encoder_kwargs["output_attentions"] = generation_config.output_attentions | |
| encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states | |
| # 3. make sure that encoder returns `ModelOutput` | |
| model_input_name = model_input_name if model_input_name is not None else self.main_input_name | |
| encoder_kwargs["return_dict"] = True | |
| encoder_kwargs[model_input_name] = inputs_tensor | |
| model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore | |
| return model_kwargs | |
| def _prepare_decoder_input_ids_for_generation( | |
| self, | |
| batch_size: int, | |
| model_input_name: str, | |
| model_kwargs: dict[str, torch.Tensor], | |
| decoder_start_token_id: torch.Tensor, | |
| device: Optional[torch.device] = None, | |
| ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: | |
| """Prepares `decoder_input_ids` for generation with encoder-decoder models""" | |
| # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, | |
| # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. | |
| if model_kwargs is not None and "decoder_input_ids" in model_kwargs: | |
| decoder_input_ids = model_kwargs.pop("decoder_input_ids") | |
| elif "input_ids" in model_kwargs and model_input_name != "input_ids": | |
| decoder_input_ids = model_kwargs.pop("input_ids") | |
| else: | |
| decoder_input_ids = None | |
| # 2. `decoder_start_token_id` must have shape (batch_size, 1) | |
| if device is None: | |
| device = self.device | |
| if decoder_start_token_id.ndim == 1: | |
| if decoder_start_token_id.shape[0] != batch_size: | |
| raise ValueError( | |
| f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" | |
| ) | |
| decoder_start_token_id = decoder_start_token_id.view(-1, 1) | |
| else: | |
| decoder_start_token_id = ( | |
| torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id | |
| ) | |
| # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. | |
| # no user input -> use decoder_start_token_id as decoder_input_ids | |
| if decoder_input_ids is None: | |
| decoder_input_ids = decoder_start_token_id | |
| # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the | |
| # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. | |
| # See: https://github.com/huggingface/transformers/pull/31470 | |
| elif "donut" in self.__class__.__name__.lower() or ( | |
| self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() | |
| ): | |
| pass | |
| elif self.config.model_type in ["whisper"]: | |
| pass | |
| # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust | |
| # decoder_attention_mask if provided) | |
| elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): | |
| decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) | |
| if "decoder_attention_mask" in model_kwargs: | |
| decoder_attention_mask = model_kwargs["decoder_attention_mask"] | |
| decoder_attention_mask = torch.cat( | |
| (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), | |
| dim=-1, | |
| ) | |
| model_kwargs["decoder_attention_mask"] = decoder_attention_mask | |
| return decoder_input_ids, model_kwargs | |
| @staticmethod | |
| def _expand_inputs_for_generation( | |
| expand_size: int = 1, | |
| is_encoder_decoder: bool = False, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| **model_kwargs, | |
| ) -> tuple[torch.LongTensor, dict[str, Any]]: | |
| """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" | |
| # Do not call torch.repeat_interleave if expand_size is 1 because it clones | |
| # the input tensor and thus requires more memory although no change is applied | |
| if expand_size == 1: | |
| return input_ids, model_kwargs | |
| def _expand_dict_for_generation(dict_to_expand): | |
| for key in dict_to_expand: | |
| if ( | |
| key != "cache_position" | |
| and dict_to_expand[key] is not None | |
| and isinstance(dict_to_expand[key], torch.Tensor) | |
| ): | |
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) | |
| return dict_to_expand | |
| if input_ids is not None: | |
| input_ids = input_ids.repeat_interleave(expand_size, dim=0) | |
| model_kwargs = _expand_dict_for_generation(model_kwargs) | |
| if is_encoder_decoder: | |
| if model_kwargs.get("encoder_outputs") is None: | |
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | |
| model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) | |
| return input_ids, model_kwargs | |
| def _update_model_kwargs_for_generation( | |
| self, | |
| outputs: ModelOutput, | |
| model_kwargs: dict[str, Any], | |
| is_encoder_decoder: bool = False, | |
| num_new_tokens: int = 1, | |
| ) -> dict[str, Any]: | |
| # update past_key_values keeping its naming used in model code | |
| for possible_cache_name in ALL_CACHE_NAMES: | |
| if possible_cache_name in outputs: | |
| # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated | |
| if possible_cache_name in ("past_buckets_states", "mems"): | |
| cache_name = "past_key_values" | |
| else: | |
| cache_name = possible_cache_name | |
| model_kwargs[cache_name] = getattr(outputs, possible_cache_name) | |
| break | |
| # update token_type_ids with last value | |
| if "token_type_ids" in model_kwargs: | |
| token_type_ids = model_kwargs["token_type_ids"] | |
| model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) | |
| if not is_encoder_decoder: | |
| # update attention mask | |
| if "attention_mask" in model_kwargs: | |
| attention_mask = model_kwargs["attention_mask"] | |
| model_kwargs["attention_mask"] = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | |
| ) | |
| else: | |
| # update decoder attention mask | |
| if "decoder_attention_mask" in model_kwargs: | |
| decoder_attention_mask = model_kwargs["decoder_attention_mask"] | |
| model_kwargs["decoder_attention_mask"] = torch.cat( | |
| [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], | |
| dim=-1, | |
| ) | |
| if model_kwargs.get("use_cache", True): | |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens | |
| else: | |
| past_positions = model_kwargs.pop("cache_position") | |
| new_positions = torch.arange( | |
| past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype | |
| ).to(past_positions.device) | |
| model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) | |
| return model_kwargs | |
| def _reorder_cache(self, past_key_values, beam_idx): | |
| raise NotImplementedError( | |
| f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" | |
| f" enable beam search for {self.__class__}" | |
| ) | |
| def _get_candidate_generator( | |
| self, | |
| generation_config: GenerationConfig, | |
| input_ids: torch.LongTensor, | |
| inputs_tensor: torch.Tensor, | |
| assistant_model: "PreTrainedModel", | |
| logits_processor: LogitsProcessorList, | |
| target_tokenizer: "PreTrainedTokenizerBase", | |
| assistant_tokenizer: "PreTrainedTokenizerBase", | |
| model_kwargs: dict, | |
| ) -> CandidateGenerator: | |
| """ | |
| Returns the candidate generator to be used in `assisted_generation` | |
| """ | |
| different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) | |
| if generation_config.assistant_early_exit is not None: | |
| candidate_generator = EarlyExitCandidateGenerator( | |
| input_ids=input_ids, | |
| assistant_model=self, | |
| generation_config=generation_config, | |
| model_kwargs=model_kwargs, | |
| inputs_tensor=inputs_tensor, | |
| logits_processor=logits_processor, | |
| ) | |
| elif generation_config.prompt_lookup_num_tokens is not None: | |
| candidate_generator = PromptLookupCandidateGenerator( | |
| eos_token_id=generation_config._eos_token_tensor, | |
| num_output_tokens=generation_config.prompt_lookup_num_tokens, | |
| max_matching_ngram_size=generation_config.max_matching_ngram_size, | |
| max_length=generation_config.max_length, | |
| ) | |
| elif different_tokenizers: | |
| if generation_config.do_sample is True: | |
| atm_translator = AssistantVocabTranslatorCache.get_translator( | |
| target_tokenizer, | |
| assistant_tokenizer, | |
| self.config.get_text_config().vocab_size, | |
| assistant_model=assistant_model, | |
| assistant_prune_lm_head=True, # prune LM head of assistant model | |
| ) | |
| # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index | |
| assistant_model.generation_config.repetition_penalty = None | |
| candidate_generator = UniversalSpeculativeDecodingGenerator( | |
| input_ids=input_ids, | |
| assistant_model=assistant_model, | |
| generation_config=generation_config, | |
| model_kwargs=model_kwargs, | |
| inputs_tensor=inputs_tensor, | |
| logits_processor=logits_processor, | |
| target_tokenizer=target_tokenizer, | |
| assistant_tokenizer=assistant_tokenizer, | |
| atm_translator=atm_translator, | |
| ) | |
| elif generation_config.do_sample is False: | |
| candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( | |
| input_ids=input_ids, | |
| assistant_model=assistant_model, | |
| generation_config=generation_config, | |
| model_kwargs=model_kwargs, | |
| inputs_tensor=inputs_tensor, | |
| logits_processor=logits_processor, | |
| target_tokenizer=target_tokenizer, | |
| assistant_tokenizer=assistant_tokenizer, | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}" | |
| ) | |
| else: | |
| candidate_generator = AssistedCandidateGenerator( | |
| input_ids=input_ids, | |
| assistant_model=assistant_model, | |
| generation_config=generation_config, | |
| model_kwargs=model_kwargs, | |
| inputs_tensor=inputs_tensor, | |
| logits_processor=logits_processor, | |
| ) | |
| return candidate_generator | |
| def _get_logits_processor( | |
| self, | |
| generation_config: GenerationConfig, | |
| input_ids_seq_length: Optional[int] = None, | |
| encoder_input_ids: torch.LongTensor = None, | |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| device: Optional[str] = None, | |
| model_kwargs: Optional[dict[str, Any]] = None, | |
| negative_prompt_ids: Optional[torch.Tensor] = None, | |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> LogitsProcessorList: | |
| """ | |
| This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] | |
| instances used to modify the scores of the language model head. | |
| """ | |
| # instantiate processors list | |
| processors = LogitsProcessorList() | |
| if logits_processor is None: | |
| logits_processor = [] | |
| if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: | |
| processors.append( | |
| UnbatchedClassifierFreeGuidanceLogitsProcessor( | |
| generation_config.guidance_scale, | |
| self, | |
| unconditional_ids=negative_prompt_ids, | |
| unconditional_attention_mask=negative_prompt_attention_mask, | |
| use_cache=generation_config.use_cache, | |
| ) | |
| ) | |
| if generation_config.sequence_bias is not None: | |
| processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) | |
| if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: | |
| processors.append( | |
| HammingDiversityLogitsProcessor( | |
| diversity_penalty=generation_config.diversity_penalty, | |
| num_beams=generation_config.num_beams, | |
| num_beam_groups=generation_config.num_beam_groups, | |
| ) | |
| ) | |
| if ( | |
| generation_config.encoder_repetition_penalty is not None | |
| and generation_config.encoder_repetition_penalty != 1.0 | |
| ): | |
| if len(encoder_input_ids.shape) == 2: | |
| processors.append( | |
| EncoderRepetitionPenaltyLogitsProcessor( | |
| penalty=generation_config.encoder_repetition_penalty, | |
| encoder_input_ids=encoder_input_ids, | |
| ) | |
| ) | |
| else: | |
| warnings.warn( | |
| "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to " | |
| "`generate`, ignoring the argument.", | |
| UserWarning, | |
| ) | |
| if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: | |
| processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) | |
| if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: | |
| processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) | |
| if ( | |
| generation_config.encoder_no_repeat_ngram_size is not None | |
| and generation_config.encoder_no_repeat_ngram_size > 0 | |
| ): | |
| if len(encoder_input_ids.shape) == 2: | |
| processors.append( | |
| EncoderNoRepeatNGramLogitsProcessor( | |
| generation_config.encoder_no_repeat_ngram_size, | |
| encoder_input_ids, | |
| ) | |
| ) | |
| else: | |
| warnings.warn( | |
| "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to " | |
| "`generate`, ignoring the argument.", | |
| UserWarning, | |
| ) | |
| if generation_config.bad_words_ids is not None: | |
| processors.append( | |
| NoBadWordsLogitsProcessor( | |
| generation_config.bad_words_ids, | |
| generation_config._eos_token_tensor, | |
| ) | |
| ) | |
| if ( | |
| generation_config.min_length is not None | |
| and getattr(generation_config, "_eos_token_tensor", None) is not None | |
| and generation_config.min_length > 0 | |
| ): | |
| processors.append( | |
| MinLengthLogitsProcessor( | |
| generation_config.min_length, | |
| generation_config._eos_token_tensor, | |
| device=device, | |
| ) | |
| ) | |
| if ( | |
| generation_config.min_new_tokens is not None | |
| and getattr(generation_config, "_eos_token_tensor", None) is not None | |
| and generation_config.min_new_tokens > 0 | |
| ): | |
| processors.append( | |
| MinNewTokensLengthLogitsProcessor( | |
| input_ids_seq_length, | |
| generation_config.min_new_tokens, | |
| generation_config._eos_token_tensor, | |
| device=device, | |
| ) | |
| ) | |
| if prefix_allowed_tokens_fn is not None: | |
| processors.append( | |
| PrefixConstrainedLogitsProcessor( | |
| prefix_allowed_tokens_fn, | |
| generation_config.num_beams // generation_config.num_beam_groups, | |
| ) | |
| ) | |
| if generation_config.forced_bos_token_id is not None: | |
| processors.append( | |
| ForcedBOSTokenLogitsProcessor( | |
| generation_config.forced_bos_token_id, | |
| ) | |
| ) | |
| if generation_config.forced_eos_token_id is not None: | |
| processors.append( | |
| ForcedEOSTokenLogitsProcessor( | |
| generation_config.max_length, | |
| generation_config.forced_eos_token_id, | |
| device=device, | |
| ) | |
| ) | |
| if generation_config.remove_invalid_values is True: | |
| processors.append(InfNanRemoveLogitsProcessor()) | |
| if generation_config.exponential_decay_length_penalty is not None: | |
| processors.append( | |
| ExponentialDecayLengthPenalty( | |
| generation_config.exponential_decay_length_penalty, | |
| generation_config._eos_token_tensor, | |
| input_ids_seq_length, | |
| ) | |
| ) | |
| if generation_config.suppress_tokens is not None: | |
| processors.append( | |
| SuppressTokensLogitsProcessor( | |
| generation_config.suppress_tokens, | |
| device=device, | |
| ) | |
| ) | |
| if generation_config.begin_suppress_tokens is not None: | |
| begin_index = input_ids_seq_length | |
| begin_index = ( | |
| begin_index | |
| if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) | |
| else begin_index + 1 | |
| ) | |
| processors.append( | |
| SuppressTokensAtBeginLogitsProcessor( | |
| generation_config.begin_suppress_tokens, | |
| begin_index, | |
| device=device, | |
| ) | |
| ) | |
| # TODO (joao): find a strategy to specify the order of the processors | |
| processors = self._merge_criteria_processor_list(processors, logits_processor) | |
| # Processors previously known as `LogitsWarpers`, only applied with sampling strategies | |
| if generation_config.do_sample: | |
| # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a | |
| # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) | |
| if generation_config.num_beams > 1: | |
| if isinstance(generation_config._eos_token_tensor, list): | |
| min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 | |
| elif isinstance(generation_config._eos_token_tensor, torch.Tensor): | |
| min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 | |
| else: | |
| min_tokens_to_keep = 2 | |
| else: | |
| min_tokens_to_keep = 1 | |
| # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
| # all samplers can be found in `generation_utils_samplers.py` | |
| if generation_config.temperature is not None and generation_config.temperature != 1.0: | |
| processors.append(TemperatureLogitsWarper(generation_config.temperature)) | |
| if generation_config.top_k is not None and generation_config.top_k != 0: | |
| processors.append( | |
| TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) | |
| ) | |
| if generation_config.top_p is not None and generation_config.top_p < 1.0: | |
| processors.append( | |
| TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) | |
| ) | |
| if generation_config.min_p is not None: | |
| # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) | |
| processors.append( | |
| MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) | |
| ) | |
| if generation_config.typical_p is not None and generation_config.typical_p < 1.0: | |
| processors.append( | |
| TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) | |
| ) | |
| if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: | |
| processors.append( | |
| EpsilonLogitsWarper( | |
| epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep | |
| ) | |
| ) | |
| if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: | |
| processors.append( | |
| EtaLogitsWarper( | |
| epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device | |
| ) | |
| ) | |
| # Watermarking should be after all logits processing is finished (see #34630) | |
| if generation_config.watermarking_config is not None: | |
| processors.append( | |
| generation_config.watermarking_config.construct_processor( | |
| self.config.get_text_config().vocab_size, device | |
| ) | |
| ) | |
| # `LogitNormalization` should always be the last logit processor, when present | |
| if generation_config.renormalize_logits is True: | |
| processors.append(LogitNormalization()) | |
| return processors | |
| def _get_stopping_criteria( | |
| self, | |
| generation_config: GenerationConfig, | |
| stopping_criteria: Optional[StoppingCriteriaList], | |
| tokenizer: Optional["PreTrainedTokenizerBase"] = None, | |
| **kwargs, | |
| ) -> StoppingCriteriaList: | |
| criteria = StoppingCriteriaList() | |
| if generation_config.max_length is not None: | |
| max_position_embeddings = getattr(self.config, "max_position_embeddings", None) | |
| criteria.append( | |
| MaxLengthCriteria( | |
| max_length=generation_config.max_length, | |
| max_position_embeddings=max_position_embeddings, | |
| ) | |
| ) | |
| if generation_config.max_time is not None: | |
| criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) | |
| if generation_config.stop_strings is not None: | |
| if tokenizer is None: | |
| raise ValueError( | |
| "There are one or more stop strings, either in the arguments to `generate` or in the " | |
| "model's generation config, but we could not locate a tokenizer. When generating with " | |
| "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." | |
| ) | |
| criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) | |
| if generation_config._eos_token_tensor is not None: | |
| criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) | |
| if ( | |
| generation_config.is_assistant | |
| and generation_config.assistant_confidence_threshold is not None | |
| and generation_config.assistant_confidence_threshold > 0 | |
| ): | |
| criteria.append( | |
| ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) | |
| ) | |
| criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) | |
| return criteria | |
| def _merge_criteria_processor_list( | |
| self, | |
| default_list: Union[LogitsProcessorList, StoppingCriteriaList], | |
| custom_list: Union[LogitsProcessorList, StoppingCriteriaList], | |
| ) -> Union[LogitsProcessorList, StoppingCriteriaList]: | |
| """ | |
| Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same | |
| processor/criteria is present on both lists, use the user-defined one. | |
| (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.) | |
| """ | |
| if len(custom_list) == 0: | |
| return default_list | |
| final_list = type(default_list)() | |
| for default in default_list: | |
| using_custom = False | |
| for custom in custom_list: | |
| if type(custom) is type(default): | |
| object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" | |
| logger.warning_once( | |
| f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it " | |
| f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} " | |
| f"will take precedence. Please check the docstring of {type(custom)} to see related " | |
| "`.generate()` flags." | |
| ) | |
| final_list.append(custom) | |
| using_custom = True | |
| break | |
| if not using_custom: | |
| final_list.append(default) | |
| for custom in custom_list: | |
| if custom not in final_list: | |
| final_list.append(custom) | |
| return final_list | |
| def compute_transition_scores( | |
| self, | |
| sequences: torch.Tensor, | |
| scores: tuple[torch.Tensor], | |
| beam_indices: Optional[torch.Tensor] = None, | |
| normalize_logits: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was | |
| used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time. | |
| Parameters: | |
| sequences (`torch.LongTensor`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or | |
| shorter if all batches finished early due to the `eos_token_id`. | |
| scores (`tuple(torch.FloatTensor)`): | |
| Transition scores for each vocabulary token at each generation step. Beam transition scores consisting | |
| of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. | |
| Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), | |
| with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. | |
| beam_indices (`torch.LongTensor`, *optional*): | |
| Beam indices of generated token id at each generation step. `torch.LongTensor` of shape | |
| `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at | |
| generate-time. | |
| normalize_logits (`bool`, *optional*, defaults to `False`): | |
| Whether to normalize the logits (which, for legacy reasons, may be unnormalized). | |
| Return: | |
| `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing | |
| the transition scores (logits) | |
| Examples: | |
| ```python | |
| >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM | |
| >>> import numpy as np | |
| >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> tokenizer.pad_token_id = tokenizer.eos_token_id | |
| >>> inputs = tokenizer(["Today is"], return_tensors="pt") | |
| >>> # Example 1: Print the scores for each token generated with Greedy Search | |
| >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) | |
| >>> transition_scores = model.compute_transition_scores( | |
| ... outputs.sequences, outputs.scores, normalize_logits=True | |
| ... ) | |
| >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for | |
| >>> # encoder-decoder models, like BART or T5. | |
| >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] | |
| >>> generated_tokens = outputs.sequences[:, input_length:] | |
| >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): | |
| ... # | token | token string | log probability | probability | |
| ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") | |
| | 262 | the | -1.414 | 24.33% | |
| | 1110 | day | -2.609 | 7.36% | |
| | 618 | when | -2.010 | 13.40% | |
| | 356 | we | -1.859 | 15.58% | |
| | 460 | can | -2.508 | 8.14% | |
| >>> # Example 2: Reconstruct the sequence scores from Beam Search | |
| >>> outputs = model.generate( | |
| ... **inputs, | |
| ... max_new_tokens=5, | |
| ... num_beams=4, | |
| ... num_return_sequences=4, | |
| ... return_dict_in_generate=True, | |
| ... output_scores=True, | |
| ... ) | |
| >>> transition_scores = model.compute_transition_scores( | |
| ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False | |
| ... ) | |
| >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. | |
| >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the | |
| >>> # use case, you might want to recompute it with `normalize_logits=True`. | |
| >>> # Tip 2: the output length does NOT include the input length | |
| >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) | |
| >>> length_penalty = model.generation_config.length_penalty | |
| >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) | |
| >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) | |
| True | |
| ```""" | |
| # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent | |
| # to a beam search approach were the first (and only) beam is always selected | |
| if beam_indices is None: | |
| beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) | |
| beam_indices = beam_indices.expand(-1, len(scores)) | |
| # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being | |
| # seq_len - input_length | |
| scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) | |
| # 3. Optionally normalize the logits (across the vocab dimension) | |
| if normalize_logits: | |
| scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1]) | |
| scores = torch.nn.functional.log_softmax(scores, dim=1) | |
| scores = scores.reshape(-1, scores.shape[-1]) | |
| # 4. cut beam_indices to longest beam length | |
| beam_indices_mask = beam_indices < 0 | |
| max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() | |
| beam_indices = beam_indices.clone()[:, :max_beam_length] | |
| beam_indices_mask = beam_indices_mask[:, :max_beam_length] | |
| # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards | |
| beam_indices[beam_indices_mask] = 0 | |
| # 6. multiply beam_indices with vocab size to gather correctly from scores | |
| beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size | |
| # 7. Define which indices contributed to scores | |
| cut_idx = sequences.shape[-1] - max_beam_length | |
| indices = sequences[:, cut_idx:] + beam_sequence_indices | |
| # 8. Compute scores | |
| transition_scores = scores.gather(0, indices) | |
| # 9. Mask out transition_scores of beams that stopped early | |
| transition_scores[beam_indices_mask] = 0 | |
| return transition_scores | |
| def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): | |
| if assistant_model is None: | |
| return | |
| if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: | |
| attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] | |
| attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] | |
| are_equal = all( | |
| getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check | |
| ) | |
| if not are_equal: | |
| raise ValueError( | |
| "The main model and the assistant don't have compatible encoder-dependent input shapes. " | |
| "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." | |
| ) | |
| doc_reference = ( | |
| "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" | |
| ) | |
| if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: | |
| if assistant_tokenizer is not None: | |
| raise ValueError( | |
| f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." | |
| ) | |
| else: | |
| if tokenizer is None or assistant_tokenizer is None: | |
| raise ValueError( | |
| f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." | |
| ) | |
| def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): | |
| """Validates model kwargs for generation. Generate argument typos will also be caught here.""" | |
| # If a `Cache` instance is passed, checks whether the model is compatible with it | |
| if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: | |
| raise ValueError( | |
| f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " | |
| "check the model documentation for supported cache formats." | |
| ) | |
| # Excludes arguments that are handled before calling any model function | |
| if self.config.is_encoder_decoder: | |
| for key in ["decoder_input_ids"]: | |
| model_kwargs.pop(key, None) | |
| unused_model_args = [] | |
| model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) | |
| # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If | |
| # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) | |
| if "kwargs" in model_args or "model_kwargs" in model_args: | |
| model_args |= set(inspect.signature(self.forward).parameters) | |
| # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` | |
| if self.config.is_encoder_decoder: | |
| base_model = getattr(self, self.base_model_prefix, None) | |
| # allow encoder kwargs | |
| encoder = getattr(self, "encoder", None) | |
| # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. | |
| # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` | |
| # TODO: A better way to handle this. | |
| if encoder is None and base_model is not None: | |
| encoder = getattr(base_model, "encoder", None) | |
| if encoder is not None: | |
| encoder_model_args = set(inspect.signature(encoder.forward).parameters) | |
| model_args |= encoder_model_args | |
| # allow decoder kwargs | |
| decoder = getattr(self, "decoder", None) | |
| if decoder is None and base_model is not None: | |
| decoder = getattr(base_model, "decoder", None) | |
| if decoder is not None: | |
| decoder_model_args = set(inspect.signature(decoder.forward).parameters) | |
| model_args |= {f"decoder_{x}" for x in decoder_model_args} | |
| for key, value in model_kwargs.items(): | |
| if value is not None and key not in model_args: | |
| unused_model_args.append(key) | |
| if unused_model_args: | |
| raise ValueError( | |
| f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" | |
| " generate arguments will also show up in this list)" | |
| ) | |
| def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): | |
| """Performs validation related to the resulting generated length""" | |
| # 1. Max length warnings related to poor parameterization | |
| if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: | |
| # 20 is the default max_length of the generation config | |
| warnings.warn( | |
| f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " | |
| "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " | |
| "generation.", | |
| UserWarning, | |
| ) | |
| if input_ids_length >= generation_config.max_length: | |
| input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| raise ValueError( | |
| f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" | |
| f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" | |
| " increasing `max_length` or, better yet, setting `max_new_tokens`." | |
| ) | |
| # 2. Min length warnings due to unfeasible parameter combinations | |
| min_length_error_suffix = ( | |
| " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " | |
| "increase the maximum length." | |
| ) | |
| if has_default_max_length: | |
| min_length_error_suffix += ( | |
| f" Note that `max_length` is set to {generation_config.max_length}, its default value." | |
| ) | |
| if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: | |
| warnings.warn( | |
| f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" | |
| f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, | |
| UserWarning, | |
| ) | |
| if generation_config.min_new_tokens is not None: | |
| min_length = generation_config.min_new_tokens + input_ids_length | |
| if min_length > generation_config.max_length: | |
| warnings.warn( | |
| f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " | |
| f"added to the prompt length ({input_ids_length}), is larger than" | |
| f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, | |
| UserWarning, | |
| ) | |
| def _prepare_generated_length( | |
| self, | |
| generation_config, | |
| has_default_max_length, | |
| has_default_min_length, | |
| model_input_name, | |
| input_ids_length, | |
| inputs_tensor, | |
| ): | |
| """Prepared max and min length in generation configs to avoid clashes between similar attributes""" | |
| if generation_config.max_new_tokens is not None: | |
| if not has_default_max_length and generation_config.max_length is not None: | |
| logger.warning( | |
| f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" | |
| f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " | |
| "Please refer to the documentation for more information. " | |
| "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | |
| ) | |
| generation_config.max_length = generation_config.max_new_tokens + input_ids_length | |
| # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length | |
| # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` | |
| elif ( | |
| model_input_name == "inputs_embeds" | |
| and input_ids_length != inputs_tensor.shape[1] | |
| and not self.config.is_encoder_decoder | |
| ): | |
| generation_config.max_length -= inputs_tensor.shape[1] | |
| elif has_default_max_length: # by default let's always generate 20 new tokens | |
| if generation_config.max_length == GenerationConfig().max_length: | |
| generation_config.max_length = generation_config.max_length + input_ids_length | |
| max_position_embeddings = getattr(self.config, "max_position_embeddings", None) | |
| if max_position_embeddings is not None: | |
| generation_config.max_length = min(generation_config.max_length, max_position_embeddings) | |
| # same for min length | |
| if generation_config.min_new_tokens is not None: | |
| if not has_default_min_length: | |
| logger.warning( | |
| f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" | |
| f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " | |
| "Please refer to the documentation for more information. " | |
| "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | |
| ) | |
| generation_config.min_length = generation_config.min_new_tokens + input_ids_length | |
| elif ( | |
| model_input_name == "inputs_embeds" | |
| and input_ids_length != inputs_tensor.shape[1] | |
| and not self.config.is_encoder_decoder | |
| ): | |
| generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) | |
| return generation_config | |
| def _prepare_generation_config( | |
| self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict | |
| ) -> tuple[GenerationConfig, dict]: | |
| """ | |
| Prepares the base generation config, then applies any generation configuration options from kwargs. This | |
| function handles retrocompatibility with respect to configuration files. | |
| """ | |
| # parameterization priority: | |
| # kwargs > non-global default values in `generation_config` > `model.generation_config` > GenerationConfig() | |
| # TODO (joao): per-model generation config classes. | |
| using_model_generation_config = False | |
| if generation_config is None: | |
| # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, | |
| # the following conditions must be met | |
| # 1) the generation config must have been created from the model config (`_from_model_config` field); | |
| # 2) the generation config must have seen no modification since its creation (the hash is the same); | |
| # 3) there are non-default generation parameters in the model config. | |
| # 4) the user must have set new generation parameters in the model config. | |
| if ( | |
| self.generation_config._from_model_config # 1) | |
| and self.generation_config._original_object_hash == hash(self.generation_config) # 2) | |
| and len(self.config._get_non_default_generation_parameters()) > 0 # 3) | |
| ): | |
| new_generation_config = GenerationConfig.from_model_config(self.config) | |
| if new_generation_config != self.generation_config: # 4) | |
| warnings.warn( | |
| "You have modified the pretrained model configuration to control generation. This is a" | |
| " deprecated strategy to control generation and will be removed in v5." | |
| " Please use and modify the model generation configuration (see" | |
| " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", | |
| UserWarning, | |
| ) | |
| self.generation_config = new_generation_config | |
| generation_config = self.generation_config | |
| using_model_generation_config = True | |
| # `torch.export.export` usually raises an exception if it is called | |
| # with ``strict=True``. deepcopy can only be processed if ``strict=False``. | |
| generation_config = copy.deepcopy(generation_config) | |
| if not using_model_generation_config: | |
| # If `generation_config` is provided: | |
| # - `use_model_defaults`: let's fallback ALL default values to the model's generation config | |
| # - otherwise: legacy behavior, let's just make sure we have the tokens defined | |
| model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) | |
| if use_model_defaults is True or ( | |
| use_model_defaults is None and model_base_version >= version.parse("4.50.0") | |
| ): | |
| modified_values = {} | |
| global_default_generation_config = GenerationConfig() | |
| model_generation_config = self.generation_config | |
| # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy | |
| for key, model_gen_config_value in model_generation_config.__dict__.items(): | |
| if key.startswith("_") or key == "transformers_version": # metadata | |
| continue | |
| global_default_value = getattr(global_default_generation_config, key, None) | |
| custom_gen_config_value = getattr(generation_config, key, None) | |
| if ( | |
| custom_gen_config_value == global_default_value | |
| and model_gen_config_value != global_default_value | |
| ): | |
| modified_values[key] = model_gen_config_value | |
| setattr(generation_config, key, model_gen_config_value) | |
| if use_model_defaults is None and len(modified_values) > 0: | |
| logger.warning_once( | |
| f"`generation_config` default values have been modified to match model-specific defaults: " | |
| f"{modified_values}. If this is not desired, please set these values explicitly." | |
| ) | |
| else: | |
| if generation_config.bos_token_id is None: | |
| generation_config.bos_token_id = self.generation_config.bos_token_id | |
| if generation_config.eos_token_id is None: | |
| generation_config.eos_token_id = self.generation_config.eos_token_id | |
| if generation_config.pad_token_id is None: | |
| generation_config.pad_token_id = self.generation_config.pad_token_id | |
| if generation_config.decoder_start_token_id is None: | |
| generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id | |
| # Finally, apply any passed kwargs | |
| model_kwargs = generation_config.update(**kwargs) | |
| return generation_config, model_kwargs | |
| def _get_initial_cache_position(self, seq_length, device, model_kwargs): | |
| """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" | |
| # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` | |
| if "cache_position" in model_kwargs and model_kwargs["cache_position"]: | |
| return model_kwargs | |
| if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: | |
| cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 | |
| elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: | |
| cache_position = ( | |
| torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 | |
| ) | |
| else: | |
| cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1 | |
| past_length = 0 | |
| if model_kwargs.get("past_key_values") is not None: | |
| cache = model_kwargs["past_key_values"] | |
| past_length = 0 | |
| if not isinstance(cache, Cache): | |
| past_length = cache[0][0].shape[2] | |
| elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: | |
| past_length = cache.get_seq_length() | |
| cache_position = cache_position[past_length:] | |
| model_kwargs["cache_position"] = cache_position | |
| return model_kwargs | |
| def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]: | |
| """ | |
| Returns the device map for each decoder layer, to allocate the cache on the right device. | |
| Inspired from `dispatch_model` in accelerate. | |
| """ | |
| execution_device_map = None | |
| if hasattr(self, "hf_device_map"): | |
| if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: | |
| main_device = "cpu" | |
| else: | |
| main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] | |
| execution_device_map = { | |
| name: main_device if device in ["cpu", "disk"] else device | |
| for name, device in self.hf_device_map.items() | |
| } | |
| # No `execution_device_map` -> rely on `self.device` to allocate the cache | |
| if execution_device_map is None: | |
| return None | |
| # Single device for all layers | |
| num_hidden_layers = self.config.get_text_config().num_hidden_layers | |
| if len(execution_device_map) == 1 and "" in execution_device_map: | |
| return dict.fromkeys(range(num_hidden_layers), execution_device_map[""]) | |
| # Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device. | |
| layer_device_map = {} | |
| # Case 1: The model has a `get_decoder` method, we can use it to find the decoder name. | |
| if hasattr(self, "get_decoder"): | |
| decoder_name = None | |
| for name, module in self.named_modules(): | |
| if module is self.get_decoder(): | |
| decoder_name = name | |
| break | |
| if decoder_name is None: | |
| raise RuntimeError( | |
| "`model.get_decoder()` is not returning a named module of the model. This is unexpected, please " | |
| "open an issue on GitHub." | |
| ) | |
| decoder_mapped_modules = [ | |
| module_name for module_name in execution_device_map.keys() if decoder_name in module_name | |
| ] | |
| # The decoder name may be present in `execution_device_map` in two forms: | |
| # a) each layer has a device mapping | |
| if len(decoder_mapped_modules) >= num_hidden_layers: | |
| for idx in range(num_hidden_layers): | |
| for module_name in decoder_mapped_modules: | |
| if f".{idx}." in f"{module_name}.": | |
| layer_device_map[idx] = execution_device_map[module_name] | |
| break | |
| # b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map, | |
| # then the mapping is done in a parent module | |
| else: | |
| while True: | |
| if decoder_name in execution_device_map: | |
| layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name]) | |
| break | |
| elif "." in decoder_name: | |
| decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module | |
| else: | |
| raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map") | |
| # Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index) | |
| else: | |
| for layer in execution_device_map: | |
| for idx in range(num_hidden_layers): | |
| if f".{idx}." in f"{layer}.": | |
| layer_device_map[idx] = execution_device_map[layer] | |
| break | |
| for idx in range(num_hidden_layers): | |
| if idx not in layer_device_map: | |
| raise RuntimeError(f"layer {idx} has not been mapped to a device.") | |
| return layer_device_map | |
| def _get_cache( | |
| self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs | |
| ) -> Cache: | |
| """ | |
| Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a | |
| new `generate` call requires a larger cache or uses a different batch size. | |
| Returns the resulting cache object. | |
| """ | |
| if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""): | |
| cache_implementation = "hybrid_chunked" | |
| cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] | |
| requires_cross_attention_cache = ( | |
| self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None | |
| ) | |
| if hasattr(self, "_cache"): | |
| cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache | |
| if cache_implementation == "sliding_window": | |
| max_cache_len = min(self.config.sliding_window, max_cache_len) | |
| need_new_cache = ( | |
| not hasattr(self, "_cache") | |
| or (not isinstance(cache_to_check, cache_cls)) | |
| or cache_to_check.max_batch_size != batch_size | |
| or isinstance( | |
| cache_to_check, (HybridChunkedCache, OffloadedHybridCache) | |
| ) # due to internal slicing, we always re-init | |
| ) | |
| if cache_implementation != "mamba": | |
| need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len | |
| if requires_cross_attention_cache and hasattr(self, "_cache"): | |
| need_new_cache = ( | |
| need_new_cache | |
| or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] | |
| ) | |
| if need_new_cache: | |
| if hasattr(self.config, "_pre_quantization_dtype"): | |
| cache_dtype = self.config._pre_quantization_dtype | |
| else: | |
| cache_dtype = self.dtype | |
| layer_device_map = self._get_layer_device_map_for_cache_init() | |
| cache_kwargs = { | |
| "config": self.config.get_text_config(), | |
| "max_batch_size": batch_size, | |
| "max_cache_len": max_cache_len, | |
| "dtype": cache_dtype, | |
| "device": device, | |
| "layer_device_map": layer_device_map, | |
| } | |
| self._cache = cache_cls(**cache_kwargs) | |
| if requires_cross_attention_cache: | |
| encoder_kwargs = cache_kwargs.copy() | |
| encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] | |
| self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) | |
| else: | |
| self._cache.reset() | |
| return self._cache | |
| def _supports_default_dynamic_cache(self) -> bool: | |
| """ | |
| Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. | |
| This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which | |
| uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in | |
| order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed | |
| for `HybridMambaAttentionDynamicCache`). | |
| """ | |
| return ( | |
| self._supports_cache_class | |
| and "jamba" not in self.__class__.__name__.lower() | |
| and "zamba" not in self.__class__.__name__.lower() | |
| and "bamba" not in self.__class__.__name__.lower() | |
| and "minimax" not in self.__class__.__name__.lower() | |
| ) | |
| def _prepare_cache_for_generation( | |
| self, | |
| generation_config: GenerationConfig, | |
| model_kwargs: dict, | |
| assistant_model: "PreTrainedModel", | |
| batch_size: int, | |
| max_cache_length: int, | |
| device: torch.device, | |
| ) -> bool: | |
| """ | |
| Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is | |
| instantiated, writes it to `model_kwargs`, under the name expected by the model. | |
| """ | |
| is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) | |
| cache_name = "past_key_values" if not is_hybrid_cache else "cache_params" | |
| requires_cross_attention_cache = ( | |
| self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None | |
| ) | |
| # Quick escape route 1: if the user specifies a cache, we only need to: | |
| # a) check for conflicting `generate` arguments | |
| # b) convert to the new cache format (if the user passes a legacy cache and model supports it) | |
| user_defined_cache = model_kwargs.get(cache_name) | |
| if user_defined_cache is not None: | |
| if generation_config.cache_implementation is not None: | |
| raise ValueError( | |
| f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " | |
| "Cache object) is unsupported. Please use only one of the two." | |
| ) | |
| if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): | |
| model_kwargs[cache_name] = ( | |
| DynamicCache.from_legacy_cache(user_defined_cache) | |
| if not requires_cross_attention_cache | |
| else EncoderDecoderCache.from_legacy_cache(user_defined_cache) | |
| ) | |
| return | |
| # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in | |
| # `generation_config.validate()`) | |
| if generation_config.use_cache is False: | |
| return | |
| # Quick escape route 3: model that only supports legacy caches = nothing to prepare | |
| if not self._supports_default_dynamic_cache(): | |
| if generation_config.cache_implementation is not None: | |
| warnings.warn( | |
| "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " | |
| f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " | |
| "ignored.", | |
| UserWarning, | |
| ) | |
| return | |
| # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` | |
| # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, | |
| # which is only supported in dynamic caches atm | |
| if assistant_model is not None and generation_config.cache_implementation is not None: | |
| logger.warning_once( | |
| "An assistant model is provided, using a dynamic cache instead of a cache of type=" | |
| f"'{generation_config.cache_implementation}'." | |
| ) | |
| generation_config.cache_implementation = None | |
| generation_config.cache_implementation = generation_config.cache_implementation or getattr( | |
| self.config.get_text_config(), "cache_implementation", None | |
| ) | |
| if generation_config.cache_implementation is not None: | |
| if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: | |
| if generation_config.cache_implementation == "static" and not self._supports_static_cache: | |
| raise ValueError( | |
| "This model does not support `cache_implementation='static'`. Please check the following " | |
| "issue: https://github.com/huggingface/transformers/issues/28981" | |
| ) | |
| model_kwargs[cache_name] = self._get_cache( | |
| cache_implementation=generation_config.cache_implementation, | |
| batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, | |
| max_cache_len=max_cache_length, | |
| device=device, | |
| model_kwargs=model_kwargs, | |
| ) | |
| elif generation_config.cache_implementation == "quantized": | |
| if not self._supports_quantized_cache: | |
| raise ValueError( | |
| "This model does not support the quantized cache. If you want your model to support quantized " | |
| "cache, please open an issue and tag @zucchini-nlp." | |
| ) | |
| cache_config = ( | |
| generation_config.cache_config | |
| if generation_config.cache_config is not None | |
| else QuantizedCacheConfig() | |
| ) | |
| cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] | |
| if cache_config.backend == "quanto" and not is_optimum_quanto_available(): | |
| raise ImportError( | |
| "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " | |
| "Please install it via with `pip install optimum-quanto`" | |
| ) | |
| elif cache_config.backend == "HQQ" and not is_hqq_available(): | |
| raise ImportError( | |
| "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " | |
| "Please install it via with `pip install hqq`" | |
| ) | |
| model_kwargs[cache_name] = cache_class(cache_config) | |
| elif generation_config.cache_implementation == "offloaded": | |
| model_kwargs[cache_name] = OffloadedCache() | |
| elif generation_config.cache_implementation == "dynamic": | |
| model_kwargs[cache_name] = DynamicCache() | |
| # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that | |
| # keeps copying the cache thus using much more memory | |
| else: | |
| model_kwargs[cache_name] = ( | |
| DynamicCache() | |
| if not requires_cross_attention_cache | |
| else EncoderDecoderCache(DynamicCache(), DynamicCache()) | |
| ) | |
| def _supports_logits_to_keep(self) -> bool: | |
| """ | |
| Return True if the current model supports the keyword argument `logits_to_keep` in forward() | |
| to save memory. Checking it in this way allows to avoid using a new model attribute. | |
| """ | |
| return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) | |
| def _prepare_special_tokens( | |
| self, | |
| generation_config: GenerationConfig, | |
| kwargs_has_attention_mask: Optional[bool] = None, | |
| device: Optional[Union[torch.device, str]] = None, | |
| ): | |
| """ | |
| Prepares the special tokens for generation, overwriting the generation config with their processed versions | |
| converted to tensor. | |
| Note that `generation_config` is changed in place and stops being serializable after this method is called. | |
| That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the | |
| function). However, if called outside `generate`, consider creating a copy of `generation_config` first. | |
| """ | |
| # Convert special tokens to tensors | |
| def _tensor_or_none(token, device=None): | |
| if token is None: | |
| return token | |
| device = device if device is not None else self.device | |
| if isinstance(token, torch.Tensor): | |
| return token.to(device) | |
| return torch.tensor(token, device=device, dtype=torch.long) | |
| bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) | |
| eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) | |
| pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) | |
| decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device) | |
| # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) | |
| if self.config.is_encoder_decoder: | |
| decoder_start_token_tensor = ( | |
| decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor | |
| ) | |
| # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). | |
| if eos_token_tensor is not None and eos_token_tensor.ndim == 0: | |
| eos_token_tensor = eos_token_tensor.unsqueeze(0) | |
| # Set pad token if unset (and there are conditions to do so) | |
| if pad_token_tensor is None and eos_token_tensor is not None: | |
| if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: | |
| logger.warning( | |
| "The attention mask and the pad token id were not set. As a consequence, you may observe " | |
| "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." | |
| ) | |
| pad_token_tensor = eos_token_tensor[0] | |
| logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") | |
| # Sanity checks/warnings | |
| if self.config.is_encoder_decoder and decoder_start_token_tensor is None: | |
| raise ValueError( | |
| "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." | |
| ) | |
| if ( | |
| eos_token_tensor is not None | |
| and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any() | |
| ): | |
| if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: | |
| logger.warning_once( | |
| "The attention mask is not set and cannot be inferred from input because pad token is same as " | |
| "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " | |
| "`attention_mask` to obtain reliable results." | |
| ) | |
| if eos_token_tensor is not None and ( | |
| torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() | |
| ): | |
| logger.warning( | |
| f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation " | |
| "will not stop until the maximum length is reached. Depending on other flags, it may even crash." | |
| ) | |
| # Update generation config with the updated special tokens tensors | |
| # NOTE: this must be written into a different attribute name than the one holding the original special tokens | |
| # (in their non-tensor form), in order to enable end-to-end compilation. See | |
| # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations | |
| generation_config._bos_token_tensor = bos_token_tensor | |
| generation_config._eos_token_tensor = eos_token_tensor | |
| generation_config._pad_token_tensor = pad_token_tensor | |
| generation_config._decoder_start_token_tensor = decoder_start_token_tensor | |
| def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: GenerationConfig) -> bool: | |
| """ | |
| Determines whether to trigger auto-compilation of the model's forward pass at generation time. | |
| """ | |
| # Override: honor `disable_compile` flag | |
| if generation_config.disable_compile: | |
| return False | |
| # Base logic | |
| valid_hardware = self.device.type == "cuda" or bool( | |
| generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices | |
| ) | |
| using_compilable_cache = ( | |
| isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable | |
| ) | |
| can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache | |
| # Exception 1: Some quantization methods do not support compilation | |
| if getattr(self, "hf_quantizer", None) is not None: | |
| can_compile &= self.hf_quantizer.is_compileable | |
| if hasattr(self, "hf_device_map"): | |
| all_model_devices = set(self.hf_device_map.values()) | |
| # Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash) | |
| has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1 | |
| can_compile &= not has_cpu_offload | |
| # Exception 3: Disk offload is not supported for compilation | |
| has_disk_offload = "disk" in all_model_devices | |
| can_compile &= not has_disk_offload | |
| # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn | |
| # them | |
| if generation_config.compile_config is not None and not can_compile: | |
| logger.warning_once( | |
| "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation " | |
| "will be skipped." | |
| ) | |
| return can_compile | |
| @torch.no_grad() | |
| def generate( | |
| self, | |
| inputs: Optional[torch.Tensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, | |
| synced_gpus: Optional[bool] = None, | |
| assistant_model: Optional["PreTrainedModel"] = None, | |
| streamer: Optional["BaseStreamer"] = None, | |
| negative_prompt_ids: Optional[torch.Tensor] = None, | |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | |
| use_model_defaults: Optional[bool] = None, | |
| custom_generate: Optional[str] = None, | |
| **kwargs, | |
| ) -> Union[GenerateOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head. | |
| <Tip warning={true}> | |
| Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the | |
| model's default generation configuration. You can override any `generation_config` by passing the corresponding | |
| parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. | |
| For an overview of generation strategies and code examples, check out the [following | |
| guide](../generation_strategies). | |
| </Tip> | |
| Parameters: | |
| inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): | |
| The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the | |
| method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` | |
| should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of | |
| `input_ids`, `input_values`, `input_features`, or `pixel_values`. | |
| generation_config ([`~generation.GenerationConfig`], *optional*): | |
| The generation configuration to be used as base parametrization for the generation call. `**kwargs` | |
| passed to generate matching the attributes of `generation_config` will override them. If | |
| `generation_config` is not provided, the default will be used, which has the following loading | |
| priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model | |
| configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s | |
| default values, whose documentation should be checked to parameterize generation. | |
| logits_processor (`LogitsProcessorList`, *optional*): | |
| Custom logits processors that complement the default logits processors built from arguments and | |
| generation config. If a logit processor is passed that is already created with the arguments or a | |
| generation config an error is thrown. This feature is intended for advanced users. | |
| stopping_criteria (`StoppingCriteriaList`, *optional*): | |
| Custom stopping criteria that complements the default stopping criteria built from arguments and a | |
| generation config. If a stopping criteria is passed that is already created with the arguments or a | |
| generation config an error is thrown. If your stopping criteria depends on the `scores` input, make | |
| sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is | |
| intended for advanced users. | |
| prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): | |
| If provided, this function constraints the beam search to allowed tokens only at each step. If not | |
| provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and | |
| `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned | |
| on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful | |
| for constrained generation conditioned on the prefix, as described in [Autoregressive Entity | |
| Retrieval](https://huggingface.co/papers/2010.00904). | |
| synced_gpus (`bool`, *optional*): | |
| Whether to continue running the while loop until max_length. Unless overridden, this flag will be set | |
| to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid | |
| deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. | |
| assistant_model (`PreTrainedModel`, *optional*): | |
| An assistant model that can be used to accelerate generation. The assistant model must have the exact | |
| same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model | |
| is much faster than running generation with the model you're calling generate from. As such, the | |
| assistant model should be much smaller. | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| The negative prompt needed for some processors such as CFG. The batch size must match the input batch | |
| size. This is an experimental feature, subject to breaking API changes in future versions. | |
| negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Attention_mask for `negative_prompt_ids`. | |
| use_model_defaults (`bool`, *optional*): | |
| When it is `True`, unset parameters in `generation_config` will be set to the model-specific default | |
| generation configuration (`model.generation_config`), as opposed to the global defaults | |
| (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be | |
| `True`. | |
| custom_generate (`str`, *optional*): | |
| A string containing the name of a huggingface.co repository. If provided, the custom `generate` | |
| function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the | |
| standard `generate` method. Note that the logic is for generation is entirely defined in that | |
| repository, and the return type may be different from the standard `generate` method. | |
| kwargs (`dict[str, Any]`, *optional*): | |
| Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be | |
| forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder | |
| specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. | |
| Return: | |
| [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` | |
| or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. | |
| If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible | |
| [`~utils.ModelOutput`] types are: | |
| - [`~generation.GenerateDecoderOnlyOutput`], | |
| - [`~generation.GenerateBeamDecoderOnlyOutput`] | |
| If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible | |
| [`~utils.ModelOutput`] types are: | |
| - [`~generation.GenerateEncoderDecoderOutput`], | |
| - [`~generation.GenerateBeamEncoderDecoderOutput`] | |
| """ | |
| # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead | |
| trust_remote_code = kwargs.pop("trust_remote_code", None) | |
| if custom_generate is not None: | |
| # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: | |
| # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to | |
| # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. | |
| global_keys_to_exclude = { | |
| "self", | |
| "kwargs", | |
| "global_keys_to_exclude", | |
| "trust_remote_code", | |
| "custom_generate", | |
| } | |
| generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} | |
| generate_arguments.update(kwargs) | |
| custom_generate_function = self.load_custom_generate( | |
| custom_generate, trust_remote_code=trust_remote_code, **kwargs | |
| ) | |
| return custom_generate_function(model=self, **generate_arguments) | |
| # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | |
| tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria | |
| assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation | |
| generation_config, model_kwargs = self._prepare_generation_config( | |
| generation_config, use_model_defaults, **kwargs | |
| ) | |
| self._validate_model_kwargs(model_kwargs.copy()) | |
| self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) | |
| # 2. Set generation parameters if not already defined | |
| if synced_gpus is None: | |
| synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 | |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
| accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) | |
| requires_attention_mask = "encoder_outputs" not in model_kwargs | |
| kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None | |
| # 3. Define model inputs | |
| inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( | |
| inputs, generation_config.bos_token_id, model_kwargs | |
| ) | |
| batch_size = inputs_tensor.shape[0] | |
| device = inputs_tensor.device | |
| self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) | |
| # decoder-only models must use left-padding for batched generation. | |
| if not self.config.is_encoder_decoder: | |
| # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` | |
| # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. | |
| if ( | |
| generation_config._pad_token_tensor is not None | |
| and batch_size > 1 | |
| and len(inputs_tensor.shape) == 2 | |
| and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 | |
| ): | |
| logger.warning( | |
| "A decoder-only architecture is being used, but right-padding was detected! For correct " | |
| "generation results, please set `padding_side='left'` when initializing the tokenizer." | |
| ) | |
| # 4. Define other model kwargs | |
| # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are | |
| # generating the first new token or not, and we only want to use the embeddings for the first new token) | |
| if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": | |
| generation_config.use_cache = True | |
| if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: | |
| model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( | |
| inputs_tensor, generation_config, model_kwargs | |
| ) | |
| elif kwargs_has_attention_mask: | |
| # TODO (joao): generalize this check with other types of inputs | |
| if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: | |
| raise ValueError("`attention_mask` passed to `generate` must be 2D.") | |
| if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: | |
| # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` | |
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( | |
| inputs_tensor, model_kwargs, model_input_name, generation_config | |
| ) | |
| # 5. Prepare `input_ids` which will be used for auto-regressive generation | |
| if self.config.is_encoder_decoder: | |
| input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( | |
| batch_size=batch_size, | |
| model_input_name=model_input_name, | |
| model_kwargs=model_kwargs, | |
| decoder_start_token_id=generation_config._decoder_start_token_tensor, | |
| device=inputs_tensor.device, | |
| ) | |
| else: | |
| input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") | |
| if generation_config.token_healing: | |
| input_ids = self.heal_tokens(input_ids, tokenizer) | |
| if streamer is not None: | |
| streamer.put(input_ids.cpu()) | |
| # 6. Prepare `max_length` depending on other stopping criteria. | |
| input_ids_length = input_ids.shape[1] | |
| has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None | |
| has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None | |
| generation_config = self._prepare_generated_length( | |
| generation_config=generation_config, | |
| has_default_max_length=has_default_max_length, | |
| has_default_min_length=has_default_min_length, | |
| model_input_name=model_input_name, | |
| inputs_tensor=inputs_tensor, | |
| input_ids_length=input_ids_length, | |
| ) | |
| # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole | |
| # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding | |
| # dynamically overrides this value as it can need more than the last token logits | |
| if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: | |
| model_kwargs["logits_to_keep"] = 1 | |
| self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) | |
| # 7. Prepare the cache. | |
| # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. | |
| # - different models have a different cache name expected by the model (default = "past_key_values") | |
| # - `max_length`, prepared above, is used to determine the maximum cache length | |
| max_cache_length = generation_config.max_length - 1 | |
| if ( | |
| inputs_tensor.shape[1] != input_ids_length | |
| and model_input_name == "inputs_embeds" | |
| and not self.config.is_encoder_decoder | |
| ): | |
| max_cache_length += inputs_tensor.shape[1] | |
| self._prepare_cache_for_generation( | |
| generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device | |
| ) | |
| # 8. determine generation mode | |
| generation_mode = generation_config.get_generation_mode(assistant_model) | |
| if streamer is not None and (generation_config.num_beams > 1): | |
| raise ValueError( | |
| "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." | |
| ) | |
| if self.device.type != input_ids.device.type: | |
| warnings.warn( | |
| "You are calling .generate() with the `input_ids` being on a device type different" | |
| f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" | |
| f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." | |
| " Please make sure that you have put `input_ids` to the" | |
| f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" | |
| " running `.generate()`.", | |
| UserWarning, | |
| ) | |
| # 9. prepare logits processors and stopping criteria | |
| prepared_logits_processor = self._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_length, | |
| encoder_input_ids=inputs_tensor, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| logits_processor=logits_processor, | |
| device=inputs_tensor.device, | |
| model_kwargs=model_kwargs, | |
| negative_prompt_ids=negative_prompt_ids, | |
| negative_prompt_attention_mask=negative_prompt_attention_mask, | |
| ) | |
| prepared_stopping_criteria = self._get_stopping_criteria( | |
| generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs | |
| ) | |
| # Set model_kwargs `use_cache` so we can use it later in forward runs | |
| model_kwargs["use_cache"] = generation_config.use_cache | |
| # 10. go into different generation modes | |
| if generation_mode == GenerationMode.ASSISTED_GENERATION: | |
| if generation_config.num_return_sequences > 1: | |
| raise ValueError( | |
| "num_return_sequences has to be 1 when doing assisted generate, " | |
| f"but is {generation_config.num_return_sequences}." | |
| ) | |
| if batch_size > 1: | |
| raise ValueError("assisted generate is only supported for batch_size = 1") | |
| if not model_kwargs["use_cache"]: | |
| raise ValueError("assisted generate requires `use_cache=True`") | |
| if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: | |
| raise ValueError("assisted generate is not supported with Static cache classes`") | |
| if self._is_stateful: | |
| # In assisted generation we need the ability to confirm whether the model would pick certain tokens, | |
| # which is not possible with stateful models (they can't reset to a previous subset of generated text) | |
| raise ValueError( | |
| f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" | |
| ) | |
| # 11. Get the candidate generator, given the parameterization | |
| candidate_generator = self._get_candidate_generator( | |
| generation_config=generation_config, | |
| input_ids=input_ids, | |
| inputs_tensor=inputs_tensor, | |
| assistant_model=assistant_model, | |
| logits_processor=logits_processor, | |
| target_tokenizer=tokenizer, | |
| assistant_tokenizer=assistant_tokenizer, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # 12. run assisted generate | |
| result = self._assisted_decoding( | |
| input_ids, | |
| candidate_generator=candidate_generator, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode == GenerationMode.DOLA_GENERATION: | |
| if not trust_remote_code: | |
| logger.warning_once( | |
| "DoLa Decoding is scheduled to be moved to a `custom_generate` repository in v4.55.0. " | |
| "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." | |
| ) | |
| if self._is_stateful: | |
| # DoLa decoding was not designed for stateful models, and would require some changes | |
| raise ValueError( | |
| f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" | |
| ) | |
| result = self._dola_decoding( | |
| input_ids, | |
| dola_layers=generation_config.dola_layers, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: | |
| if not trust_remote_code: | |
| logger.warning_once( | |
| "Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " | |
| "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." | |
| ) | |
| if not model_kwargs["use_cache"]: | |
| raise ValueError("Contrastive search requires `use_cache=True`") | |
| if self._is_stateful: | |
| # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) | |
| raise ValueError( | |
| f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" | |
| ) | |
| result = self._contrastive_search( | |
| input_ids, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): | |
| # 11. expand input_ids with `num_return_sequences` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) | |
| result = self._sample( | |
| input_ids, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): | |
| # 11. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # 12. run beam sample | |
| result = self._beam_search( | |
| input_ids, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: | |
| logger.warning_once( | |
| "Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " | |
| "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." | |
| ) | |
| # 11. prepare beam search scorer | |
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=inputs_tensor.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| num_beam_groups=generation_config.num_beam_groups, | |
| max_length=generation_config.max_length, | |
| ) | |
| # 12. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # 13. run beam search | |
| result = self._group_beam_search( | |
| input_ids, | |
| beam_scorer, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: | |
| logger.warning_once( | |
| "Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. " | |
| "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call." | |
| ) | |
| final_constraints = [] | |
| if generation_config.constraints is not None: | |
| final_constraints = generation_config.constraints | |
| if generation_config.force_words_ids is not None: | |
| def typeerror(): | |
| raise ValueError( | |
| "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` " | |
| f"of positive integers, but is {generation_config.force_words_ids}." | |
| ) | |
| if ( | |
| not isinstance(generation_config.force_words_ids, list) | |
| or len(generation_config.force_words_ids) == 0 | |
| ): | |
| typeerror() | |
| for word_ids in generation_config.force_words_ids: | |
| if isinstance(word_ids[0], list): | |
| if not isinstance(word_ids, list) or len(word_ids) == 0: | |
| typeerror() | |
| if any(not isinstance(token_ids, list) for token_ids in word_ids): | |
| typeerror() | |
| if any( | |
| any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) | |
| for token_ids in word_ids | |
| ): | |
| typeerror() | |
| constraint = DisjunctiveConstraint(word_ids) | |
| else: | |
| if not isinstance(word_ids, list) or len(word_ids) == 0: | |
| typeerror() | |
| if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): | |
| typeerror() | |
| constraint = PhrasalConstraint(word_ids) | |
| final_constraints.append(constraint) | |
| # 11. prepare beam search scorer | |
| constrained_beam_scorer = ConstrainedBeamSearchScorer( | |
| constraints=final_constraints, | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=inputs_tensor.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| max_length=generation_config.max_length, | |
| ) | |
| # 12. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # 13. run beam search | |
| result = self._constrained_beam_search( | |
| input_ids, | |
| constrained_beam_scorer=constrained_beam_scorer, | |
| logits_processor=prepared_logits_processor, | |
| stopping_criteria=prepared_stopping_criteria, | |
| generation_config=generation_config, | |
| synced_gpus=synced_gpus, | |
| **model_kwargs, | |
| ) | |
| # Convert to legacy cache format if requested | |
| if ( | |
| generation_config.return_legacy_cache is True | |
| and hasattr(result, "past_key_values") | |
| and getattr(result.past_key_values, "to_legacy_cache") is not None | |
| ): | |
| result.past_key_values = result.past_key_values.to_legacy_cache() | |
| return result | |
| def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: | |
| """ | |
| Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is | |
| fed through `this_peer_finished`. ZeRO stage 3-friendly. | |
| """ | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| return False | |
| elif this_peer_finished: | |
| return False | |
| return True | |
| def heal_tokens( | |
| self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None | |
| ) -> torch.LongTensor: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head. | |
| Parameters: | |
| input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. | |
| tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. | |
| Return: | |
| `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. | |
| """ | |
| if tokenizer is None: | |
| raise ValueError( | |
| " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " | |
| "argument of `generate`." | |
| ) | |
| bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id | |
| vocab_trie = ExtensionsTrie(tokenizer.get_vocab()) | |
| generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id) | |
| # assumption: leading/trailing whitespace is not meaningful, so the prompts are | |
| # stripped before re-tokenizing to desensitize generation to whitespace artefacts | |
| prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] | |
| input_ids = tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| ).input_ids.to(input_ids.device) | |
| # replace bos with pad to not condition healing on it | |
| input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) | |
| """ | |
| the latter code assumes the input_ids is not empty, | |
| input_id has to be checked if contains elements | |
| """ | |
| if input_ids.numel() == 0: | |
| return input_ids | |
| tail_ids = input_ids[:, -1].tolist() | |
| space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] | |
| # tail tokens are used for a prefix search, thus, whitespaces are replaced with | |
| # their tokenization (e.g. 'Δ ') to enable search for tokens prefixed with a whitespace | |
| tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) | |
| for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): | |
| batch_ids = input_ids[batch_idx] | |
| if torch.all(batch_ids == pad_token_id).item(): | |
| continue # skip empty sequences (all pad ids) | |
| # apply bias for alternatives (extensions) to the tail token | |
| """ | |
| seq_bias key has to be tuple with int so have to use | |
| tokenizer function to convert str to int | |
| """ | |
| seq_bias = { | |
| (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok) | |
| } | |
| if len(seq_bias) == 1: | |
| continue # skip if there are no token alternatives to heal with | |
| # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' | |
| seq_bias[(tail_id,)] += 1.0 | |
| generation_config.update(sequence_bias=seq_bias) | |
| trimmed_ids = batch_ids[:-1] | |
| """ | |
| the latter code assumes trimmed_ids is not empty | |
| so have to check the its element count | |
| """ | |
| if trimmed_ids.numel() == 0: | |
| continue | |
| # if the prompt is a single (non-pad) token, regenerate from bos | |
| if len(batch_ids[batch_ids != pad_token_id]) == 1: | |
| trimmed_ids[-1] = bos_token_id | |
| input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config) | |
| return input_ids | |
| def _dola_decoding( | |
| self, | |
| input_ids: torch.LongTensor, | |
| dola_layers: Union[str, list[int]], | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| streamer: "BaseStreamer", | |
| **model_kwargs, | |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be | |
| used for decoder-only text models. | |
| The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language | |
| Models" (https://huggingface.co/papers/2309.03883) in ICLR 2024. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| dola_layers (`Union[str, list[int]]`): | |
| The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which | |
| means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices | |
| to be used for candidate layers. The 0-th layer is the word embedding layer of the model. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`, *optional*): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| model_kwargs: | |
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. | |
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] | |
| or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| if self.config.is_encoder_decoder: | |
| raise ValueError("DoLa decoding is only available for decoder-only models.") | |
| # init values | |
| pad_token_id = generation_config._pad_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | |
| do_sample = generation_config.do_sample | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # keep track of which sequences are already finished | |
| batch_size, cur_length = input_ids.shape[:2] | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs) | |
| this_peer_finished = False | |
| # prepare layers for DoLa decoding | |
| final_layer = self.config.get_text_config().num_hidden_layers | |
| # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, | |
| # as the early exit from word embeddings will become identity function | |
| # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th | |
| # layer otherwise. Notice that DoLa does not help shallow models much. | |
| if not self.config.tie_word_embeddings: | |
| start_layer = 0 | |
| elif final_layer > 2: | |
| start_layer = 2 | |
| elif final_layer == 2: | |
| start_layer = 1 | |
| else: | |
| start_layer = 0 | |
| # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` | |
| # are used for `'low'` and `'high'` layers, respectively. | |
| # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for | |
| # `'low'` and `'high'` layers, respectively. | |
| if isinstance(dola_layers, str) and dola_layers == "low": | |
| if start_layer == final_layer // 2: | |
| candidate_premature_layers = [start_layer] | |
| else: | |
| candidate_premature_layers = ( | |
| list(range(start_layer, final_layer // 2, 2)) | |
| if final_layer <= 40 | |
| else list(range(start_layer, 20, 2)) | |
| ) | |
| elif isinstance(dola_layers, str) and dola_layers == "high": | |
| candidate_premature_layers = ( | |
| list(range(final_layer // 2, final_layer, 2)) | |
| if final_layer <= 40 | |
| else list(range(final_layer - 20, final_layer, 2)) | |
| ) | |
| # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. | |
| elif isinstance(dola_layers, list): | |
| candidate_premature_layers = [i for i in dola_layers if i < final_layer] | |
| else: | |
| raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.") | |
| lm_head = self.get_output_embeddings() | |
| if lm_head is None: | |
| raise ValueError("DoLa is not supported for models that don't have output embeddings.") | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=True, | |
| ) | |
| # .float() is needed to retain precision for later logits manipulations | |
| final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32) | |
| final_logits = outputs.logits[:, -1, :].float() | |
| candidate_premature_logits = {} | |
| for candidate_premature_layer in candidate_premature_layers: | |
| candidate_premature_logits[candidate_premature_layer] = lm_head( | |
| outputs.hidden_states[candidate_premature_layer][:, -1, :] | |
| ).to(final_logits.device) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue | |
| next_token_logits = _dola_select_contrast( | |
| candidate_premature_layers, candidate_premature_logits, final_logits | |
| ) | |
| next_token_logits = next_token_logits.to(input_ids.device) | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores,) | |
| if output_logits: | |
| raw_logits += (final_layer_next_token_logits,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| if do_sample: # sample | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: # argmax | |
| next_tokens = torch.argmax(next_token_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if has_eos_stopping_criteria: | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| # stop when each sentence is finished | |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
| this_peer_finished = unfinished_sequences.max() == 0 | |
| if streamer is not None: | |
| streamer.end() | |
| if return_dict_in_generate: | |
| return GenerateDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return input_ids | |
| @torch.no_grad() | |
| def _contrastive_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| streamer: Optional["BaseStreamer"], | |
| **model_kwargs, | |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **contrastive search** and can | |
| be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| model_kwargs: | |
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. | |
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] | |
| or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | |
| top_k = generation_config.top_k | |
| penalty_alpha = generation_config.penalty_alpha | |
| pad_token_id = generation_config._pad_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| sequential = generation_config.low_memory | |
| # init attention / hidden states / scores tuples | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # keep track of which sequences are already finished | |
| batch_size, cur_len = input_ids.shape[:2] | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| # Create cosine_matrix_mask based on the attention_mask | |
| cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) | |
| if self.config.is_encoder_decoder: | |
| if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: | |
| cosine_matrix_mask = model_kwargs["decoder_attention_mask"] | |
| else: | |
| cosine_matrix_mask = model_kwargs["attention_mask"] | |
| cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) | |
| this_peer_finished = False | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; | |
| # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step | |
| if model_kwargs.get("past_key_values") is None or ( | |
| isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) | |
| and model_kwargs["past_key_values"].get_seq_length() == 0 | |
| ): | |
| # prepare inputs | |
| model_kwargs["use_cache"] = True | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save | |
| # the `encoder_outputs` | |
| outputs = self( | |
| **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions | |
| ) | |
| # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with | |
| # previous tokens) | |
| if self.config.is_encoder_decoder: | |
| last_hidden_states = outputs.decoder_hidden_states[-1] | |
| else: | |
| last_hidden_states = outputs.hidden_states[-1] | |
| # next logit for contrastive search to select top-k candidate tokens | |
| # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration | |
| # (the clone itself is always small) | |
| # torch.float32 is needed to retain precision for later logits manipulations | |
| logit_for_next_step = outputs.logits[:, -1, :].to( | |
| copy=True, dtype=torch.float32, device=input_ids.device | |
| ) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if not sequential: | |
| # Expands model inputs top_k times, for batched forward passes (akin to beam search). | |
| # input_ids is required for expanding visual inputs in qwen2vl | |
| _, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=top_k, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| past_key_values = model_kwargs.get("past_key_values") | |
| if past_key_values is None: | |
| raise ValueError( | |
| f"{self.__class__.__name__} does not support caching and therefore **can't** be used " | |
| "for contrastive search." | |
| ) | |
| elif ( | |
| not isinstance(past_key_values[0], (tuple, torch.Tensor)) | |
| or past_key_values[0][0].shape[0] != batch_size | |
| ): | |
| raise ValueError( | |
| f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " | |
| "used for contrastive search without further modifications." | |
| ) | |
| # contrastive_search main logic start: | |
| # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by | |
| # degeneration penalty | |
| processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) | |
| next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) | |
| top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_logits: | |
| raw_logits += (logit_for_next_step,) | |
| if output_scores: | |
| scores += (processed_logit_for_next_step,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # This is needed to properly delete outputs.logits which may be very large for this first iteration | |
| # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() | |
| del outputs | |
| if not sequential: | |
| # Replicates the new past_key_values to match the `top_k` candidates | |
| past = model_kwargs["past_key_values"] | |
| # If it is a static cache, modify it in-place layer after layer to save memory | |
| if isinstance(past, DynamicCache) or ( | |
| isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) | |
| ): | |
| past.batch_repeat_interleave(top_k) | |
| else: | |
| new_key_values = [] | |
| for layer in past: | |
| items = [] | |
| # item is either the key or the value matrix | |
| for item in layer: | |
| items.append(item.repeat_interleave(top_k, dim=0)) | |
| new_key_values.append(tuple(items)) | |
| past = tuple(new_key_values) | |
| model_kwargs["past_key_values"] = past | |
| if sequential: | |
| all_outputs = [] | |
| for i in range(top_k): | |
| # compute the candidate tokens by the language model and collect their hidden_states | |
| next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) | |
| outputs = self( | |
| **next_model_inputs, | |
| return_dict=True, | |
| output_hidden_states=True, | |
| output_attentions=output_attentions, | |
| ) | |
| if isinstance(outputs["past_key_values"], DynamicCache) or ( | |
| isinstance(outputs["past_key_values"], EncoderDecoderCache) | |
| and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) | |
| ): | |
| # Remove past K-V from output since we don't need to stack later | |
| outputs["past_key_values"] = None | |
| # Remove last token from past K-V since we don't want to append it at this point | |
| model_kwargs["past_key_values"].crop(-1) | |
| all_outputs.append(outputs) | |
| outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) | |
| else: | |
| # compute the candidate tokens by the language model and collect their hidden_states | |
| # assembles top_k_ids into batch of size k | |
| next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) | |
| outputs = self( | |
| **next_model_inputs, | |
| return_dict=True, | |
| output_hidden_states=True, | |
| output_attentions=output_attentions, | |
| ) | |
| # This is essential to avoid having a last reference to the big past K-V and double the necessary memory | |
| # in the next loop | |
| del next_model_inputs | |
| # name is different for encoder-decoder and decoder-only models | |
| if self.config.is_encoder_decoder: | |
| next_hidden = outputs.decoder_hidden_states[-1] | |
| full_hidden_states = outputs.decoder_hidden_states | |
| else: | |
| next_hidden = outputs.hidden_states[-1] | |
| full_hidden_states = outputs.hidden_states | |
| # .float() is needed to retain precision for later logits manipulations | |
| logits = outputs.logits[:, -1, :].float() | |
| context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) | |
| # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the | |
| # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't | |
| # introduce (noticeable) slowdowns on single-device runs. | |
| selected_idx = _ranking_fast( | |
| context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k | |
| ) | |
| cosine_matrix_mask = torch.cat( | |
| [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 | |
| ) | |
| selected_idx = selected_idx.to("cpu") | |
| # This will be used instead of the previous inneficient torch.stack(torch.split()) | |
| augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) | |
| # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing | |
| # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores | |
| # (model confidence minus degeneration penalty); (6) decoder hidden_states | |
| next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] | |
| next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) | |
| next_hidden = next_hidden[range(batch_size), selected_idx, :] | |
| last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) | |
| next_decoder_hidden_states = () | |
| for layer in full_hidden_states: | |
| layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] | |
| next_decoder_hidden_states += (layer,) | |
| # generate past_key_values cache of only the selected token | |
| if sequential: | |
| next_model_input = self.prepare_inputs_for_generation( | |
| top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs | |
| ) | |
| selected_outputs = self( | |
| **next_model_input, | |
| return_dict=True, | |
| output_hidden_states=False, | |
| output_attentions=False, | |
| ) | |
| next_past_key_values = selected_outputs["past_key_values"] | |
| else: | |
| next_past_key_values = None | |
| for possible_cache_name in ALL_CACHE_NAMES: | |
| next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None) | |
| # Do it in-place layer per layer to save memory | |
| if isinstance(next_past_key_values, DynamicCache) or ( | |
| isinstance(next_past_key_values, EncoderDecoderCache) | |
| and isinstance(next_past_key_values.self_attention_cache, DynamicCache) | |
| ): | |
| next_past_key_values.batch_select_indices(augmented_idx) | |
| else: | |
| new_key_values = [] | |
| for layer in next_past_key_values: | |
| items = [] | |
| # item is either the key or the value matrix | |
| for item in layer: | |
| items.append(item[augmented_idx, ...]) | |
| new_key_values.append(tuple(items)) | |
| next_past_key_values = tuple(new_key_values) | |
| logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] | |
| logit_for_next_step = logit_for_next_step.to(input_ids.device) | |
| # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration | |
| if self.config.is_encoder_decoder: | |
| next_step_cross_attentions = () | |
| next_step_decoder_attentions = () | |
| if output_attentions: | |
| for layer in outputs.cross_attentions: | |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] | |
| next_step_cross_attentions += (layer,) | |
| for layer in outputs.decoder_attentions: | |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] | |
| next_step_decoder_attentions += (layer,) | |
| outputs = Seq2SeqLMOutput( | |
| past_key_values=next_past_key_values, | |
| decoder_hidden_states=next_decoder_hidden_states, | |
| decoder_attentions=next_step_decoder_attentions or None, | |
| cross_attentions=next_step_cross_attentions or None, | |
| ) | |
| else: | |
| next_step_attentions = () | |
| if output_attentions: | |
| for layer in outputs.attentions: | |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] | |
| next_step_attentions += (layer,) | |
| outputs = CausalLMOutputWithPast( | |
| past_key_values=next_past_key_values, | |
| hidden_states=next_decoder_hidden_states, | |
| attentions=next_step_attentions or None, | |
| ) | |
| # contrastive_search main logic end | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue | |
| # finished sentences should have their next token be a padding token | |
| if has_eos_stopping_criteria: | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| # stop when each sentence is finished | |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
| this_peer_finished = unfinished_sequences.max() == 0 | |
| if streamer is not None: | |
| streamer.end() | |
| if return_dict_in_generate: | |
| # Contrastive search works by forward looking at the next token, so we need to exclude it from | |
| # `past_key_values` to be consistent with the other decoding methods | |
| if model_kwargs.get("past_key_values") is not None: | |
| if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( | |
| isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) | |
| and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) | |
| ): | |
| model_kwargs["past_key_values"].crop(-1) | |
| else: | |
| past_key_values = [] | |
| for layer in model_kwargs["past_key_values"]: | |
| layer_past_key_values = [] | |
| for item in layer: | |
| layer_past_key_values.append(item[..., :-1, :]) | |
| past_key_values.append(tuple(layer_past_key_values)) | |
| model_kwargs["past_key_values"] = tuple(past_key_values) | |
| if self.config.is_encoder_decoder: | |
| return GenerateEncoderDecoderOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return input_ids | |
| def _sample( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| streamer: Optional["BaseStreamer"], | |
| **model_kwargs, | |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and | |
| can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
| an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: | |
| A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| pad_token_id = generation_config._pad_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | |
| do_sample = generation_config.do_sample | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # keep track of which sequences are already finished | |
| batch_size, cur_len = input_ids.shape[:2] | |
| this_peer_finished = False | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| model_forward = self.__call__ | |
| compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) | |
| if compile_forward: | |
| os.environ["TOKENIZERS_PARALLELISM"] = "0" | |
| # If we use FA2 and a static cache, we cannot compile with fullgraph | |
| if self.config._attn_implementation == "flash_attention_2" and getattr( | |
| model_kwargs.get("past_key_values"), "is_compileable", False | |
| ): | |
| if generation_config.compile_config is None: | |
| generation_config.compile_config = CompileConfig(fullgraph=False) | |
| # only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user) | |
| elif generation_config.compile_config.fullgraph: | |
| logger.warning_once( | |
| "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " | |
| "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." | |
| ) | |
| generation_config.compile_config.fullgraph = False | |
| model_forward = self.get_compiled_call(generation_config.compile_config) | |
| if generation_config.prefill_chunk_size is not None: | |
| model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) | |
| is_prefill = False | |
| else: | |
| is_prefill = True | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # prepare variable output controls (note: some models won't accept all output controls) | |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | |
| if is_prefill: | |
| outputs = self(**model_inputs, return_dict=True) | |
| is_prefill = False | |
| else: | |
| outputs = model_forward(**model_inputs, return_dict=True) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue | |
| # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration | |
| # (the clone itself is always small) | |
| next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores,) | |
| if output_logits: | |
| raw_logits += (next_token_logits,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # token selection | |
| if do_sample: | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(next_token_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if has_eos_stopping_criteria: | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
| this_peer_finished = unfinished_sequences.max() == 0 | |
| cur_len += 1 | |
| # This is needed to properly delete outputs.logits which may be very large for first iteration | |
| # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration | |
| del outputs | |
| if streamer is not None: | |
| streamer.end() | |
| if return_dict_in_generate: | |
| if self.config.is_encoder_decoder: | |
| return GenerateEncoderDecoderOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return input_ids | |
| # Auxiliary functions for beam search | |
| def _temporary_reorder_cache(self, past_key_values, beam_idx): | |
| """ | |
| Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. | |
| TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need | |
| for this function, with `Cache.reorder_cache` being the sole remaining code path | |
| """ | |
| model_class = self.__class__.__name__.lower() | |
| # Exception 1: code path for models using the legacy cache format | |
| if isinstance(past_key_values, (tuple, list)): | |
| past_key_values = self._reorder_cache(past_key_values, beam_idx) | |
| # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their | |
| # cache format is standardized, to avoid adding complexity to the codebase. | |
| elif "gptbigcode" in model_class: | |
| if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): | |
| raise ValueError( | |
| f"Using an unsupported cache format with {model_class}. Currently, it only supports the " | |
| "legacy tuple format or `DynamicCache`" | |
| ) | |
| past_key_values = self._reorder_cache(past_key_values, beam_idx) | |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
| # Standard code path: use the `Cache.reorder_cache` | |
| else: | |
| past_key_values.reorder_cache(beam_idx) | |
| return past_key_values | |
| @staticmethod | |
| def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor: | |
| """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]""" | |
| shape = list(tensor.shape) | |
| return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:]) | |
| @staticmethod | |
| def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor: | |
| """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]""" | |
| shape = list(tensor.shape) | |
| return torch.reshape(tensor, [batch_size, num_beams] + shape[1:]) | |
| @staticmethod | |
| def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Gathers the beam slices indexed by beam_indices into new beam array. | |
| Args: | |
| tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor | |
| with the two first dimensions depicting the batch and the beam dimensions. | |
| beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to | |
| select . | |
| Returns: | |
| A tensor with the selected beams | |
| """ | |
| # `take_along_dim` requires its indices arg to have the same number of dims as `input` | |
| while len(beam_indices.shape) < len(tensor.shape): | |
| beam_indices = beam_indices.unsqueeze(-1) | |
| gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1) | |
| return gathered_tensor | |
| @staticmethod | |
| def _beam_search_has_unfinished_sequences( | |
| running_beam_scores: torch.Tensor, | |
| beam_scores: torch.Tensor, | |
| is_sent_finished: torch.Tensor, | |
| next_token_hits_stopping_criteria: torch.Tensor, | |
| cur_len: int, | |
| max_length: int, | |
| decoder_prompt_len: int, | |
| early_stopping: Union[bool, str], | |
| length_penalty: float, | |
| ): | |
| """ | |
| Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False | |
| """ | |
| # a. Can the open beams improve the top completed scores? | |
| # early_stopping == False -> apply heuristic = always get the best score from | |
| # `cur_len - decoder_prompt_len`. See the discussion below for more details. | |
| # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 | |
| # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the | |
| # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use | |
| # `max_length` there. | |
| if early_stopping == "never" and length_penalty > 0.0: | |
| best_hypothetical_length = max_length - decoder_prompt_len | |
| else: | |
| best_hypothetical_length = cur_len - decoder_prompt_len | |
| best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) | |
| worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) | |
| improvement_possible = torch.any(best_possible_running_score > worst_finished_score) | |
| # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is | |
| # enabled, where we want to finish as soon as all beams have a completed sequence. | |
| exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True)) | |
| # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have | |
| # reached `max_length`` | |
| valid_continuations = ~torch.all(next_token_hits_stopping_criteria) | |
| return improvement_possible & exists_open_beam & valid_continuations | |
| def _get_top_k_continuations( | |
| self, | |
| accumulated_log_probs: torch.Tensor, | |
| running_sequences: torch.Tensor, | |
| running_beam_indices: torch.Tensor, | |
| cur_len: int, | |
| decoder_prompt_len: int, | |
| do_sample: bool, | |
| beams_to_keep: int, | |
| num_beams: int, | |
| vocab_size: int, | |
| batch_size: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Get top-K continuations given the accumulated log probs on the next token. | |
| A few notes to understand what's going on: | |
| 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the | |
| top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated | |
| log-probabilities, or sample them without replacement using the accumulated scores | |
| 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at | |
| least `num_beams` sequences remaining to continue the live beam search. | |
| 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations | |
| selected in this step hit the stopping criteria. | |
| """ | |
| # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after | |
| # token selection. The function should be an argument exposed, so that custom scoring functions can be | |
| # defined. | |
| # Gather the top K scores from _all_ beams. | |
| if do_sample: | |
| topk_indices = torch.multinomial( | |
| nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep | |
| ) | |
| topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices) | |
| else: | |
| topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep) | |
| # Gather K top beams, recover the beam index by floor division and token id by modulo division | |
| topk_current_beam_indices = topk_indices // vocab_size | |
| topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices) | |
| topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices) | |
| topk_ids = topk_indices % vocab_size | |
| # Update sequences for the K top-k new sequences. | |
| topk_running_sequences[:, :, cur_len] = topk_ids | |
| # we want to store the beam indices with batch information -> real beam index = beam index % num beams | |
| batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams | |
| batch_modified_indices = topk_current_beam_indices + batch_offset | |
| topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices | |
| return topk_log_probs, topk_running_sequences, topk_running_beam_indices | |
| def _get_running_beams_for_next_iteration( | |
| self, | |
| topk_log_probs: torch.Tensor, | |
| topk_running_sequences: torch.Tensor, | |
| topk_running_beam_indices: torch.Tensor, | |
| next_token_hits_stopping_criteria: torch.Tensor, | |
| num_beams: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the | |
| best non-finished beams to continue beam search in the next iteration. | |
| """ | |
| # To prevent these just finished sequences from being used in subsequent iterations, set their log probs | |
| # to a very large negative value | |
| topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9 | |
| next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1] | |
| running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices) | |
| running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices) | |
| running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices) | |
| return running_sequences, running_beam_scores, running_beam_indices | |
| def _update_finished_beams( | |
| self, | |
| sequences: torch.Tensor, | |
| topk_running_sequences: torch.Tensor, | |
| beam_scores: torch.Tensor, | |
| topk_log_probs: torch.Tensor, | |
| beam_indices: torch.Tensor, | |
| topk_running_beam_indices: torch.Tensor, | |
| is_sent_finished: torch.Tensor, | |
| next_token_hits_stopping_criteria: torch.Tensor, | |
| top_num_beam_mask: torch.Tensor, | |
| num_beams: int, | |
| cur_len: int, | |
| decoder_prompt_len: int, | |
| length_penalty: float, | |
| early_stopping: Union[bool, str], | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the finished beams if (and only if) there are new completed sequences that have a higher score than | |
| the current finished sequences. | |
| """ | |
| # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the | |
| # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to | |
| # continue. | |
| did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :] | |
| # Further process topk logits for the finished beams | |
| # - add length penalty | |
| topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty) | |
| # - make sure no scores can be added anymore if beam is full and early stopping is on | |
| beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) | |
| topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 | |
| # - make sure still running sequences cannot be chosen as finalized beam | |
| topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 | |
| # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized | |
| # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score | |
| # in this step), and keep the best `num_beams` sequences. | |
| merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1) | |
| merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1) | |
| merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1) | |
| merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1) | |
| topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1] | |
| sequences = self._gather_beams(merged_sequences, topk_merged_indices) | |
| beam_scores = self._gather_beams(merged_scores, topk_merged_indices) | |
| beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices) | |
| is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices) | |
| return sequences, beam_scores, beam_indices, is_sent_finished | |
| # end of auxiliary functions for beam search | |
| def _beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| **model_kwargs, | |
| ) -> Union[GenerateBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **beam search decoding** and | |
| can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| If it's the first time you're diving into Beam Search, we recommend you read the following blog post: | |
| https://huggingface.co/blog/how-to-generate (especially the beam search section). | |
| You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function | |
| (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores) | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`: | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
| an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or | |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # 1. init beam_search values | |
| pad_token_id = generation_config._pad_token_tensor | |
| eos_token_id = generation_config._eos_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| do_sample = generation_config.do_sample | |
| early_stopping = generation_config.early_stopping | |
| length_penalty = generation_config.length_penalty | |
| max_length = generation_config.max_length | |
| num_beams = generation_config.num_beams | |
| num_return_sequences = generation_config.num_return_sequences | |
| batch_size_unflattened, cur_len = input_ids.shape[:2] | |
| batch_size = batch_size_unflattened // num_beams | |
| # TODO (joao): standardize special cases | |
| if self.__class__.__name__ == "MoshiDepthDecoder": | |
| vocab_size = self.config.audio_vocab_size | |
| elif self.__class__.__name__ == "ImageGPTForCausalImageModeling": | |
| vocab_size = self.get_output_embeddings().out_features | |
| else: | |
| vocab_size = self.config.get_text_config().vocab_size | |
| decoder_prompt_len = cur_len | |
| this_peer_finished = False | |
| # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates | |
| # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K | |
| # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences | |
| # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token. | |
| n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 | |
| beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams | |
| top_num_beam_mask = torch.cat( | |
| (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)), | |
| dim=0, | |
| ).to(input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there | |
| # are newer low-memory alternatives like the offloaded cache) | |
| sequential = generation_config.low_memory | |
| if sequential: | |
| raise ValueError( | |
| "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in " | |
| "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered." | |
| ) | |
| # 2. init output tuples | |
| all_scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| beam_indices = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # 3. init running tensors and static-shaped placeholders | |
| # per batch, beam-item holding current token in loop and completed sequences | |
| output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1 | |
| running_sequences = torch.full( | |
| (batch_size, num_beams, max_length), | |
| fill_value=output_fill_value, | |
| dtype=torch.int64, | |
| device=input_ids.device, | |
| ) | |
| running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams) | |
| sequences = running_sequences.detach().clone() | |
| # per batch, beam-item score, logprobs | |
| # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens | |
| # of the first beam are considered to avoid sampling the exact same tokens across all beams. | |
| running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| running_beam_scores[:, 1:] = -1e9 | |
| beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device) | |
| # per batch, beam-item state bit indicating if sentence has finished. | |
| is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) | |
| # per batch, beam-item state bit indicating if there are valid continuations. | |
| next_token_hits_stopping_criteria = torch.zeros( | |
| (batch_size, num_beams), dtype=torch.bool, device=input_ids.device | |
| ) | |
| # per batch selected beam indices | |
| running_beam_indices = torch.full( | |
| (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device | |
| ) | |
| beam_indices = running_beam_indices.detach().clone() | |
| # 4. run the generation loop | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # a. Forward current tokens, obtain the logits | |
| flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len]) | |
| model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs) | |
| # prepare variable output controls (note: some models won't accept all output controls) | |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | |
| model_outputs = self(**model_inputs, return_dict=True) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| model_outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue | |
| # Copy is needed to avoid keeping a hanging ref | |
| logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) | |
| # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.* | |
| # `temperature`, ...), and add new logprobs to existing running logprobs scores. | |
| log_probs = nn.functional.log_softmax(logits, dim=-1) | |
| log_probs = logits_processor(flat_running_sequences, log_probs) | |
| # Store logits, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_logits: | |
| raw_logits += (logits.clone(),) | |
| if return_dict_in_generate and output_scores: | |
| all_scores += (log_probs.clone(),) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (model_outputs.decoder_attentions,) | |
| if self.config.is_encoder_decoder | |
| else (model_outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (model_outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (model_outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (model_outputs.hidden_states,) | |
| ) | |
| # This is needed to properly delete logits which may be very large for first iteration | |
| # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration | |
| del model_outputs | |
| log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams) | |
| log_probs = log_probs + running_beam_scores[:, :, None] | |
| log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size)) | |
| # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best | |
| # continuations among all beams based on the accumulated scores. | |
| topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations( | |
| accumulated_log_probs=log_probs, | |
| running_sequences=running_sequences, | |
| running_beam_indices=running_beam_indices, | |
| cur_len=cur_len, | |
| decoder_prompt_len=decoder_prompt_len, | |
| do_sample=do_sample, | |
| beams_to_keep=beams_to_keep, | |
| num_beams=num_beams, | |
| vocab_size=vocab_size, | |
| batch_size=batch_size, | |
| ) | |
| # d. Check which running sequences have finished | |
| next_token_hits_stopping_criteria = stopping_criteria( | |
| self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes | |
| all_scores, | |
| ) | |
| next_token_hits_stopping_criteria = self._unflatten_beam_dim( | |
| next_token_hits_stopping_criteria, batch_size, beams_to_keep | |
| ) | |
| # e. Get the non-finished running `num_beams` sequences for the next generation step | |
| running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration( | |
| topk_log_probs=topk_log_probs, | |
| topk_running_sequences=topk_running_sequences, | |
| topk_running_beam_indices=topk_running_beam_indices, | |
| next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, | |
| num_beams=num_beams, | |
| ) | |
| # f. Update the completed beams if a new high score in a finished sequence is found | |
| sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams( | |
| sequences=sequences, | |
| topk_running_sequences=topk_running_sequences, | |
| beam_scores=beam_scores, | |
| topk_log_probs=topk_log_probs, | |
| beam_indices=beam_indices, | |
| topk_running_beam_indices=topk_running_beam_indices, | |
| is_sent_finished=is_sent_finished, | |
| next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, | |
| top_num_beam_mask=top_num_beam_mask, | |
| num_beams=num_beams, | |
| cur_len=cur_len, | |
| decoder_prompt_len=decoder_prompt_len, | |
| length_penalty=length_penalty, | |
| early_stopping=early_stopping, | |
| ) | |
| # g. Prepare remaining data for the next iteration, including computing the stopping condition for | |
| # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) | |
| # pluck the cache from the beam indices that will be used in the next iteration | |
| if model_kwargs.get("past_key_values", None) is not None: | |
| model_kwargs["past_key_values"] = self._temporary_reorder_cache( | |
| past_key_values=model_kwargs["past_key_values"], | |
| beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]), | |
| ) | |
| cur_len = cur_len + 1 | |
| this_peer_finished = not self._beam_search_has_unfinished_sequences( | |
| running_beam_scores, | |
| beam_scores, | |
| is_sent_finished, | |
| next_token_hits_stopping_criteria, | |
| cur_len, | |
| max_length, | |
| decoder_prompt_len, | |
| early_stopping, | |
| length_penalty, | |
| ) | |
| # 5. prepare outputs | |
| # Take best beams for each batch (the score is sorted in descending order) | |
| sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :]) | |
| beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences]) | |
| beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) | |
| # Crop the static-shaped tensors to the actual size. | |
| # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each | |
| # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a | |
| # previous decoding iteration) | |
| max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max() | |
| output_length = decoder_prompt_len + max_generated_length | |
| sequences = sequences[:, :output_length] | |
| beam_indices = beam_indices[:, :max_generated_length] | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| beam_scores = None | |
| if self.config.is_encoder_decoder: | |
| return GenerateBeamEncoderDecoderOutput( | |
| sequences=sequences, | |
| sequences_scores=beam_scores, | |
| scores=all_scores, | |
| logits=raw_logits, | |
| beam_indices=beam_indices, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateBeamDecoderOnlyOutput( | |
| sequences=sequences, | |
| sequences_scores=beam_scores, | |
| scores=all_scores, | |
| logits=raw_logits, | |
| beam_indices=beam_indices, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return sequences | |
| def _group_beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| beam_scorer: BeamScorer, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| **model_kwargs, | |
| ): | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **diverse beam search | |
| decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| beam_scorer (`BeamScorer`): | |
| An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and | |
| sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| model_kwargs: | |
| Additional model specific kwargs that will be forwarded to the `forward` function of the model. If | |
| model is an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or | |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| pad_token_id = generation_config._pad_token_tensor | |
| eos_token_id = generation_config._eos_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| num_beams = beam_scorer.num_beams | |
| num_beam_groups = beam_scorer.num_beam_groups | |
| num_sub_beams = num_beams // num_beam_groups | |
| batch_size = len(beam_scorer._beam_hyps) // num_beam_groups | |
| device = input_ids.device | |
| batch_beam_size, cur_len = input_ids.shape | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| if return_dict_in_generate and output_scores: | |
| beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] | |
| else: | |
| beam_indices = None | |
| if num_beams * batch_size != batch_beam_size: | |
| raise ValueError( | |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in | |
| # the same group don't produce same tokens every time. | |
| beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) | |
| beam_scores[:, ::num_sub_beams] = 0 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False | |
| decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # predicted tokens in cur_len step | |
| current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) | |
| # indices which will form the beams in the next time step | |
| reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) | |
| # do one decoder step on all beams of all sentences in batch | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # prepare variable output controls (note: some models won't accept all output controls) | |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | |
| outputs = self(**model_inputs, return_dict=True) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue | |
| if output_scores: | |
| processed_score = torch.zeros_like(outputs.logits[:, -1, :]) | |
| if output_logits: | |
| # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration | |
| # (the clone itself is always small) | |
| raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device) | |
| for beam_group_idx in range(num_beam_groups): | |
| group_start_idx = beam_group_idx * num_sub_beams | |
| group_end_idx = min(group_start_idx + num_sub_beams, num_beams) | |
| group_size = group_end_idx - group_start_idx | |
| # indices of beams of current group among all sentences in batch | |
| batch_group_indices = [] | |
| for batch_idx in range(batch_size): | |
| batch_group_indices.extend( | |
| [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] | |
| ) | |
| group_input_ids = input_ids[batch_group_indices] | |
| # select outputs of beams of current group only | |
| # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop | |
| # .float() is needed to retain precision for later logits manipulations | |
| next_token_logits = outputs.logits[batch_group_indices, -1, :].to( | |
| dtype=torch.float32, device=input_ids.device | |
| ) | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * group_size, vocab_size) | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores_processed = logits_processor( | |
| group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx | |
| ) | |
| next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) | |
| next_token_scores = next_token_scores.expand_as(next_token_scores_processed) | |
| if output_scores: | |
| processed_score[batch_group_indices] = next_token_scores_processed | |
| # reshape for beam search | |
| next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) | |
| # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. | |
| n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True | |
| ) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None | |
| beam_outputs = beam_scorer.process( | |
| group_input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=process_beam_indices, | |
| group_index=beam_group_idx, | |
| decoder_prompt_len=decoder_prompt_len, | |
| ) | |
| beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| if return_dict_in_generate and output_scores: | |
| beam_indices[beam_group_idx] = tuple( | |
| beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) | |
| ) | |
| input_ids[batch_group_indices] = group_input_ids[beam_idx] | |
| group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| current_tokens[batch_group_indices] = group_input_ids[:, -1] | |
| # (beam_idx // group_size) -> batch_idx | |
| # (beam_idx % group_size) -> offset of idx inside the group | |
| reordering_indices[batch_group_indices] = ( | |
| num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") | |
| + group_start_idx | |
| + (beam_idx % group_size) | |
| ) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (processed_score,) | |
| if output_logits: | |
| raw_logits += (raw_logit_score,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | |
| # This is needed to properly delete outputs.logits which may be very large for first iteration | |
| # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration | |
| # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory | |
| # (that way the memory peak does not include outputs.logits) | |
| del outputs | |
| if model_kwargs.get("past_key_values", None) is not None: | |
| model_kwargs["past_key_values"] = self._temporary_reorder_cache( | |
| model_kwargs["past_key_values"], reordering_indices | |
| ) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): | |
| this_peer_finished = True | |
| final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=final_beam_indices, | |
| decoder_prompt_len=decoder_prompt_len, | |
| ) | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| sequence_outputs["sequence_scores"] = None | |
| if self.config.is_encoder_decoder: | |
| return GenerateBeamEncoderDecoderOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| logits=raw_logits, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateBeamDecoderOnlyOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| logits=raw_logits, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return sequence_outputs["sequences"] | |
| def _constrained_beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| constrained_beam_scorer: ConstrainedBeamSearchScorer, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| **model_kwargs, | |
| ) -> Union[GenerateBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **constrained beam search | |
| decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| constrained_beam_scorer (`ConstrainedBeamSearchScorer`): | |
| A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and | |
| sorted during generation, while satisfying a list of positive constraints. For more information, the | |
| documentation of [`ConstrainedBeamSearchScorer`] should be read. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
| an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or | |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| pad_token_id = generation_config._pad_token_tensor | |
| eos_token_id = generation_config._eos_token_tensor | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| batch_size = len(constrained_beam_scorer._beam_hyps) | |
| num_beams = constrained_beam_scorer.num_beams | |
| batch_beam_size, cur_len = input_ids.shape[:2] | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| if num_beams * batch_size != batch_beam_size: | |
| raise ValueError( | |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| beam_indices = ( | |
| tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None | |
| ) | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens | |
| # of the first beam are considered to avoid sampling the exact same tokens across all beams. | |
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| beam_scores[:, 1:] = -1e9 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False | |
| decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # prepare variable output controls (note: some models won't accept all output controls) | |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | |
| outputs = self(**model_inputs, return_dict=True) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue | |
| # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration | |
| # (the clone itself is always small) | |
| # .float() is needed to retain precision for later logits manipulations | |
| next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( | |
| next_token_scores_processed | |
| ) | |
| scores_for_all_vocab = next_token_scores.clone() | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores,) | |
| if output_logits: | |
| raw_logits += (next_token_logits,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
| # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. | |
| n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True | |
| ) | |
| next_indices = (next_tokens / vocab_size).long() | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = constrained_beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| scores_for_all_vocab, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=beam_indices, | |
| decoder_prompt_len=decoder_prompt_len, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| # This is needed to properly delete outputs.logits which may be very large for first iteration | |
| # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration | |
| # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory | |
| # (that way the memory peak does not include outputs.logits) | |
| del outputs | |
| if model_kwargs.get("past_key_values", None) is not None: | |
| model_kwargs["past_key_values"] = self._temporary_reorder_cache( | |
| model_kwargs["past_key_values"], beam_idx | |
| ) | |
| if return_dict_in_generate and output_scores: | |
| beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): | |
| this_peer_finished = True | |
| sequence_outputs = constrained_beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=beam_indices, | |
| decoder_prompt_len=decoder_prompt_len, | |
| ) | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| sequence_outputs["sequence_scores"] = None | |
| if self.config.is_encoder_decoder: | |
| return GenerateBeamEncoderDecoderOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| logits=raw_logits, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateBeamDecoderOnlyOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| logits=raw_logits, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return sequence_outputs["sequences"] | |
| def _assisted_decoding( | |
| self, | |
| input_ids: torch.LongTensor, | |
| candidate_generator: CandidateGenerator, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| streamer: Optional["BaseStreamer"], | |
| **model_kwargs, | |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **greedy decoding** or | |
| **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a | |
| candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text | |
| models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| candidate_generator (`CandidateGenerator`): | |
| A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For | |
| more information, the documentation of [`CandidateGenerator`] should be read. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with | |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| model_kwargs: | |
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. | |
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or | |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| do_sample = generation_config.do_sample | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # keep track of which sequences are already finished | |
| batch_size, cur_len = input_ids.shape[:2] | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) | |
| this_peer_finished = False | |
| is_first_iteration = True # to preserve the same API in the output as other generation methods | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| cur_len = input_ids.shape[1] | |
| # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device | |
| candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) | |
| candidate_input_ids = candidate_input_ids.to(self.device) | |
| if candidate_logits is not None: | |
| candidate_logits = candidate_logits.to(self.device) | |
| candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] | |
| is_done_candidate = stopping_criteria(candidate_input_ids, None) | |
| # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain | |
| # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, | |
| # we use this forward pass to also pick the subsequent logits in the original model. | |
| # 2.1. Prepare the model inputs | |
| candidate_kwargs = copy.copy(model_kwargs) | |
| candidate_kwargs = _prepare_attention_mask( | |
| candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder | |
| ) | |
| candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) | |
| if "cache_position" in candidate_kwargs: | |
| candidate_kwargs["cache_position"] = torch.cat( | |
| ( | |
| candidate_kwargs["cache_position"], | |
| torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), | |
| ), | |
| dim=0, | |
| ) | |
| model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) | |
| if "logits_to_keep" in model_inputs: | |
| model_inputs["logits_to_keep"] = candidate_length + 1 | |
| # 2.2. Run a forward pass on the candidate sequence | |
| # prepare variable output controls (note: some models won't accept all output controls) | |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | |
| outputs = self(**model_inputs) | |
| # 2.3. Process the new logits | |
| # .float() is needed to retain precision for later logits manipulations | |
| new_logits = outputs.logits[:, -candidate_length - 1 :].to( | |
| dtype=torch.float32, device=input_ids.device | |
| ) # excludes the input prompt if present | |
| next_token_logits = new_logits.clone() | |
| if len(logits_processor) > 0: | |
| for i in range(candidate_length + 1): | |
| new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) | |
| # 3. Select the accepted tokens. There are two possible cases: | |
| # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) | |
| # π Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192). | |
| if do_sample and candidate_logits is not None: | |
| valid_tokens, n_matches = _speculative_sampling( | |
| candidate_input_ids, | |
| candidate_logits, | |
| candidate_length, | |
| new_logits, | |
| is_done_candidate, | |
| ) | |
| # Case 2: all other cases (originally from assisted generation) π Compare the tokens selected from the | |
| # original model logits with the candidate tokens. We can keep the candidate tokens until the first | |
| # mismatch, or until the max length is reached. | |
| else: | |
| if do_sample: | |
| probs = new_logits.softmax(dim=-1) | |
| selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | |
| else: | |
| selected_tokens = new_logits.argmax(dim=-1) | |
| candidate_new_tokens = candidate_input_ids[:, cur_len:] | |
| n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() | |
| # Ensure we don't generate beyond max_len or an EOS token | |
| if is_done_candidate and n_matches == candidate_length: | |
| n_matches -= 1 | |
| valid_tokens = selected_tokens[:, : n_matches + 1] | |
| # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated | |
| # by the model after the last candidate match is also valid, as it is generated from a correct sequence. | |
| # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there | |
| # is no match. | |
| # 4.1. Get the valid continuation, after the matching tokens | |
| input_ids = torch.cat((input_ids, valid_tokens), dim=-1) | |
| if streamer is not None: | |
| streamer.put(valid_tokens.cpu()) | |
| new_cur_len = input_ids.shape[1] | |
| # 4.2. Discard past key values relative to unused assistant tokens | |
| new_cache_size = new_cur_len - 1 | |
| outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) | |
| # 5. Update the candidate generation strategy if needed | |
| candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) | |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| num_new_tokens=n_matches + 1, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue | |
| # Store scores, attentions and hidden_states when required | |
| # Assistant: modified to append one tuple element per token, as in the other generation methods. | |
| if return_dict_in_generate: | |
| newly_added_length = n_matches + 1 | |
| if output_scores: | |
| scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) | |
| if output_logits: | |
| raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) | |
| newly_added_length = new_cur_len if is_first_iteration else newly_added_length | |
| if output_attentions: | |
| if self.config.is_encoder_decoder: | |
| cross_attentions = _split_model_outputs( | |
| cross_attentions, outputs.cross_attentions, cur_len, newly_added_length | |
| ) | |
| decoder_attentions = _split_model_outputs( | |
| decoder_attentions, | |
| outputs.decoder_attentions, | |
| cur_len, | |
| newly_added_length, | |
| is_decoder_attention=True, | |
| ) | |
| # some (V)LLMs have hard requirement on SDPA and thus never return attn | |
| elif outputs.attentions[0] is not None: | |
| decoder_attentions = _split_model_outputs( | |
| decoder_attentions, | |
| outputs.attentions, | |
| cur_len, | |
| newly_added_length, | |
| is_decoder_attention=True, | |
| ) | |
| if output_hidden_states: | |
| if self.config.is_encoder_decoder: | |
| decoder_hidden_states = _split_model_outputs( | |
| decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length | |
| ) | |
| else: | |
| decoder_hidden_states = _split_model_outputs( | |
| decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length | |
| ) | |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
| this_peer_finished = unfinished_sequences.max() == 0 | |
| is_first_iteration = False | |
| if streamer is not None: | |
| streamer.end() | |
| if ( | |
| hasattr(candidate_generator, "assistant_model") | |
| and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" | |
| ): | |
| candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( | |
| candidate_generator.num_assistant_tokens | |
| ) | |
| if return_dict_in_generate: | |
| if self.config.is_encoder_decoder: | |
| return GenerateEncoderDecoderOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return GenerateDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return input_ids | |
| def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs): | |
| # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may | |
| # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings | |
| torch._dynamo.config.cache_size_limit = 64 | |
| chunk_size = generation_config.prefill_chunk_size | |
| # Only chunk up the token just before last, so that decoding is completely performed outside this function | |
| # (here we simply prefill the cache) | |
| input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1) | |
| if "past_key_values" not in model_kwargs: | |
| raise ValueError("Cannot use prefill chunking without a cache") | |
| model_forward = self.forward | |
| compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) | |
| if compile_forward: | |
| model_forward = self.get_compiled_call(generation_config.compile_config) | |
| attention_mask = model_kwargs.pop("attention_mask", None) | |
| past_length = 0 | |
| for input_chunk in input_chunks: | |
| current_length = past_length + input_chunk.shape[-1] | |
| # Prepare inputs | |
| if attention_mask is not None: | |
| model_kwargs["attention_mask"] = attention_mask[:, :current_length] | |
| model_kwargs["cache_position"] = torch.arange( | |
| past_length, current_length, dtype=torch.long, device=input_chunk.device | |
| ) | |
| model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0) | |
| model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs) | |
| outputs = model_forward(**model_inputs, return_dict=True) | |
| model_kwargs["past_key_values"] = outputs.past_key_values | |
| past_length = current_length | |
| model_kwargs["attention_mask"] = attention_mask | |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 | |
| _ = model_kwargs.pop("position_ids", None) | |
| return model_kwargs | |
| def _speculative_sampling( | |
| candidate_input_ids, | |
| candidate_logits, | |
| candidate_length, | |
| new_logits, | |
| is_done_candidate, | |
| ): | |
| """ | |
| Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns | |
| the selected tokens, as well as the number of candidate matches. | |
| NOTE: Unless otherwise stated, the variable names match those in the paper. | |
| """ | |
| new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] | |
| # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens | |
| # selected by the assistant, respectively. | |
| q = candidate_logits.softmax(dim=-1) | |
| q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) | |
| p = new_logits.softmax(dim=-1) | |
| p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) | |
| probability_ratio = p_i / q_i | |
| # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller | |
| # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio | |
| # (= keep with p = probability_ratio). Keep all the tokens until the first rejection | |
| r_i = torch.rand_like(probability_ratio) | |
| is_accepted = r_i <= probability_ratio | |
| n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 | |
| # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) | |
| if is_done_candidate and n_matches == candidate_length: | |
| # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model | |
| # due to acceptance on EOS we fix `n_matches` | |
| n_matches -= 1 | |
| valid_tokens = new_candidate_input_ids[:, : n_matches + 1] | |
| else: | |
| # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. | |
| gamma = candidate_logits.shape[1] | |
| p_n_plus_1 = p[:, n_matches, :] | |
| if n_matches < gamma: | |
| q_n_plus_1 = q[:, n_matches, :] | |
| p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) | |
| p_prime.div_(p_prime.sum()) | |
| else: | |
| p_prime = p_n_plus_1 | |
| t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] | |
| # The selected tokens include the matches (if any) plus the next sampled tokens | |
| if n_matches > 0: | |
| valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) | |
| else: | |
| valid_tokens = t | |
| return valid_tokens, n_matches | |
| def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): | |
| """ | |
| Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple | |
| where each member corresponds to a single generated token. | |
| """ | |
| # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the | |
| # prompt. | |
| if len(outputs) == 0: | |
| new_tuple = () | |
| for layer in new_outputs: | |
| last_dim_size = cur_len if is_decoder_attention else layer.shape[-1] | |
| new_tuple += (layer[..., :cur_len, :last_dim_size],) | |
| outputs += (new_tuple,) | |
| # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly | |
| cur_len += 1 | |
| added_len -= cur_len | |
| for i in range(added_len): | |
| new_tuple = () | |
| for layer in new_outputs: | |
| last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1] | |
| new_tuple += (layer[..., i : i + 1, :last_dim_size],) | |
| outputs += (new_tuple,) | |
| return outputs | |
| def _ranking_fast( | |
| context_hidden: torch.FloatTensor, | |
| next_hidden: torch.FloatTensor, | |
| next_top_k_probs: torch.FloatTensor, | |
| cosine_matrix_mask: torch.LongTensor, | |
| alpha: float, | |
| beam_width: int, | |
| ) -> torch.FloatTensor: | |
| """ | |
| Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described | |
| in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each | |
| row in the batch. | |
| """ | |
| norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) | |
| norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) | |
| cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] | |
| # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) | |
| # Using a large negative value for masked positions | |
| cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) | |
| cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min | |
| cosine_matrix = cosine_matrix + cosine_matrix_mask | |
| degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] | |
| next_top_k_probs = next_top_k_probs.view(-1) # [B*K] | |
| contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty | |
| contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] | |
| _, selected_idx = contrastive_score.max(dim=-1) # [B] | |
| return selected_idx | |
| def _split(data, full_batch_size: int, split_size: int): | |
| """ | |
| Takes care of three cases: | |
| 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim | |
| 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and | |
| return a list of tuples | |
| 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and | |
| return a list of tuples of tuples | |
| (see documentation of ModelOutput) | |
| """ | |
| if data is None: | |
| return [None] * (full_batch_size // split_size) | |
| if isinstance(data, torch.Tensor): | |
| return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] | |
| # New cache format | |
| elif isinstance(data, DynamicCache) or ( | |
| isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) | |
| ): | |
| return data.batch_split(full_batch_size, split_size) | |
| elif isinstance(data, tuple): | |
| # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) | |
| if isinstance(data[0], tuple): | |
| return [ | |
| tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) | |
| for i in range(0, full_batch_size, split_size) | |
| ] | |
| else: | |
| return [ | |
| tuple(sub_tensor[i : i + split_size] for sub_tensor in data) | |
| for i in range(0, full_batch_size, split_size) | |
| ] | |
| else: | |
| raise TypeError(f"Unexpected attribute type: {type(data)}") | |
| def _split_model_inputs( | |
| model_input: Union[ModelOutput, dict], split_size: int, full_batch_size: int, config: PretrainedConfig | |
| ) -> list[Union[ModelOutput, dict]]: | |
| """ | |
| Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split | |
| size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from | |
| previous forward pass. | |
| """ | |
| # Edge case: if model_input is None, return a list of Nones | |
| # this happens with Whisper where encoder_outputs is None | |
| if model_input is None: | |
| return [model_input] * (full_batch_size // split_size) | |
| # Infer the class from the object | |
| model_output_cls = type(model_input) | |
| if (full_batch_size % split_size) != 0: | |
| raise ValueError("`full_batch_size` must be divisible by `split_size`") | |
| if split_size > full_batch_size: | |
| raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") | |
| # Helper function to split tensors or tuples of tensors | |
| # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them | |
| keys = ( | |
| model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() | |
| ) | |
| # We only keep keys that are in the model_input | |
| keys = [k for k in keys if k in model_input] | |
| # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a | |
| # ModelOutput object. | |
| # bool should not be split but replicated for each split | |
| bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] | |
| keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] | |
| non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] | |
| # we split the tensors and tuples of tensors | |
| data_split_list = [ | |
| {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} | |
| for i in range(full_batch_size // split_size) | |
| ] | |
| # bool values are the same and replicated for each split | |
| bool_data = {k: model_input[k] for k in bool_keys} | |
| # encoder_outputs is a ModelOutput object and should be split by its own | |
| if "encoder_outputs" in model_input: | |
| encoder_outputs_split = _split_model_inputs( | |
| model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() | |
| ) | |
| data_split_list = [ | |
| {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) | |
| ] | |
| # logits_to_keep should be replicated for each split, similar to bool values | |
| if "logits_to_keep" in model_input: | |
| data_split_list = [ | |
| {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list | |
| ] | |
| # Convert each dictionary in the list to an object of the inferred class | |
| split_model_inputs: list[Union[ModelOutput, dict]] = [ | |
| model_output_cls(**data_split, **bool_data) for data_split in data_split_list | |
| ] | |
| return split_model_inputs | |
| def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: | |
| """ | |
| Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the | |
| specific ModelOutput subclass from the list provided. | |
| """ | |
| if not model_outputs: | |
| raise ValueError("Input list is empty.") | |
| # Infer the class from the first object in the list | |
| model_output_cls = type(model_outputs[0]) | |
| # Ensure all objects are of the same type | |
| if not all(isinstance(obj, model_output_cls) for obj in model_outputs): | |
| raise ValueError("All elements in the list should be of the same type.") | |
| # Helper function to concat tensors or tuples of tensors | |
| def _concat(data): | |
| """ | |
| Reverse of `_split` function above. | |
| """ | |
| if any(data is None for data in data): | |
| return None | |
| if isinstance(data[0], torch.Tensor): | |
| return torch.cat(data, dim=0) | |
| # New cache format | |
| elif isinstance(data[0], DynamicCache): | |
| return DynamicCache.from_batch_splits(data) | |
| elif isinstance(data[0], EncoderDecoderCache): | |
| return EncoderDecoderCache.from_batch_splits(data) | |
| elif isinstance(data[0], tuple): | |
| # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) | |
| if isinstance(data[0][0], tuple): | |
| return tuple( | |
| tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) | |
| for i in range(len(data[0])) | |
| ) | |
| else: | |
| return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) | |
| elif isinstance(data[0], (int, float)): | |
| # If the elements are integers or floats, return a tensor | |
| return torch.tensor(data) | |
| else: | |
| raise TypeError(f"Unexpected attribute type: {type(data[0])}") | |
| # Use a dictionary comprehension to gather attributes from all objects and concatenate them | |
| concatenated_data = { | |
| k: _concat([getattr(model_output, k) for model_output in model_outputs]) | |
| for k in model_output_cls.__dataclass_fields__.keys() | |
| } | |
| # Return a new object of the inferred class with the concatenated attributes | |
| return model_output_cls(**concatenated_data) | |
| def _relative_top_filter( | |
| scores: torch.FloatTensor, | |
| baseline_scores: torch.FloatTensor, | |
| relative_top: float = 0.1, | |
| filter_value: float = -float("Inf"), | |
| base_filter_value=-1e-3, | |
| min_tokens_to_keep: int = 1, | |
| ) -> torch.FloatTensor: | |
| """ | |
| Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235 | |
| Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution. | |
| """ | |
| scores_normalized = scores.log_softmax(dim=-1) | |
| baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) | |
| sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) | |
| min_thresh = sorted_logits[..., min_tokens_to_keep - 1] | |
| probs_max = torch.max(scores_normalized, dim=-1).values | |
| probs_thresh = probs_max + np.log(relative_top) | |
| probs_thresh = torch.min(min_thresh, probs_thresh) | |
| probs_thresh = probs_thresh.unsqueeze(-1) | |
| baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value | |
| scores_normalized[scores_normalized < probs_thresh] = filter_value | |
| return scores_normalized, baseline_scores_normalized | |
| def _dola_select_contrast( | |
| candidate_premature_layers: list[int], | |
| candidate_premature_logits: dict[int, torch.FloatTensor], | |
| final_logits: torch.FloatTensor, | |
| ) -> torch.FloatTensor: | |
| if len(candidate_premature_layers) == 1: | |
| base_logits = candidate_premature_logits[candidate_premature_layers[0]] | |
| final_logits, base_logits = _relative_top_filter(final_logits, base_logits) | |
| logits = final_logits - base_logits | |
| return logits | |
| # 1. Stacking all premature_layers into a new dimension | |
| stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0) | |
| # 2. Calculate the softmax values for mature_layer and all premature_layers | |
| # shape: (batch_size, vocab_size) | |
| softmax_mature_layer = F.softmax(final_logits, dim=-1) | |
| # shape: (num_premature_layers, batch_size, vocab_size) | |
| softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) | |
| # 3. Calculate the average distribution | |
| # shape: (num_premature_layers, batch_size, vocab_size) | |
| avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) | |
| # 4. Calculate log-softmax for the KL divergence | |
| # shape: (batch_size, vocab_size) | |
| log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) | |
| # shape: (num_premature_layers, batch_size, vocab_size) | |
| log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) | |
| # 5. Calculate the KL divergences and then the JS divergences | |
| # shape: (num_premature_layers, batch_size) | |
| kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1) | |
| # shape: (num_premature_layers, batch_size) | |
| kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1) | |
| js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) | |
| # 6. Reduce the batchmean | |
| js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) | |
| premature_layer = candidate_premature_layers[int(js_divs.argmax().item())] | |
| base_logits = candidate_premature_logits[premature_layer] | |
| final_logits, base_logits = _relative_top_filter(final_logits, base_logits) | |
| logits = final_logits - base_logits | |
| return logits | |
| </script> | |
| <body class="bg-gray-900 text-gray-200 flex flex-col h-screen"> | |
| <header class="bg-gray-800/50 backdrop-blur-sm border-b border-gray-700 p-4 shadow-lg"> | |
| <h1 class="text-2xl font-bold text-center text-white">Bloatedness Visualizer</h1> | |
| <p class="text-center text-gray-400 mt-1">Paste your model in the 'Main' tab and add dependencies in other tabs.</p> | |
| </header> | |
| <main class="flex-grow flex flex-col md:flex-row gap-4 p-4 overflow-hidden"> | |
| <!-- Left Panel: Code Input & Controls --> | |
| <div class="md:w-1/3 flex flex-col h-full"> | |
| <div class="flex-grow flex flex-col bg-gray-800 rounded-lg shadow-2xl border border-gray-700"> | |
| <div class="p-4 border-b border-gray-700 flex justify-between items-center"> | |
| <h2 class="text-lg font-semibold">Code Input</h2> | |
| <button id="visualize-btn" class="bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded-md transition-colors duration-300"> | |
| Visualize | |
| </button> | |
| </div> | |
| <div class="border-b border-gray-700 bg-gray-900/50 px-2 pt-2 flex items-center gap-2"> | |
| <div id="tab-bar" class="flex gap-1"> | |
| <!-- Tabs will be dynamically inserted here --> | |
| </div> | |
| <button id="add-tab-btn" class="ml-auto bg-gray-600 hover:bg-gray-500 text-white font-bold h-8 w-8 rounded-full transition-colors duration-200">+</button> | |
| </div> | |
| <div id="code-inputs-container" class="flex-grow relative"> | |
| <!-- Textareas will be dynamically inserted here --> | |
| </div> | |
| </div> | |
| </div> | |
| <!-- Right Panel: Visualization & Details --> | |
| <div class="md:w-2/3 flex flex-col gap-4 h-full"> | |
| <div class="flex-grow bg-gray-800 rounded-lg shadow-2xl border border-gray-700 relative overflow-hidden"> | |
| <div id="graph-container" class="w-full h-full"></div> | |
| <div id="loading-spinner" class="absolute inset-0 bg-gray-800/50 flex items-center justify-center hidden z-10"> | |
| <svg class="animate-spin h-10 w-10 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"> | |
| <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle> | |
| <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path> | |
| </svg> | |
| </div> | |
| </div> | |
| <div id="details-panel" class="h-40 bg-gray-800 rounded-lg shadow-2xl border border-gray-700 p-4 flex flex-col"> | |
| <h3 class="text-lg font-semibold border-b border-gray-700 pb-2 mb-2">Details</h3> | |
| <div id="details-content" class="text-gray-400 fira-code overflow-y-auto"> | |
| <p>Click on a node in the graph to see its details. Scroll to zoom, drag to pan.</p> | |
| </div> | |
| </div> | |
| </div> | |
| </main> | |
| <script> | |
| // --- DOM Element References --- | |
| const visualizeBtn = document.getElementById('visualize-btn'); | |
| const graphContainer = document.getElementById('graph-container'); | |
| const detailsContent = document.getElementById('details-content'); | |
| const loadingSpinner = document.getElementById('loading-spinner'); | |
| const tabBar = document.getElementById('tab-bar'); | |
| const codeInputsContainer = document.getElementById('code-inputs-container'); | |
| const addTabBtn = document.getElementById('add-tab-btn'); | |
| // --- Example Code --- | |
| const exampleCodeMain = document.getElementById("main-code").textContent; | |
| const exampleCodeDeps = document.getElementById("dependencies").textContent; | |
| fetch("main_code.py") | |
| .then(res => res.text()) | |
| .then(code => { | |
| const exampleCodeMain = code; | |
| // do something with the code | |
| }); | |
| fetch("dependencies.py") | |
| .then(res => res.text()) | |
| .then(code => { | |
| const exampleCodeDeps = code; | |
| // do something with the code | |
| }); | |
| // --- Tab Management --- | |
| let tabCounter = 0; | |
| function addTab(name, content = '', isActive = false) { | |
| tabCounter++; | |
| const tabId = `tab-${tabCounter}`; | |
| const textareaId = `textarea-${tabCounter}`; | |
| // Create Tab Button | |
| const tabButton = document.createElement('button'); | |
| tabButton.id = tabId; | |
| tabButton.className = 'tab px-4 py-2 text-sm font-medium rounded-t-md transition-colors duration-200'; | |
| tabButton.textContent = name; | |
| tabButton.dataset.textareaId = textareaId; | |
| tabBar.appendChild(tabButton); | |
| // Create Textarea | |
| const textarea = document.createElement('textarea'); | |
| textarea.id = textareaId; | |
| textarea.className = 'fira-code w-full h-full p-4 bg-gray-900 text-gray-300 resize-none focus:outline-none absolute top-0 left-0'; | |
| textarea.placeholder = `Paste dependency code here...`; | |
| textarea.value = content; | |
| codeInputsContainer.appendChild(textarea); | |
| tabButton.addEventListener('click', () => switchTab(tabId)); | |
| if (isActive) { | |
| switchTab(tabId); | |
| } else { | |
| textarea.classList.add('hidden'); | |
| } | |
| } | |
| function switchTab(tabId) { | |
| // Update tab buttons | |
| document.querySelectorAll('.tab').forEach(tab => { | |
| tab.classList.toggle('active', tab.id === tabId); | |
| }); | |
| // Update textareas | |
| document.querySelectorAll('#code-inputs-container textarea').forEach(area => { | |
| area.classList.toggle('hidden', area.id !== document.getElementById(tabId).dataset.textareaId); | |
| }); | |
| } | |
| addTabBtn.addEventListener('click', () => addTab(`Dep ${tabCounter}`)); | |
| // --- Core Logic: Parser --- | |
| function parsePythonCode(code) { | |
| const nodes = []; | |
| const links = []; | |
| const nodeRegistry = new Set(); | |
| let currentClassInfo = null; | |
| const lines = code.split('\n'); | |
| lines.forEach(line => { | |
| const indentation = line.match(/^\s*/)[0].length; | |
| if (line.trim().length > 0 && indentation === 0) { | |
| currentClassInfo = null; | |
| } | |
| const classMatch = /^\s*class\s+([\w\d_]+)\s*(?:\(([\w\d_,\s]+)\))?:/.exec(line); | |
| if (classMatch) { | |
| const className = classMatch[1]; | |
| const parents = classMatch[2] ? classMatch[2].split(',').map(p => p.trim()) : []; | |
| if (!nodeRegistry.has(className)) { | |
| nodes.push({ id: className, type: 'class', parents: parents }); | |
| nodeRegistry.add(className); | |
| } else { | |
| // If class was already created as an external placeholder, update it | |
| const existingNode = nodes.find(n => n.id === className); | |
| if (existingNode && existingNode.isExternal) { | |
| existingNode.isExternal = false; | |
| existingNode.parents = parents; | |
| } | |
| } | |
| currentClassInfo = { name: className, indentation: indentation }; | |
| parents.forEach(parent => { | |
| if (!nodeRegistry.has(parent)) { | |
| nodes.push({ id: parent, type: 'class', isExternal: true, parents: [] }); | |
| nodeRegistry.add(parent); | |
| } | |
| links.push({ source: className, target: parent, type: 'inheritance' }); | |
| }); | |
| } | |
| const methodMatch = /^\s+def\s+([\w\d_]+)\s*\(([^)]*)\)/.exec(line); | |
| if (currentClassInfo && methodMatch && indentation > currentClassInfo.indentation) { | |
| const methodName = methodMatch[1]; | |
| const signature = methodMatch[2]; | |
| const methodId = `${currentClassInfo.name}.${methodName}`; | |
| if (!nodeRegistry.has(methodId)) { | |
| nodes.push({ id: methodId, name: methodName, type: 'method', parentClass: currentClassInfo.name, signature: `(${signature})` }); | |
| nodeRegistry.add(methodId); | |
| links.push({ source: currentClassInfo.name, target: methodId, type: 'method' }); | |
| } | |
| } | |
| }); | |
| return { nodes, links }; | |
| } | |
| // --- Core Logic: D3 Visualization --- | |
| let simulation; | |
| function renderGraph(data) { | |
| graphContainer.innerHTML = ''; | |
| const width = graphContainer.clientWidth; | |
| const height = graphContainer.clientHeight; | |
| const svg = d3.select(graphContainer).append("svg") | |
| .attr("viewBox", [-width / 2, -height / 2, width, height]); | |
| const container = svg.append("g"); | |
| // Add zoom capabilities | |
| const zoom = d3.zoom() | |
| .scaleExtent([0.1, 4]) | |
| .on("zoom", (event) => { | |
| container.attr("transform", event.transform); | |
| }); | |
| svg.call(zoom); | |
| if (simulation) { | |
| simulation.stop(); | |
| } | |
| simulation = d3.forceSimulation(data.nodes) | |
| .force("link", d3.forceLink(data.links).id(d => d.id).distance(d => d.type === 'inheritance' ? 150 : 60).strength(0.5)) | |
| .force("charge", d3.forceManyBody().strength(-400)) | |
| .force("center", d3.forceCenter(0, 0)) | |
| .force("x", d3.forceX()) | |
| .force("y", d3.forceY()); | |
| const link = container.append("g") | |
| .selectAll("line") | |
| .data(data.links) | |
| .join("line") | |
| .attr("class", d => `link ${d.type}`); | |
| const node = container.append("g") | |
| .selectAll("g") | |
| .data(data.nodes) | |
| .join("g") | |
| .attr("class", "node") | |
| .call(drag(simulation)); | |
| node.append("circle") | |
| .attr("r", d => d.type === 'class' ? 15 : 8) | |
| .attr("fill", d => { | |
| if (d.type !== 'class') return '#9ca3af'; | |
| return d.isExternal ? '#be185d' : '#2563eb'; | |
| }); | |
| node.append("text") | |
| .text(d => d.type === 'class' ? d.id : d.name) | |
| .attr("x", d => d.type === 'class' ? 18 : 12) | |
| .attr("y", 3) | |
| .attr("fill", "#e5e7eb"); | |
| node.on("click", (event, d) => { | |
| event.stopPropagation(); // Prevent zoom from firing on node click | |
| updateDetailsPanel(d); | |
| node.classed("selected", n => n.id === d.id); | |
| }); | |
| simulation.on("tick", () => { | |
| link.attr("x1", d => d.source.x).attr("y1", d => d.source.y) | |
| .attr("x2", d => d.target.x).attr("y2", d => d.target.y); | |
| node.attr("transform", d => `translate(${d.x},${d.y})`); | |
| }); | |
| } | |
| // --- Interactivity --- | |
| function drag(simulation) { | |
| function dragstarted(event, d) { | |
| if (!event.active) simulation.alphaTarget(0.3).restart(); | |
| d.fx = d.x; | |
| d.fy = d.y; | |
| } | |
| function dragged(event, d) { | |
| d.fx = event.x; | |
| d.fy = event.y; | |
| } | |
| function dragended(event, d) { | |
| if (!event.active) simulation.alphaTarget(0); | |
| d.fx = null; | |
| d.fy = null; | |
| } | |
| return d3.drag() | |
| .on("start", dragstarted) | |
| .on("drag", dragged) | |
| .on("end", dragended); | |
| } | |
| // --- UI Updates --- | |
| function updateDetailsPanel(d) { | |
| let content = ''; | |
| if (d.type === 'class') { | |
| content = ` | |
| <p><span class="text-gray-100 font-semibold">Name:</span> ${d.id}</p> | |
| <p><span class="text-gray-100 font-semibold">Type:</span> ${d.isExternal ? 'External Class' : 'Class'}</p> | |
| <p><span class="text-gray-100 font-semibold">Inherits from:</span> ${d.parents && d.parents.length > 0 ? d.parents.join(', ') : 'None'}</p> | |
| ${d.isExternal ? '<p class="text-pink-400 mt-1">Note: This class was not defined in the provided code.</p>' : ''} | |
| `; | |
| } else if (d.type === 'method') { | |
| content = ` | |
| <p><span class="text-gray-100 font-semibold">Name:</span> ${d.name}</p> | |
| <p><span class="text-gray-100 font-semibold">Type:</span> Method</p> | |
| <p><span class="text-gray-100 font-semibold">Belongs to:</span> ${d.parentClass}</p> | |
| <p><span class="text-gray-100 font-semibold">Signature:</span> ${d.name}${d.signature}</p> | |
| `; | |
| } | |
| detailsContent.innerHTML = content; | |
| } | |
| function handleVisualize() { | |
| loadingSpinner.classList.remove('hidden'); | |
| setTimeout(() => { | |
| try { | |
| let allCode = ''; | |
| document.querySelectorAll('#code-inputs-container textarea').forEach(area => { | |
| allCode += area.value + '\n'; | |
| }); | |
| if (!allCode.trim()) { | |
| graphContainer.innerHTML = '<p class="p-4 text-center text-gray-400">Please paste some code to visualize.</p>'; | |
| return; | |
| } | |
| const graphData = parsePythonCode(allCode); | |
| renderGraph(graphData); | |
| } catch (error) { | |
| console.error("Failed to visualize code:", error); | |
| graphContainer.innerHTML = `<p class="p-4 text-center text-red-400">An error occurred during parsing. Check the console for details.</p>`; | |
| } finally { | |
| loadingSpinner.classList.add('hidden'); | |
| } | |
| }, 50); | |
| } | |
| // --- Event Listeners --- | |
| visualizeBtn.addEventListener('click', handleVisualize); | |
| // --- Initial Load --- | |
| window.addEventListener('load', () => { | |
| addTab('Main', exampleCodeMain, true); | |
| addTab('Deps', exampleCodeDeps); | |
| handleVisualize(); | |
| }); | |
| window.addEventListener('resize', () => { | |
| let allCode = ''; | |
| document.querySelectorAll('#code-inputs-container textarea').forEach(area => { | |
| allCode += area.value + '\n'; | |
| }); | |
| if (allCode.trim()) { | |
| handleVisualize(); | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |