| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from collections.abc import Callable |
| from dataclasses import dataclass |
| from functools import cached_property |
| from typing import Optional |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from transformers import initialization as init |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.configuration_utils import PreTrainedConfig |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_experts_implementation, use_kernelized_func |
| from transformers.masking_utils import ( |
| create_bidirectional_mask, |
| create_causal_mask, |
| create_masks_for_generate, |
| create_sliding_window_causal_mask, |
| ) |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check |
| from transformers.utils.generic import maybe_autocast, merge_with_config_defaults |
| from transformers.utils.output_capturing import OutputRecorder, capture_outputs |
| from transformers.models.auto.modeling_auto import AutoModel |
| from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig |
|
|
|
|
| @dataclass |
| @auto_docstring( |
| custom_intro=""" |
| Base class for Gemma4 outputs, with hidden states and attentions. |
| """ |
| ) |
| class Gemma4ModelOutputWithPast(BaseModelOutputWithPast): |
| r""" |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| image_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
| image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. |
| audio_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
| audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. |
| """ |
|
|
| image_hidden_states: torch.FloatTensor | None = None |
|
|
| audio_hidden_states: torch.FloatTensor | None = None |
|
|
|
|
| @dataclass |
| @auto_docstring( |
| custom_intro=""" |
| Base class for Gemma4 causal language model (or autoregressive) outputs. |
| """ |
| ) |
| class Gemma4CausalLMOutputWithPast(ModelOutput): |
| r""" |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| image_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
| image_hidden_states of the model produced by the vision encoder after projecting last hidden state. |
| audio_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
| audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. |
| """ |
|
|
| loss: torch.FloatTensor | None = None |
| logits: torch.FloatTensor | None = None |
| past_key_values: Cache | None = None |
| hidden_states: tuple[torch.FloatTensor] | None = None |
| attentions: tuple[torch.FloatTensor] | None = None |
| image_hidden_states: torch.FloatTensor | None = None |
|
|
| audio_hidden_states: torch.FloatTensor | None = None |
|
|
|
|
| @dataclass |
| @auto_docstring |
| class Gemma4AudioModelOutput(BaseModelOutputWithPooling): |
| r""" |
| attention_mask (`torch.BoolTensor`, *optional*): |
| A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding. |
| """ |
|
|
| attention_mask: torch.BoolTensor | None = None |
|
|
|
|
| class Gemma4ClippableLinear(nn.Module): |
| def __init__( |
| self, |
| config: Gemma4VisionConfig | Gemma4AudioConfig, |
| in_features: int, |
| out_features: int, |
| ) -> None: |
| super().__init__() |
| self.use_clipped_linears = config.use_clipped_linears |
| self.linear = nn.Linear(in_features, out_features, bias=False) |
|
|
| if self.use_clipped_linears: |
| self.register_buffer("input_min", torch.tensor(-float("inf"))) |
| self.register_buffer("input_max", torch.tensor(float("inf"))) |
| self.register_buffer("output_min", torch.tensor(-float("inf"))) |
| self.register_buffer("output_max", torch.tensor(float("inf"))) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| if self.use_clipped_linears: |
| hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max) |
|
|
| hidden_states = self.linear(hidden_states) |
|
|
| if self.use_clipped_linears: |
| hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max) |
|
|
| return hidden_states |
|
|
|
|
| class Gemma4RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): |
| super().__init__() |
| self.eps = eps |
| self.with_scale = with_scale |
|
|
| if self.with_scale: |
| self.weight = nn.Parameter(torch.ones(dim), requires_grad=True) |
|
|
| def _norm(self, hidden_states: torch.Tensor): |
| mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps |
| |
| return hidden_states * torch.pow(mean_squared, -0.5) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| normed_output = self._norm(hidden_states.float()) |
| if self.with_scale: |
| normed_output = normed_output * self.weight.float() |
| return normed_output.type_as(hidden_states) |
|
|
|
|
| class Gemma4AudioRelPositionalEncoding(nn.Module): |
| """Sinusoidal relative positional encoding for the audio encoder. |
| |
| Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with |
| concatenated [sin..., cos...] layout matching the original Gemma4 convention. |
| """ |
|
|
| inv_timescales: torch.Tensor |
|
|
| def __init__(self, config: Gemma4AudioConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.context_size = ( |
| config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right |
| ) |
| min_timescale = 1.0 |
| max_timescale = 10000.0 |
| num_timescales = self.hidden_size // 2 |
| log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) |
| inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) |
| self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False) |
|
|
| @torch.no_grad() |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| position_ids = torch.arange(12, -1, -1, device=hidden_states.device) |
| position_ids = position_ids[..., None] |
| scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device) |
| pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) |
| return pos_embed.to(dtype=hidden_states.dtype) |
|
|
|
|
| class Gemma4AudioAttention(nn.Module): |
| """Chunked local attention with relative position bias""" |
|
|
| def __init__(self, config: Gemma4AudioConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.attention_logits_soft_cap = config.attention_logit_cap |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.num_heads = config.num_attention_heads |
|
|
| self.q_scale = (self.head_dim**-0.5) / math.log(2) |
| self.k_scale = math.log(1 + math.e) / math.log(2) |
|
|
| self.chunk_size = config.attention_chunk_size |
| self.max_past_horizon = config.attention_context_left - 1 |
| self.max_future_horizon = config.attention_context_right |
| self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon |
|
|
| self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) |
| self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) |
| self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) |
| self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) |
|
|
| self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) |
| self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim)) |
|
|
| self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False) |
|
|
| def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim.""" |
| batch_size, seq_len, num_heads, head_dim = hidden_states.shape |
| num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size |
| pad = num_blocks * self.chunk_size - seq_len |
| hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad)) |
| return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous() |
|
|
| def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`.""" |
| batch_size, seq_len, num_heads, head_dim = hidden_states.shape |
| hidden_states = F.pad( |
| hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1) |
| ) |
| hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size) |
| hidden_states = torch.movedim(hidden_states, -1, 2) |
| return hidden_states.contiguous() |
|
|
| def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: |
| """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860.""" |
| batch_size, num_heads, num_blocks, block_size, position_length = x.shape |
| context_size = self.context_size |
| x = F.pad(x, (0, context_size + 1 - position_length)) |
| x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1)) |
| x = x[..., : block_size * context_size] |
| return x.view(batch_size, num_heads, num_blocks, block_size, context_size) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: torch.Tensor, |
| attention_mask: torch.BoolTensor | None = None, |
| ) -> tuple[torch.Tensor, None]: |
| batch_size, seq_length, _ = hidden_states.shape |
| hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim) |
|
|
| query_states = self.q_proj(hidden_states).float().view(hidden_shape) |
| key_states = self.k_proj(hidden_states).float().view(hidden_shape) |
| value_states = self.v_proj(hidden_states).float().view(hidden_shape) |
|
|
| query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale) |
| key_states = key_states * self.k_scale |
|
|
| query_states = self._convert_to_block(query_states) |
| key_states = self._extract_block_context(key_states) |
| value_states = self._extract_block_context(value_states) |
| num_blocks = query_states.shape[1] |
|
|
| relative_key_states = self.relative_k_proj(position_embeddings) |
| relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim) |
| relative_key_states = relative_key_states.to(dtype=query_states.dtype) |
|
|
| queries = query_states.permute(0, 3, 1, 2, 4) |
| matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2) |
|
|
| queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim) |
| matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0) |
| matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1) |
| matrix_bd = self._rel_shift(matrix_bd) |
|
|
| attn_weights = matrix_ac + matrix_bd |
| attn_weights = attn_weights / self.softcap |
| attn_weights = torch.tanh(attn_weights) |
| attn_weights = attn_weights * self.softcap |
|
|
| if attention_mask is not None: |
| attn_weights = attn_weights.masked_fill( |
| attention_mask.logical_not(), self.config.attention_invalid_logits_value |
| ) |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) |
| attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4) |
| attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1) |
| attn_output = attn_output[:, :seq_length].contiguous() |
| attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype)) |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class Gemma4AudioSubSampleConvProjectionLayer(nn.Module): |
| def __init__(self, in_channels, out_channels, norm_eps): |
| super().__init__() |
| self.conv = nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=(3, 3), |
| stride=(2, 2), |
| padding=1, |
| bias=False, |
| ) |
| self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False) |
| self.act = nn.ReLU() |
|
|
| def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): |
| if mask is not None: |
| mask = mask.to(device=hidden_states.device) |
| hidden_states = hidden_states * mask[:, None, :, None] |
|
|
| hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype)) |
| hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) |
|
|
| if mask is not None: |
| mask = mask[:, ::2] |
|
|
| return hidden_states, mask |
|
|
|
|
| class Gemma4AudioSubSampleConvProjection(nn.Module): |
| def __init__(self, config: Gemma4AudioConfig): |
| super().__init__() |
| self.layer0 = Gemma4AudioSubSampleConvProjectionLayer( |
| in_channels=1, |
| out_channels=config.subsampling_conv_channels[0], |
| norm_eps=config.rms_norm_eps, |
| ) |
| self.layer1 = Gemma4AudioSubSampleConvProjectionLayer( |
| in_channels=config.subsampling_conv_channels[0], |
| out_channels=config.subsampling_conv_channels[1], |
| norm_eps=config.rms_norm_eps, |
| ) |
| proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1] |
| self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False) |
|
|
| def forward( |
| self, |
| input_features: torch.Tensor, |
| input_features_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| hidden_states = input_features.unsqueeze(1) |
| hidden_states, mask = self.layer0(hidden_states, input_features_mask) |
| hidden_states, mask = self.layer1(hidden_states, mask) |
|
|
| batch_size, _, seq_len, _ = hidden_states.shape |
| hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) |
| return self.input_proj_linear(hidden_states), mask |
|
|
|
|
| class Gemma4AudioFeedForward(nn.Module): |
| def __init__(self, config: Gemma4AudioConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4) |
| self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size) |
|
|
| self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size) |
| self.post_layer_norm = Gemma4RMSNorm(config.hidden_size) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| self.gradient_clipping = config.gradient_clipping |
| self.post_layer_scale = config.residual_weight |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| |
| gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max) |
|
|
| residual = hidden_states |
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.pre_layer_norm(hidden_states) |
|
|
| hidden_states = self.ffw_layer_1(hidden_states) |
| hidden_states = self.act_fn(hidden_states) |
| hidden_states = self.ffw_layer_2(hidden_states) |
|
|
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.post_layer_norm(hidden_states) |
| hidden_states *= self.post_layer_scale |
| hidden_states += residual |
|
|
| return hidden_states |
|
|
|
|
| |
| class Gemma4AudioCausalConv1d(nn.Conv1d): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @cached_property |
| def left_pad(self): |
| effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1 |
| return effective_kernel_size - self.stride[0] |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| |
| ) -> torch.Tensor: |
| |
| |
| |
| |
| x = nn.functional.pad(x, (self.left_pad, 0)) |
|
|
| return super().forward(x) |
|
|
|
|
| class Gemma4AudioLightConv1d(nn.Module): |
| def __init__(self, config: Gemma4AudioConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2) |
| self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) |
| self.depthwise_conv1d = Gemma4AudioCausalConv1d( |
| in_channels=config.hidden_size, |
| out_channels=config.hidden_size, |
| kernel_size=config.conv_kernel_size, |
| groups=config.hidden_size, |
| bias=False, |
| ) |
|
|
| self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) |
| self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| self.gradient_clipping = config.gradient_clipping |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| residual = hidden_states |
|
|
| hidden_states = self.pre_layer_norm(hidden_states) |
| hidden_states = self.linear_start(hidden_states) |
| hidden_states = nn.functional.glu(hidden_states, dim=-1) |
|
|
| hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max) |
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.conv_norm(hidden_states) |
|
|
| hidden_states = self.act_fn(hidden_states) |
| hidden_states = self.linear_end(hidden_states) |
| hidden_states += residual |
| return hidden_states |
|
|
|
|
| class Gemma4AudioLayer(nn.Module): |
| def __init__(self, config: Gemma4AudioConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
|
|
| self.feed_forward1 = Gemma4AudioFeedForward(config) |
| self.feed_forward2 = Gemma4AudioFeedForward(config) |
| self.self_attn = Gemma4AudioAttention(config, layer_idx) |
| self.lconv1d = Gemma4AudioLightConv1d(config) |
|
|
| self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size) |
| self.norm_post_attn = Gemma4RMSNorm(config.hidden_size) |
| self.norm_out = Gemma4RMSNorm(config.hidden_size) |
|
|
| self.gradient_clipping = config.gradient_clipping |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.BoolTensor | None, |
| position_embeddings: torch.Tensor, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| |
| gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max) |
|
|
| hidden_states = self.feed_forward1(hidden_states) |
| residual = hidden_states |
|
|
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.norm_pre_attn(hidden_states) |
|
|
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| ) |
|
|
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.norm_post_attn(hidden_states) |
| hidden_states += residual |
|
|
| hidden_states = self.lconv1d(hidden_states) |
| hidden_states = self.feed_forward2(hidden_states) |
|
|
| hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) |
| hidden_states = self.norm_out(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| |
|
|
|
|
| class Gemma4VisionPatchEmbedder(nn.Module): |
| def __init__(self, config: Gemma4VisionConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.patch_size = config.patch_size |
| self.position_embedding_size = config.position_embedding_size |
|
|
| self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) |
| self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size)) |
|
|
| def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor: |
| """Prepare patch positions map for matmul with positon embedding table.""" |
| |
| clamped_positions = pixel_position_ids.clamp(min=0) |
| one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) |
| one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) |
| |
| position_embeddings = one_hot @ self.position_embedding_table |
| position_embeddings = position_embeddings.sum(dim=1) |
| |
| position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) |
| return position_embeddings |
|
|
| def forward( |
| self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor |
| ) -> torch.Tensor: |
| |
| pixel_values = 2 * (pixel_values - 0.5) |
| hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) |
| position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions) |
| return hidden_states + position_embeddings |
|
|
|
|
| class Gemma4VisionPooler(nn.Module): |
| """Scaling and optional spatial pooling for vision encodings""" |
|
|
| def __init__(self, config: Gemma4VisionConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.root_hidden_size = self.hidden_size**0.5 |
|
|
| def _avg_pool_by_positions( |
| self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| 2D spatial pooling according to patch positions. |
| Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between |
| input and output lengths |
| """ |
| input_seq_len = hidden_states.shape[1] |
| k = int((input_seq_len // length) ** 0.5) |
| k_squared = k**2 |
| if k_squared * length != input_seq_len: |
| raise ValueError( |
| f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." |
| ) |
|
|
| |
| |
| clamped_positions = pixel_position_ids.clamp(min=0) |
| max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 |
| kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") |
| kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] |
| weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared |
| output = weights.transpose(1, 2) @ hidden_states.float() |
| mask = torch.logical_not((weights == 0).all(dim=1)) |
| return output.to(hidden_states.dtype), mask |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| pixel_position_ids: torch.Tensor, |
| padding_positions: torch.Tensor, |
| output_length: int | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if output_length > hidden_states.shape[1]: |
| raise ValueError( |
| f"Cannot output more soft tokens (requested {output_length}) than there are patches" |
| f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." |
| ) |
|
|
| hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) |
|
|
| if hidden_states.shape[1] != output_length: |
| hidden_states, padding_positions = self._avg_pool_by_positions( |
| hidden_states, pixel_position_ids, output_length |
| ) |
|
|
| hidden_states *= self.root_hidden_size |
| return hidden_states, padding_positions |
|
|
|
|
| class Gemma4VisionMLP(nn.Module): |
| def __init__(self, config: Gemma4VisionConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) |
| self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) |
| self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size) |
| self.act_fn = ACT2FN[config.hidden_activation] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class Gemma4VisionRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: Gemma4VisionConfig, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
|
|
| self.rope_type = self.config.rope_parameters["rope_type"] |
| rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: Gemma4VisionConfig | None = None, |
| device: torch.device | None = None, |
| seq_len: int | None = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation |
| Args: |
| config ([`~transformers.PreTrainedConfig`]): |
| The model configuration. |
| device (`torch.device`): |
| The device to use for initialization of the inverse frequencies. |
| seq_len (`int`, *optional*): |
| The current sequence length. Unused for this type of RoPE. |
| Returns: |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). |
| """ |
| base = config.rope_parameters["rope_theta"] |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
| |
| |
| |
| |
| spatial_dim = dim // 2 |
|
|
| attention_factor = 1.0 |
| inv_freq = 1.0 / ( |
| base |
| ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
| |
| all_cos, all_sin = [], [] |
| for i in range(2): |
| dim_position_ids = position_ids[:, :, i] |
| dim_position_ids_expanded = dim_position_ids[:, None, :].float() |
|
|
| with maybe_autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
| all_cos.append(cos) |
| all_sin.append(sin) |
|
|
| cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype) |
| sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype) |
| return cos, sin |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| x (`torch.Tensor`): The tensor to embed. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| dropout: float | int = 0.0, |
| scaling: float | None = None, |
| softcap: float | None = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if scaling is None: |
| scaling = module.head_dim**-0.5 |
|
|
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
| if softcap is not None: |
| attn_weights = attn_weights / softcap |
| attn_weights = torch.tanh(attn_weights) |
| attn_weights = attn_weights * softcap |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| return attn_output, attn_weights |
|
|
|
|
| def apply_multidimensional_rope( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| position_ids: torch.Tensor, |
| unsqueeze_dim: int = 2, |
| ) -> torch.Tensor: |
| """Applies multidimensional RoPE to inputs. |
| |
| Args: |
| x (`torch.Tensor`): The tensor to embed. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`. |
| Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to |
| `apply_rotary_pos_emb()`, and then concatenated back together. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| |
| Returns: |
| Tensor of shape [B, L, N, H] with RoPE applied. |
| """ |
| ndim = position_ids.shape[-1] |
| num_input_channels = x.shape[-1] |
| num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim)) |
|
|
| if num_rotated_channels_per_dim <= 0: |
| raise ValueError( |
| "Invalid configuration: num_rotated_channels_per_dim must be > 0, got" |
| f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels}," |
| f" ndim={ndim})" |
| ) |
|
|
| |
| split_sizes = [num_rotated_channels_per_dim] * ndim |
| x_parts = torch.split(x, split_sizes, dim=-1) |
| cos_parts = torch.split(cos, split_sizes, dim=-1) |
| sin_parts = torch.split(sin, split_sizes, dim=-1) |
| y_parts = [ |
| apply_rotary_pos_emb( |
| x=x_parts[k], |
| cos=cos_parts[k], |
| sin=sin_parts[k], |
| unsqueeze_dim=unsqueeze_dim, |
| ) |
| for k in range(ndim) |
| ] |
| return torch.cat(y_parts, dim=-1) |
|
|
|
|
| @use_kernelized_func(apply_rotary_pos_emb) |
| class Gemma4VisionAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Gemma4VisionConfig, layer_idx: int): |
| super().__init__() |
| self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = 1.0 |
| self.attention_dropout = self.config.attention_dropout |
| self.is_causal = False |
| self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim) |
| self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) |
| self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) |
| self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size) |
|
|
| self.q_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) |
| self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: torch.Tensor = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| cos, sin = position_embeddings |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape) |
| query_states = self.q_norm(query_states) |
| query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids) |
| query_states = query_states.transpose(1, 2) |
|
|
| key_states = self.k_proj(hidden_states).view(hidden_shape) |
| key_states = self.k_norm(key_states) |
| key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids) |
| key_states = key_states.transpose(1, 2) |
|
|
| value_states = self.v_proj(hidden_states).view(hidden_shape) |
| value_states = self.v_norm(value_states) |
| value_states = value_states.transpose(1, 2) |
|
|
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( |
| self.config._attn_implementation, eager_attention_forward |
| ) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=self.attention_dropout if self.training else 0.0, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Gemma4VisionEncoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: Gemma4VisionConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
| self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) |
| self.mlp = Gemma4VisionMLP(config) |
| self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: torch.Tensor = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: |
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| **kwargs, |
| ) |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.pre_feedforward_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = self.post_feedforward_layernorm(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| class Gemma4VisionEncoder(nn.Module): |
| def __init__(self, config: Gemma4VisionConfig): |
| super().__init__() |
| self.config = config |
| self.num_layers = config.num_hidden_layers |
| self.rotary_emb = Gemma4VisionRotaryEmbedding(config) |
| self.layers = nn.ModuleList( |
| [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)] |
| ) |
|
|
| def forward( |
| self, |
| inputs_embeds: torch.Tensor, |
| attention_mask: torch.Tensor, |
| pixel_position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| r""" |
| pixel_position_ids (torch.Tensor): |
| Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2]. |
| """ |
| attention_mask = create_bidirectional_mask( |
| config=self.config, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| ) |
|
|
| |
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids) |
|
|
| |
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| position_ids=pixel_position_ids, |
| **kwargs, |
| ) |
|
|
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) |
|
|
|
|
| class Gemma4TextMLP(nn.Module): |
| def __init__(self, config: Gemma4TextConfig, layer_idx: int): |
| super().__init__() |
| first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers |
| is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 |
| use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1) |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_activation] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class Gemma4TextRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.layer_types = set(config.layer_types) |
| self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {} |
| self.rope_type: dict[str, str] = {} |
|
|
| for layer_type in self.layer_types: |
| rope_params = self.config.rope_parameters[layer_type] |
| if rope_params is None: |
| continue |
|
|
| if (rope_type := rope_params["rope_type"]) != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] |
| else: |
| rope_init_fn = self.compute_default_rope_parameters |
|
|
| self.rope_init_fns[layer_type] = rope_init_fn |
| self.rope_type[layer_type] = rope_type |
|
|
| rope_init_fn_kwargs = {"device": device, "layer_type": layer_type} |
| if layer_type == "full_attention" and rope_type == "proportional": |
| rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" |
|
|
| curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs) |
| self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) |
| self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) |
| setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: Gemma4TextConfig | None = None, |
| device: Optional["torch.device"] = None, |
| seq_len: int | None = None, |
| layer_type: str | None = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation |
| Args: |
| config ([`~transformers.PreTrainedConfig`]): |
| The model configuration. |
| device (`torch.device`): |
| The device to use for initialization of the inverse frequencies. |
| seq_len (`int`, *optional*): |
| The current sequence length. Unused for this type of RoPE. |
| layer_type (`str`, *optional*): |
| The current layer type if the model has different RoPE parameters per type. |
| Should not be used unless `config.layer_types is not None` |
| |
| Returns: |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). |
| """ |
| |
| base = config.rope_parameters[layer_type]["rope_theta"] |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
| attention_factor = 1.0 |
|
|
| |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids, layer_type=None): |
| inv_freq = getattr(self, f"{layer_type}_inv_freq") |
| attention_scaling = getattr(self, f"{layer_type}_attention_scaling") |
|
|
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with maybe_autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * attention_scaling |
| sin = emb.sin() * attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| @use_kernelized_func(apply_rotary_pos_emb) |
| class Gemma4TextAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Gemma4TextConfig, layer_idx: int): |
| super().__init__() |
| self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None |
| self.config = config |
| self.layer_idx = layer_idx |
| self.is_sliding = self.layer_type == "sliding_attention" |
| self.sliding_window = config.sliding_window if self.is_sliding else None |
|
|
| self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim |
| self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding |
| num_key_value_heads = ( |
| config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads |
| ) |
| self.num_key_value_groups = config.num_attention_heads // num_key_value_heads |
| self.scaling = 1.0 |
| self.attention_dropout = self.config.attention_dropout |
| self.is_causal = config.use_bidirectional_attention != "all" |
|
|
| |
| first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0) |
| self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 |
| prev_layers = config.layer_types[:first_kv_shared_layer_idx] |
| if self.is_kv_shared_layer: |
| |
| self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) |
| self.store_full_length_kv = False |
| else: |
| self.kv_shared_layer_index = None |
| |
| self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( |
| config.layer_types[layer_idx] |
| ) |
|
|
| self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) |
| self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) |
|
|
| self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = ( |
| nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| if not self.use_alternative_attention |
| else None |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| cos, sin = position_embeddings |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape) |
| query_states = self.q_norm(query_states) |
| query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) |
| query_states = query_states.transpose(1, 2) |
|
|
| |
| if self.is_kv_shared_layer and past_key_values is not None: |
| key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] |
| |
| key_states = key_states.to(query_states.device) |
| value_states = value_states.to(query_states.device) |
| else: |
| key_states = self.k_proj(hidden_states).view(hidden_shape) |
| value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states |
|
|
| key_states = self.k_norm(key_states) |
| key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) |
| key_states = key_states.transpose(1, 2) |
|
|
| value_states = self.v_norm(value_states) |
| value_states = value_states.transpose(1, 2) |
|
|
| if past_key_values is not None: |
| if not self.is_kv_shared_layer: |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) |
| if self.store_full_length_kv: |
| if not hasattr(past_key_values, "shared_layers"): |
| past_key_values.shared_layers = {} |
| past_key_values.shared_layers[self.layer_idx] = key_states, value_states |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=self.attention_dropout if self.training else 0.0, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
| |
| @use_experts_implementation |
| class Gemma4TextExperts(nn.Module): |
| """Collection of expert weights stored as 3D tensors.""" |
|
|
| def __init__(self, config: Gemma4TextConfig): |
| super().__init__() |
| self.num_experts = config.num_experts |
| self.hidden_dim = config.hidden_size |
| self.intermediate_dim = config.moe_intermediate_size |
| self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) |
| self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) |
| self.act_fn = ACT2FN[config.hidden_activation] |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| top_k_index: torch.Tensor, |
| top_k_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| final_hidden_states = torch.zeros_like(hidden_states) |
| with torch.no_grad(): |
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) |
| expert_mask = expert_mask.permute(2, 1, 0) |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() |
|
|
| for expert_idx in expert_hit: |
| expert_idx = expert_idx[0] |
| if expert_idx == self.num_experts: |
| continue |
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) |
| current_state = hidden_states[token_idx] |
| gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) |
| current_hidden_states = self.act_fn(gate) * up |
| current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) |
| current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] |
| final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) |
|
|
| return final_hidden_states |
|
|
|
|
| |
| class Gemma4TextRouter(nn.Module): |
| def __init__(self, config: Gemma4TextConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.scalar_root_size = self.hidden_size**-0.5 |
| self.eps = config.rms_norm_eps |
|
|
| self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False) |
| self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
| self.scale = nn.Parameter(torch.ones(self.hidden_size)) |
| self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| hidden_states = self.norm(hidden_states) |
| hidden_states = hidden_states * self.scale * self.scalar_root_size |
|
|
| expert_scores = self.proj(hidden_states) |
| router_probabilities = nn.functional.softmax(expert_scores, dim=-1) |
|
|
| |
| top_k_weights, top_k_index = torch.topk( |
| router_probabilities, |
| k=self.config.top_k_experts, |
| dim=-1, |
| ) |
|
|
| |
| top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) |
|
|
| |
| top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] |
|
|
| return router_probabilities, top_k_weights, top_k_index |
|
|
| |
| class Gemma4TextDecoderLayer(GradientCheckpointingLayer): |
| router_class = Gemma4TextRouter |
| experts_class = Gemma4TextExperts |
|
|
| def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
| self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx) |
| self.mlp = Gemma4TextMLP(config, layer_idx) |
| self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.register_buffer("layer_scalar", torch.ones(1)) |
|
|
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input |
| if self.hidden_size_per_layer_input: |
| self.act_fn = ACT2FN[config.hidden_activation] |
| self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) |
| self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) |
| self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.enable_moe_block = config.enable_moe_block |
| if self.enable_moe_block: |
| self.router = self.router_class(config) |
| self.experts = self.experts_class(config) |
| self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
| self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| per_layer_input: torch.Tensor = None, |
| position_embeddings: torch.Tensor = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| **kwargs, |
| ) |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.pre_feedforward_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
|
|
| if self.enable_moe_block: |
| hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) |
|
|
| |
| hidden_states_flat = residual.reshape(-1, residual.shape[-1]) |
| _, top_k_weights, top_k_index = self.router(hidden_states_flat) |
| hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat) |
| hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights) |
| hidden_states_2 = hidden_states_2.reshape(residual.shape) |
| hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) |
|
|
| |
| hidden_states = hidden_states_1 + hidden_states_2 |
|
|
| hidden_states = self.post_feedforward_layernorm(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| if self.hidden_size_per_layer_input: |
| residual = hidden_states |
| hidden_states = self.per_layer_input_gate(hidden_states) |
| hidden_states = self.act_fn(hidden_states) |
| hidden_states = hidden_states * per_layer_input |
| hidden_states = self.per_layer_projection(hidden_states) |
| hidden_states = self.post_per_layer_input_norm(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| hidden_states *= self.layer_scalar |
| return hidden_states |
|
|
|
|
| class Gemma4TextScaledWordEmbedding(nn.Embedding): |
| """ |
| This module overrides nn.Embeddings' forward by multiplying with embeddings scale. |
| """ |
|
|
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): |
| super().__init__(num_embeddings, embedding_dim, padding_idx) |
| self.scalar_embed_scale = embed_scale |
| self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) |
|
|
| def forward(self, input_ids: torch.Tensor): |
| return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) |
|
|
|
|
| |
|
|
|
|
| class Gemma4PreTrainedModel(PreTrainedModel): |
| config: Gemma4Config |
| supports_gradient_checkpointing = True |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| input_modalities = ("image", "text", "video", "audio") |
|
|
| @torch.no_grad() |
| def _init_weights(self, module): |
| super()._init_weights(module) |
| if isinstance(module, Gemma4VisionPatchEmbedder): |
| init.ones_(module.position_embedding_table) |
| elif isinstance(module, Gemma4AudioRelPositionalEncoding): |
| min_timescale = 1.0 |
| max_timescale = 10000.0 |
| num_timescales = module.hidden_size // 2 |
| log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) |
| inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) |
| init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0)) |
| elif isinstance(module, Gemma4AudioAttention): |
| init.constant_(module.softcap, module.attention_logits_soft_cap) |
| init.zeros_(module.per_dim_scale) |
| elif isinstance(module, Gemma4TextRotaryEmbedding): |
| for layer_type, rope_init_fn in module.rope_init_fns.items(): |
| rope_init_fn_kwargs = {"layer_type": layer_type} |
| if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional": |
| rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" |
|
|
| curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs) |
| init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq) |
| init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq) |
| elif isinstance(module, Gemma4VisionRotaryEmbedding): |
| rope_fn = ( |
| ROPE_INIT_FUNCTIONS[module.rope_type] |
| if module.rope_type != "default" |
| else module.compute_default_rope_parameters |
| ) |
| buffer_value, _ = rope_fn(module.config) |
| init.copy_(module.inv_freq, buffer_value) |
| init.copy_(module.original_inv_freq, buffer_value) |
| elif isinstance(module, Gemma4TextScaledWordEmbedding): |
| init.constant_(module.embed_scale, module.scalar_embed_scale) |
| elif isinstance(module, Gemma4TextRouter): |
| init.ones_(module.scale) |
| init.ones_(module.per_expert_scale) |
| elif isinstance(module, Gemma4TextExperts): |
| std = self.config.initializer_range |
| init.normal_(module.gate_up_proj, mean=0.0, std=std) |
| init.normal_(module.down_proj, mean=0.0, std=std) |
| elif isinstance(module, Gemma4TextDecoderLayer): |
| init.ones_(module.layer_scalar) |
| elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears: |
| init.constant_(module.input_min, -float("inf")) |
| init.constant_(module.input_max, float("inf")) |
| init.constant_(module.output_min, -float("inf")) |
| init.constant_(module.output_max, float("inf")) |
| elif isinstance(module, Gemma4VisionModel) and module.config.standardize: |
| init.zeros_(module.std_bias) |
| init.ones_(module.std_scale) |
|
|
|
|
| @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") |
| class Gemma4TextModel(Gemma4PreTrainedModel): |
| config: Gemma4TextConfig |
| decoder_layer_class = Gemma4TextDecoderLayer |
| input_modalities = ("text",) |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(Gemma4TextRouter, index=0), |
| "hidden_states": Gemma4TextDecoderLayer, |
| "attentions": Gemma4TextAttention, |
| } |
|
|
| def __init__(self, config: Gemma4TextConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| |
| self.embed_tokens = Gemma4TextScaledWordEmbedding( |
| config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 |
| ) |
| self.layers = nn.ModuleList([self.decoder_layer_class(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) |
| self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = Gemma4TextRotaryEmbedding(config) |
| self.gradient_checkpointing = False |
| self.unique_layer_types = set(self.config.layer_types) |
|
|
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input |
| if self.hidden_size_per_layer_input: |
| self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( |
| config.vocab_size_per_layer_input, |
| config.num_hidden_layers * config.hidden_size_per_layer_input, |
| self.padding_idx, |
| embed_scale=config.hidden_size_per_layer_input**0.5, |
| ) |
| self.per_layer_input_scale = 2.0**-0.5 |
| self.per_layer_model_projection = nn.Linear( |
| config.hidden_size, |
| config.num_hidden_layers * config.hidden_size_per_layer_input, |
| bias=False, |
| ) |
| self.per_layer_model_projection_scale = config.hidden_size**-0.5 |
| self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) |
|
|
| |
| self.post_init() |
|
|
| @merge_with_config_defaults |
| @capture_outputs |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| per_layer_inputs: torch.Tensor | None = None, |
| use_cache: bool | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| r""" |
| per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*): |
| Pre-computed per-layer input embeddings. When provided, these are used directly instead of being |
| computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal |
| model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before* |
| merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are |
| no longer recoverable. |
| """ |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if input_ids is not None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if self.hidden_size_per_layer_input: |
| if per_layer_inputs is None: |
| per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds) |
| per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if position_ids is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens |
| position_ids = position_ids.unsqueeze(0) |
|
|
| |
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| |
| mask_kwargs = { |
| "config": self.config, |
| "inputs_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
| |
| causal_mask_mapping = { |
| "full_attention": create_causal_mask(**mask_kwargs), |
| "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), |
| } |
|
|
| |
| hidden_states = inputs_embeds |
| position_embeddings = {} |
| for layer_type in self.unique_layer_types: |
| position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) |
|
|
| |
| for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None |
|
|
| hidden_states = decoder_layer( |
| hidden_states, |
| per_layer_input, |
| position_embeddings=position_embeddings[self.config.layer_types[i]], |
| attention_mask=causal_mask_mapping[self.config.layer_types[i]], |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
| def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor: |
| if not self.hidden_size_per_layer_input: |
| raise RuntimeError( |
| "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support" |
| f" per-layer embeddings. {self.config}" |
| ) |
|
|
| |
| |
| if input_ids is None: |
| with torch.no_grad(): |
| input_ids = ( |
| ( |
| inputs_embeds[:, :, None, :] |
| == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5 |
| ) |
| .all(dim=3) |
| .nonzero()[:, 2] |
| ) |
| try: |
| input_ids = input_ids.view(inputs_embeds.shape[:2]) |
| except RuntimeError: |
| raise RuntimeError( |
| "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that " |
| "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse " |
| "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`" |
| ) |
|
|
| return self.embed_tokens_per_layer(input_ids).reshape( |
| *input_ids.shape, |
| self.config.num_hidden_layers, |
| self.hidden_size_per_layer_input, |
| ) |
|
|
| def project_per_layer_inputs( |
| self, |
| inputs_embeds: torch.Tensor, |
| per_layer_inputs: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| if not self.hidden_size_per_layer_input: |
| raise RuntimeError( |
| "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not" |
| f" support per-layer embeddings. {self.config}" |
| ) |
|
|
| per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale |
| per_layer_projection = per_layer_projection.reshape( |
| *inputs_embeds.shape[:-1], |
| self.config.num_hidden_layers, |
| self.hidden_size_per_layer_input, |
| ) |
| per_layer_projection = self.per_layer_projection_norm(per_layer_projection) |
|
|
| if per_layer_inputs is None: |
| return per_layer_projection |
|
|
| return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale |
|
|
|
|
| @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.") |
| class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
| _tp_plan = {"lm_head": "colwise_gather_output"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
| config: Gemma4TextConfig |
| base_model_prefix = "model" |
| text_model_class = Gemma4TextModel |
|
|
| def __init__(self, config: Gemma4TextConfig): |
| super().__init__(config) |
| self.model = self.text_model_class(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Gemma4ForCausalLM |
| |
| >>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-2-9b") |
| >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
| |
| >>> prompt = "What is your favorite condiment?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "What is your favorite condiment?" |
| ```""" |
| |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| if self.config.final_logit_softcapping is not None: |
| logits = logits / self.config.final_logit_softcapping |
| logits = torch.tanh(logits) |
| logits = logits * self.config.final_logit_softcapping |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable: |
| """ |
| This creates uni/bidirectional attention mask with sliding window. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| left_window_size, right_window_size = sliding_window |
|
|
| dist = q_idx - kv_idx |
| left_mask = (dist >= 0) & (dist < left_window_size) |
| right_mask = (dist < 0) & (-dist < right_window_size) |
| return left_mask | right_mask |
|
|
| return inner_mask |
|
|
|
|
| class Gemma4AudioModel(Gemma4PreTrainedModel): |
| """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.""" |
|
|
| config: Gemma4AudioConfig |
| main_input_name = "input_features" |
| base_model_prefix = "model.audio_tower" |
| _can_record_outputs = { |
| "hidden_states": Gemma4AudioLayer, |
| "attentions": Gemma4AudioAttention, |
| } |
|
|
| def __init__(self, config: Gemma4AudioConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config) |
| self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config) |
| self.layers = nn.ModuleList( |
| [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True) |
|
|
| self.post_init() |
|
|
| def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format |
| `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention, |
| """ |
| batch_size, _, seq_len, _ = mask_4d.shape |
| device = mask_4d.device |
|
|
| chunk_size = self.config.attention_chunk_size |
| max_past_horizon = self.config.attention_context_left - 1 |
| max_future_horizon = self.config.attention_context_right |
|
|
| num_blocks = (seq_len + chunk_size - 1) // chunk_size |
| padded_seq_len = num_blocks * chunk_size |
| pad_amount = padded_seq_len - seq_len |
|
|
| mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False) |
| mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len) |
| mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False) |
|
|
| block_starts = torch.arange(num_blocks, device=device) * chunk_size |
| offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device) |
| kv_indices = block_starts[:, None] + offsets[None, :] |
| kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1) |
|
|
| return mask_5d.gather(-1, kv_indices) |
|
|
| @merge_with_config_defaults |
| @capture_outputs |
| @auto_docstring(custom_intro="Encodes audio features to soft tokens.") |
| def forward( |
| self, |
| input_features: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, torch.BoolTensor]: |
| hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask) |
| position_embeddings = self.rel_pos_enc(hidden_states) |
|
|
| attention_mask = create_bidirectional_mask( |
| config=self.config, |
| inputs_embeds=hidden_states, |
| attention_mask=output_mask, |
| and_mask_function=sliding_window_mask_function( |
| (self.config.attention_context_left - 1, self.config.attention_context_right) |
| ), |
| ) |
| attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) |
|
|
| for encoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = encoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.output_proj(hidden_states) |
| return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask) |
|
|
|
|
| class Gemma4VisionModel(Gemma4PreTrainedModel): |
| """The Gemma 4 Vision Encoder.""" |
|
|
| config = Gemma4VisionConfig |
| _can_record_outputs = { |
| "hidden_states": Gemma4VisionEncoderLayer, |
| "attentions": Gemma4VisionAttention, |
| } |
|
|
| def __init__(self, config: Gemma4VisionConfig): |
| super().__init__(config) |
| self.patch_embedder = Gemma4VisionPatchEmbedder(config) |
| self.encoder = Gemma4VisionEncoder(config) |
| self.pooler = Gemma4VisionPooler(config) |
|
|
| if self.config.standardize: |
| self.register_buffer("std_bias", torch.empty(self.config.hidden_size)) |
| self.register_buffer("std_scale", torch.empty(self.config.hidden_size)) |
|
|
| self.post_init() |
|
|
| @merge_with_config_defaults |
| @capture_outputs |
| @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.") |
| def forward( |
| self, |
| pixel_values: torch.FloatTensor, |
| pixel_position_ids: torch.LongTensor, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| r""" |
| pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`): |
| The images to encode. Either a single `[batch, channels, height, width]` tensor |
| (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes). |
| pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`): |
| The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). |
| """ |
| pooling_kernel_size = self.config.pooling_kernel_size |
| output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size) |
|
|
| padding_positions = (pixel_position_ids == -1).all(dim=-1) |
| inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions) |
| output = self.encoder( |
| inputs_embeds=inputs_embeds, |
| attention_mask=~padding_positions, |
| pixel_position_ids=pixel_position_ids, |
| **kwargs, |
| ) |
|
|
| hidden_states, pooler_mask = self.pooler( |
| hidden_states=output.last_hidden_state, |
| pixel_position_ids=pixel_position_ids, |
| padding_positions=padding_positions, |
| output_length=output_length, |
| ) |
|
|
| |
| hidden_states = hidden_states[pooler_mask] |
|
|
| if self.config.standardize: |
| hidden_states = (hidden_states - self.std_bias) * self.std_scale |
|
|
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) |
|
|
|
|
| class Gemma4MultimodalEmbedder(nn.Module): |
| """Embeds token ids or soft tokens for multimodal content into language model space.""" |
|
|
| def __init__( |
| self, |
| multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig, |
| text_config: Gemma4TextConfig, |
| ): |
| super().__init__() |
|
|
| self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size) |
| self.eps = multimodal_config.rms_norm_eps |
| self.text_hidden_size = text_config.hidden_size |
| self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) |
| self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False) |
|
|
| def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: |
| """Embeds token ids or soft tokens for multimodal content into language model space. |
| Args: |
| inputs_embeds: A torch.Tensor containing the soft tokens to embed. |
| Returns: |
| A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. |
| """ |
| embs_normed = self.embedding_pre_projection_norm(inputs_embeds) |
| return self.embedding_projection(embs_normed) |
|
|
|
|
| |
| def token_type_ids_mask_function( |
| token_type_ids: torch.Tensor | None, |
| image_group_ids: torch.Tensor | None, |
| ) -> Callable | None: |
| """ |
| This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, |
| not start and end indices. |
| """ |
| |
| if token_type_ids is None: |
| return None |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| seq_length = image_group_ids.shape[-1] |
|
|
| |
| q_idx_clamped = q_idx.clamp(max=seq_length - 1) |
| kv_idx_clamped = kv_idx.clamp(max=seq_length - 1) |
|
|
| |
| q_group = image_group_ids[batch_idx, q_idx_clamped] |
| kv_group = image_group_ids[batch_idx, kv_idx_clamped] |
| q_group = torch.where(q_idx < seq_length, q_group, -1) |
| kv_group = torch.where(kv_idx < seq_length, kv_group, -1) |
| return (q_group == kv_group) & (q_group >= 0) |
|
|
| return inner_mask |
|
|
|
|
| |
| def create_causal_mask_mapping( |
| config: PreTrainedConfig, |
| inputs_embeds: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None, |
| position_ids: torch.Tensor | None, |
| mm_token_type_ids: torch.Tensor | None = None, |
| pixel_values: torch.FloatTensor | None = None, |
| is_training: bool = False, |
| is_first_iteration: bool | None = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping |
| for all kinds of forward passes. Gemma4 uses a bidirectional mask for images. |
| |
| Uses `pixel_values` as an optional input to disambiguate edge cases. |
| """ |
| if is_training and mm_token_type_ids is None: |
| raise ValueError("`mm_token_type_ids` is required as a model input when training") |
|
|
| mask_kwargs = { |
| "config": config.get_text_config(), |
| "inputs_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
| sliding_mask_kwargs = mask_kwargs.copy() |
|
|
| |
| |
| |
| is_first_iteration = ( |
| is_first_iteration |
| if is_first_iteration is not None |
| else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None) |
| ) |
| if mm_token_type_ids is not None and is_first_iteration: |
| |
| |
|
|
| |
| |
| is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2) |
| is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) |
| is_prev_vision[..., 0] = False |
| new_vision_starts = is_vision & ~is_prev_vision |
| vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1 |
| vision_group_ids = torch.where(is_vision, vision_group_ids, -1) |
| sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function( |
| mm_token_type_ids.to(inputs_embeds.device), vision_group_ids |
| ) |
|
|
| return { |
| "full_attention": create_causal_mask(**mask_kwargs), |
| "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), |
| } |
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a |
| language modeling head. |
| """ |
| ) |
| class Gemma4Model(Gemma4PreTrainedModel): |
| |
| accepts_loss_kwargs = False |
|
|
| def __init__(self, config: Gemma4Config): |
| super().__init__(config) |
| self.vocab_size = config.text_config.vocab_size |
|
|
| language_model = AutoModel.from_config(config=config.text_config) |
| self.language_model = language_model |
| self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input |
| self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None |
| self.embed_vision = ( |
| Gemma4MultimodalEmbedder(config.vision_config, config.text_config) |
| if config.vision_config is not None |
| else None |
| ) |
| self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None |
| self.embed_audio = ( |
| Gemma4MultimodalEmbedder(config.audio_config, config.text_config) |
| if config.audio_config is not None |
| else None |
| ) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.language_model.set_input_embeddings(value) |
|
|
| @can_return_tuple |
| @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") |
| def get_image_features( |
| self, |
| pixel_values: torch.FloatTensor, |
| image_position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPooling: |
| r""" |
| image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): |
| The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). |
| """ |
| vision_outputs = self.vision_tower( |
| pixel_values=pixel_values, |
| pixel_position_ids=image_position_ids, |
| **kwargs, |
| ) |
| last_hidden_state = vision_outputs.last_hidden_state |
| vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) |
| return vision_outputs |
|
|
| def get_placeholder_mask( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]: |
| """ |
| Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens. |
| |
| Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that |
| precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using |
| `config.image_token_id`. Same goes for audio and video masks |
| |
| Args: |
| input_ids: A tensor containing the hard token IDs from the text tokenizer. |
| inputs_embeds: A tensor containing the embeddings for all hard text tokens. |
| |
| Returns: |
| image_mask, video_mask, audio_mask |
| """ |
| if input_ids is not None: |
| special_image_mask = input_ids == self.config.image_token_id |
| special_video_mask = input_ids == self.config.video_token_id |
| special_audio_mask = input_ids == self.config.audio_token_id |
| else: |
| special_image_mask = ( |
| inputs_embeds |
| == self.get_input_embeddings()( |
| torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
| ) |
| ).all(-1) |
| special_video_mask = ( |
| inputs_embeds |
| == self.get_input_embeddings()( |
| torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) |
| ) |
| ).all(-1) |
| special_audio_mask = ( |
| inputs_embeds |
| == self.get_input_embeddings()( |
| torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) |
| ) |
| ).all(-1) |
|
|
| return special_image_mask, special_video_mask, special_audio_mask |
|
|
| @merge_with_config_defaults |
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| pixel_values: torch.FloatTensor | None = None, |
| pixel_values_videos: torch.FloatTensor | None = None, |
| input_features: torch.FloatTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| input_features_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| mm_token_type_ids: torch.LongTensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| use_cache: bool | None = None, |
| image_position_ids: torch.LongTensor | None = None, |
| video_position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Gemma4ModelOutputWithPast: |
| r""" |
| input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): |
| The attention mask for the input audio. |
| image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| """ |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds) |
| multimodal_mask = image_mask | video_mask | audio_mask |
|
|
| |
| llm_input_ids = None |
| if inputs_embeds is None: |
| llm_input_ids = input_ids.clone() |
| llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id |
| inputs_embeds = self.get_input_embeddings()(llm_input_ids) |
|
|
| if self.config.get_text_config().hidden_size_per_layer_input: |
| pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :] |
| llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds) |
| per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds) |
| else: |
| per_layer_inputs = None |
|
|
| |
| if pixel_values is not None: |
| image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
| |
| n_image_tokens = image_mask.sum() |
| image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| torch_compilable_check( |
| inputs_embeds[image_mask].numel() == image_features.numel(), |
| f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:" |
| f" {image_features.shape[0]}", |
| ) |
|
|
| inputs_embeds = inputs_embeds.masked_scatter( |
| image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device) |
| ) |
|
|
| if pixel_values_videos is not None: |
| video_features = self.get_video_features( |
| pixel_values_videos, video_position_ids, return_dict=True |
| ).pooler_output |
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
| |
| n_video_tokens = video_mask.sum() |
| video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| torch_compilable_check( |
| inputs_embeds[video_mask].numel() == video_features.numel(), |
| f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:" |
| f" {video_features.shape[0]}", |
| ) |
|
|
| inputs_embeds = inputs_embeds.masked_scatter( |
| video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device) |
| ) |
|
|
| |
| if input_features is not None and input_features_mask is not None: |
| audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) |
| audio_features = audio_output.pooler_output |
| audio_mask_from_encoder = audio_output.attention_mask |
|
|
| |
| |
| |
| audio_features = audio_features[audio_mask_from_encoder] |
|
|
| n_audio_tokens = audio_mask.sum() |
| audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| torch_compilable_check( |
| inputs_embeds[audio_mask].numel() == audio_features.numel(), |
| f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:" |
| f" {audio_features.shape[0] * audio_features.shape[1]}", |
| ) |
|
|
| inputs_embeds = inputs_embeds.masked_scatter( |
| audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device) |
| ) |
|
|
| |
| if position_ids is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens |
| position_ids = position_ids.unsqueeze(0) |
|
|
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| if self.config.get_text_config().use_bidirectional_attention == "vision": |
| |
| causal_mask_mapping = create_causal_mask_mapping( |
| self.config, |
| inputs_embeds, |
| attention_mask, |
| past_key_values, |
| position_ids, |
| mm_token_type_ids, |
| pixel_values, |
| is_training=self.training, |
| ) |
| else: |
| |
| causal_mask_mapping = create_masks_for_generate( |
| self.config, |
| inputs_embeds, |
| attention_mask, |
| past_key_values, |
| position_ids, |
| ) |
|
|
| outputs = self.language_model( |
| per_layer_inputs=per_layer_inputs, |
| attention_mask=causal_mask_mapping, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| return Gemma4ModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| image_hidden_states=image_features if pixel_values is not None else None, |
| audio_hidden_states=audio_features if input_features is not None else None, |
| ) |
|
|
| @can_return_tuple |
| @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.") |
| def get_audio_features( |
| self, |
| input_features: torch.Tensor, |
| input_features_mask: torch.Tensor, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple | Gemma4AudioModelOutput: |
| r""" |
| input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): |
| The tensors corresponding to the input audio. |
| input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): |
| The attention mask for the input audio. |
| """ |
| if self.audio_tower is None: |
| raise ValueError( |
| "Audio features were requested, but the model was initialized without an audio_config. " |
| "Cannot process audio without an audio tower and audio embedder." |
| ) |
|
|
| audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs) |
| audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) |
|
|
| return audio_outputs |
|
|
| @can_return_tuple |
| @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.") |
| def get_video_features( |
| self, |
| pixel_values_videos: torch.FloatTensor, |
| video_position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPooling: |
| r""" |
| video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| """ |
| pixel_values_videos = pixel_values_videos.flatten(0, 1) |
| video_position_ids = video_position_ids.flatten(0, 1) |
| vision_outputs = self.vision_tower( |
| pixel_values=pixel_values_videos, |
| pixel_position_ids=video_position_ids, |
| **kwargs, |
| ) |
| last_hidden_state = vision_outputs.last_hidden_state |
| vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) |
| return vision_outputs |
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling |
| head. |
| """ |
| ) |
| class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} |
| base_model_prefix = "model" |
|
|
| def __init__(self, config: Gemma4Config): |
| super().__init__(config) |
| self.model = Gemma4Model(config) |
| self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| @auto_docstring |
| def get_image_features( |
| self, |
| pixel_values: torch.FloatTensor, |
| image_position_ids: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| r""" |
| image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| """ |
| return self.model.get_image_features(pixel_values, image_position_ids, **kwargs) |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| pixel_values: torch.FloatTensor | None = None, |
| pixel_values_videos: torch.FloatTensor | None = None, |
| input_features: torch.FloatTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| input_features_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| image_position_ids: torch.LongTensor | None = None, |
| video_position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| mm_token_type_ids: torch.LongTensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Gemma4CausalLMOutputWithPast: |
| r""" |
| input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): |
| The attention mask for the input audio. |
| image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): |
| 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. |
| Passed through to the vision encoder for positional embedding computation. |
| """ |
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| pixel_values_videos=pixel_values_videos, |
| input_features=input_features, |
| attention_mask=attention_mask, |
| input_features_mask=input_features_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| mm_token_type_ids=mm_token_type_ids, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| image_position_ids=image_position_ids, |
| video_position_ids=video_position_ids, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: |
| logits = logits / final_logit_softcapping |
| logits = torch.tanh(logits) |
| logits = logits * final_logit_softcapping |
|
|
| loss = None |
| if labels is not None: |
| |
| logits = logits.float() |
| shift_logits = logits[..., :-1, :] |
| shift_labels = labels[..., 1:] |
| if attention_mask is not None: |
| |
| |
| shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) |
| shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() |
| shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
| else: |
| shift_logits = shift_logits.contiguous() |
| shift_labels = shift_labels.contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss() |
|
|
| flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size) |
| flat_labels = shift_labels.view(-1).to(shift_logits.device) |
| loss = loss_fct(flat_logits, flat_labels) |
|
|
| return Gemma4CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| image_hidden_states=outputs.image_hidden_states, |
| audio_hidden_states=outputs.audio_hidden_states, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| inputs_embeds=None, |
| position_ids=None, |
| pixel_values=None, |
| pixel_values_videos=None, |
| input_features=None, |
| attention_mask=None, |
| input_features_mask=None, |
| token_type_ids=None, |
| use_cache=True, |
| logits_to_keep=None, |
| labels=None, |
| is_first_iteration=False, |
| **kwargs, |
| ): |
| |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| token_type_ids=token_type_ids, |
| is_first_iteration=is_first_iteration, |
| **kwargs, |
| ) |
|
|
| |
| if is_first_iteration or not use_cache: |
| model_inputs["pixel_values"] = pixel_values |
| model_inputs["pixel_values_videos"] = pixel_values_videos |
| model_inputs["input_features"] = input_features |
| model_inputs["input_features_mask"] = input_features_mask |
|
|
| return model_inputs |
|
|
| @staticmethod |
| def create_masks_for_generate( |
| config: PreTrainedConfig, |
| inputs_embeds: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None, |
| position_ids: torch.Tensor | None, |
| mm_token_type_ids: torch.Tensor | None = None, |
| is_first_iteration: bool | None = False, |
| **kwargs, |
| ) -> dict: |
| if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": |
| |
| return create_causal_mask_mapping( |
| config, |
| inputs_embeds, |
| attention_mask, |
| past_key_values, |
| position_ids, |
| mm_token_type_ids, |
| is_first_iteration=is_first_iteration, |
| **{k: v for k, v in kwargs.items() if k != "pixel_values"}, |
| ) |
| else: |
| |
| return create_masks_for_generate( |
| config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs |
| ) |
|
|
|
|
| def resolve_gemma4_text_causal_lm_class(config: Gemma4TextConfig): |
| if getattr(config, "use_zero_compute_optimization", False): |
| from .gemma4_optimization import OptimizedGemma4ForCausalLM |
|
|
| return OptimizedGemma4ForCausalLM |
| return Gemma4ForCausalLM |
|
|
|
|
| def build_gemma4_text_causal_lm(config: Gemma4TextConfig): |
| return resolve_gemma4_text_causal_lm_class(config)(config) |
|
|
|
|
| __all__ = [ |
| "build_gemma4_text_causal_lm", |
| "Gemma4AudioModel", |
| "Gemma4ForCausalLM", |
| "Gemma4ForConditionalGeneration", |
| "Gemma4Model", |
| "Gemma4PreTrainedModel", |
| "Gemma4TextModel", |
| "Gemma4VisionModel", |
| "resolve_gemma4_text_causal_lm_class", |
| ] |
|
|