# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Text encoder implementation in PyTorch.""" import typing as t import numpy as np import sentencepiece as spm import torch from torch import nn import torch.nn.functional as F class Tokenizer(object): """A simple tokenizer using SentencePiece.""" def __init__(self, tokenizer_path: str): self.sp = spm.SentencePieceProcessor(model_file=tokenizer_path) # Match tensorflow_text.SentencepieceTokenizer(add_bos=False, add_eos=False) self.sp.SetEncodeExtraOptions("") # Explicitly disable BOS/EOS to match the reference Colab implementation. self._add_bos = False self._add_eos = False def tokenize(self, input_texts, max_len=64): if isinstance(input_texts, str): input_texts = [input_texts] batch_ids = [ self.sp.encode(t.lower(), add_bos=self._add_bos, add_eos=self._add_eos) for t in input_texts ] tokens = np.zeros((len(batch_ids), max_len), dtype=np.int64) for i, ids in enumerate(batch_ids): length = min(len(ids), max_len) tokens[i, :length] = ids[:length] is_padding = (tokens == 0).astype(np.int32) return tokens, is_padding class PositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence. Attributes: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dim: Dimension of the embedding to be generated. """ min_timescale: int = 1 max_timescale: int = 10_000 embedding_dim: int = 0 def __init__(self, embedding_dim: int): super().__init__() self.embedding_dim = embedding_dim def __call__(self, seq_length: int = None, position: torch.tensor = None): """Generates a torch.tensor of sinusoids with different frequencies. Args: seq_length: an optional Python int defining the output sequence length. if the `position` argument is specified. position: [B, seq_length], optional position for each token in the sequence, only required when the sequence is packed. Returns: [B, seqlen, D] if `position` is specified, else [1, seqlen, D] """ if position is None: assert seq_length is not None # [1, seqlen] position = torch.arange(seq_length, dtype=torch.float32)[None, :] else: assert position.ndim == 2, position.shape num_timescales = self.embedding_dim // 2 log_timescale_increment = torch.log( torch.tensor(float(self.max_timescale) / float(self.min_timescale)) ) / torch.maximum( torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1) ) inv_timescales = self.min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment ) scaled_time = position[:, :, None] * inv_timescales[None, None, :] signal = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=2) # Force usage of `np` rather than `jnp` to compute static values at trace # time. signal = F.pad(signal, (0, self.embedding_dim % 2, 0, 0, 0, 0)) return signal class MlpBlockWithMask(nn.Module): """Transformer MLP / feed-forward block that supports masking.""" def __init__( self, mlp_dim: int, d_model: int, use_bias: bool = True, dtype: torch.dtype = torch.float32, activation_fn: nn.Module = nn.GELU, ): super().__init__() self.mlp_dim = mlp_dim self.d_model = d_model self.use_bias = use_bias self.dtype = dtype self.activation_fn = activation_fn self.c_fc = nn.Linear( in_features=self.d_model, out_features=self.mlp_dim, dtype=self.dtype, bias=self.use_bias, ) self.c_proj = nn.Linear( in_features=self.mlp_dim, out_features=self.d_model, dtype=self.dtype, bias=self.use_bias, ) def __call__( self, inputs: torch.Tensor, mlp_mask: torch.Tensor ) -> torch.Tensor: """Applies Transformer MlpBlock with mask module.""" x = self.c_fc(inputs) x = self.activation_fn()(x) x = x * mlp_mask[..., None] # First masking. x = self.c_proj(x) x = x * mlp_mask[..., None] # Second masking. return x class ResidualAttentionBlock(nn.Module): """Transformer residual attention block.""" def __init__( self, d_model: int, n_head: int, mlp_dim: int, dtype: torch.dtype = torch.float32, ): super().__init__() self.d_model = d_model self.n_head = n_head self.mlp_dim = mlp_dim self.dtype = dtype self.attn = nn.MultiheadAttention(d_model, n_head, dtype=self.dtype) self.ln_1 = nn.LayerNorm(d_model, dtype=self.dtype) self.mlp = MlpBlockWithMask( self.mlp_dim, d_model, use_bias=True, dtype=self.dtype, activation_fn=nn.ReLU, ) self.ln_2 = nn.LayerNorm(d_model, dtype=self.dtype) def attention(self, x: torch.Tensor, mask: torch.Tensor): attn_mask = ( mask[:, None, None, :] .repeat(1, self.n_head, x.shape[0], 1) .flatten(0, 1) ) attn_mask[attn_mask == 0] = float('-inf') attn_mask[attn_mask == 1] = 0 return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] def forward(self, x: torch.Tensor, mask: torch.Tensor): x = x + self.attention(self.ln_1(x), mask.permute(1, 0)) x = x + self.mlp(self.ln_2(x), mask) return x, mask class SequentialMultiInput(nn.Sequential): """Sequential module that can take multiple inputs.""" def forward(self, *inputs): for module in self._modules.values(): if isinstance(inputs, tuple): inputs = module(*inputs) else: inputs = module(inputs) return inputs class Transformer(nn.Module): """Transformer implementation.""" def __init__( self, width: int, layers: int, heads: int, mlp_dim: int, dtype: torch.dtype = torch.float32, ): super().__init__() self.width = width self.layers = layers self.heads = heads self.mlp_dim = mlp_dim self.dtype = dtype self.resblocks = SequentialMultiInput(*[ ResidualAttentionBlock(self.width, self.heads, self.mlp_dim, self.dtype) for _ in range(self.layers) ]) def forward(self, x: torch.Tensor, mask: torch.Tensor): return self.resblocks(x, mask)[0] class GlobalAvgPooling(nn.Module): """Performs a simple global pooling over the input with optional paddings. Attributes: pooling_dims: A list of dims to perform pooling over. keepdims: If True, keep dimension of inputs after pooling. """ pooling_dims: t.Sequence[int] epsilon: float = 1e-8 def __init__( self, pooling_dims: t.Sequence[int], epsilon: float = 1e-8 ): super().__init__() self.pooling_dims = pooling_dims self.epsilon = epsilon if not all([p_dims >= 0 for p_dims in self.pooling_dims]): raise ValueError('pooling_dims must be non-negative integers.') def __call__( self, inputs: torch.tensor, compatible_paddings: torch.tensor, ): """Applies global average spatial pooling to inputs. Args: inputs: An input tensor. compatible_paddings: paddings of inputs with shapes compatible with inputs, e.g. compatible_paddings with shape [B, 1] for inputs with shape [B, D]. Returns: Output tensor with global pooling applied. """ padded_value = torch.zeros_like(inputs) padded_value = torch.ones_like(inputs) * padded_value inputs = torch.where(compatible_paddings > 0, padded_value, inputs) valid_inputs = ( torch.sum( 1.0 - compatible_paddings, self.pooling_dims, keepdims=True, dtype=inputs.dtype, ) + self.epsilon ) inputs_sum = torch.sum(inputs, self.pooling_dims, keepdims=True) outputs = torch.divide(inputs_sum, valid_inputs).type(inputs.dtype) outputs = torch.squeeze(outputs, axis=self.pooling_dims) return outputs class TextEncoder(nn.Module): """Text encoder implementation.""" def __init__( self, config: t.Dict[str, int], vocab_size: int, dtype: torch.dtype = torch.float32, scale_sqrt_depth: bool = True, ): super().__init__() self.vocab_size = vocab_size self.dtype = dtype self.scale_sqrt_depth = scale_sqrt_depth # The text tower layers are fixed independent of vision tower size. self.transformer_layers = config['num_layers'] self.embedding_dim = config['hidden_size'] self.transformer_width = config['hidden_size'] self.mlp_dim = config['mlp_dim'] self.transformer_heads = config['num_heads'] self.token_embedding = nn.Embedding( self.vocab_size, self.embedding_dim, dtype=self.dtype ) self.pos_embedder = PositionalEmbedding(embedding_dim=self.embedding_dim) self.transformer = Transformer( width=self.transformer_width, layers=self.transformer_layers, heads=self.transformer_heads, mlp_dim=self.mlp_dim, dtype=self.dtype, ) self.pooling = GlobalAvgPooling(pooling_dims=[1]) self.ln_final = nn.LayerNorm(self.transformer_width, dtype=self.dtype) def __call__( self, ids: torch.tensor, paddings: torch.tensor, ): """Applies TextEncoder module.""" _, seq_length = ids.shape mask = (paddings == 0).type(torch.float32) mask = mask.permute(1, 0) # NL -> LN x = self.token_embedding(ids) if self.scale_sqrt_depth: x = x * (self.embedding_dim**0.5) x = x + self.pos_embedder(seq_length=seq_length).to(x.device) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) x = self.pooling(x, compatible_paddings=paddings[:, :, None]) return x