| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import Any, Union, TypedDict |
|
|
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import PIL.Image |
|
|
|
|
| from transformers import ( |
| AutoTokenizer, |
| BatchFeature, |
| Cache, |
| Qwen3Config, |
| Qwen3ForCausalLM, |
| Qwen3PreTrainedModel, |
| ) |
| from transformers.cache_utils import SlidingWindowCache, StaticCache |
| from transformers.generation.utils import GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3Model |
| from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils import TensorType |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
| import re |
|
|
| from transformers.models.siglip2.modeling_siglip2 import ( |
| Siglip2MLP, |
| ) |
| from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
| from perceptron.tensorstream import ( |
| Event, |
| Stream, |
| TensorStream, |
| TextType, |
| VisionType, |
| create_stream, |
| group_streams, |
| ) |
| from perceptron.tensorstream.ops import ( |
| compute_mrope_pos_tensor, |
| modality_mask, |
| reconstruct_tensor_stream_from_compact_dict, |
| slice as ts_slice, |
| tensor_stream_token_view, |
| ) |
|
|
|
|
| class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): |
| """Vision configuration for Isaac with Pixel Shuffle support. |
| |
| Extends Siglip2VisionConfig with additional fields for pixel shuffle. |
| """ |
|
|
| model_type = "pixel_shuffle_siglip2" |
| base_config_key = "vision_config" |
|
|
| def __init__( |
| self, |
| pixel_shuffle_scale_factor: int = 1, |
| num_patches: int = 256, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor |
| self.num_patches = num_patches |
|
|
|
|
| def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: |
| """Create cumulative sequence lengths for variable-length attention.""" |
| cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) |
| cu_seqlens[1:] = seq_sizes.cumsum(0) |
| max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 |
| return cu_seqlens, max_seqlen |
|
|
|
|
| class Siglip2VariableSequenceEmbeddings(nn.Module): |
| def __init__(self, config: PixelShuffleSiglip2VisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.patch_size = config.patch_size |
|
|
| self.patch_embedding = nn.Linear( |
| in_features=config.num_channels * self.patch_size * self.patch_size, |
| out_features=self.embed_dim, |
| ) |
|
|
| self.num_patches = config.num_patches |
| self.position_embedding_size = int(self.num_patches**0.5) |
| self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) |
|
|
| def positional_embeddings( |
| self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] |
| ) -> torch.Tensor: |
| |
| positional_embeddings = ( |
| self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) |
| .permute(2, 0, 1) |
| .unsqueeze(0) |
| ) |
|
|
| _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches |
| pos_embeds_list = [] |
| mode = "bilinear" |
| align_corners = False |
| antialias = True |
| for spatial_shape in spatial_shapes: |
| height, width = spatial_shape |
| |
| if height > 0 and width > 0: |
| resized_pos_embed = F.interpolate( |
| positional_embeddings, |
| size=(height, width), |
| mode=mode, |
| align_corners=align_corners, |
| antialias=antialias, |
| ) |
| |
| resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) |
| else: |
| |
| resized_pos_embed = positional_embeddings.reshape( |
| self.embed_dim, self.position_embedding_size * self.position_embedding_size |
| ).transpose(0, 1)[: height * width] |
| pos_embeds_list.append(resized_pos_embed) |
|
|
| |
| pos_embeds = torch.cat(pos_embeds_list, dim=0) |
| return pos_embeds |
|
|
| def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): |
| seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches |
|
|
| |
| target_dtype = self.patch_embedding.weight.dtype |
| patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) |
| pos_embeds = self.positional_embeddings(packed_seq_patches) |
|
|
| |
| embeddings = patch_embeds + pos_embeds |
| return embeddings |
|
|
|
|
| class Siglip2VariableLengthAttention(nn.Module): |
| """Custom attention that supports variable-length sequences with flash attention.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.embed_dim // self.num_heads |
| if self.head_dim * self.num_heads != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
| f" {self.num_heads})." |
| ) |
| self.scale = self.head_dim**-0.5 |
| self.dropout = config.attention_dropout |
|
|
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
| def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): |
| batch_size, seq_len, _ = hidden_states.size() |
|
|
| |
| if batch_size != 1: |
| raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") |
| hidden_states = hidden_states.squeeze(0) |
|
|
| |
| orig_dtype = hidden_states.dtype |
|
|
| |
| Q = self.q_proj(hidden_states) |
| K = self.k_proj(hidden_states) |
| V = self.v_proj(hidden_states) |
|
|
| |
| Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) |
| K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) |
| V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) |
|
|
| |
| attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( |
| query=Q, |
| key=K, |
| value=V, |
| cum_seq_q=cu_seqlens, |
| cum_seq_k=cu_seqlens, |
| max_q=max_seqlen, |
| max_k=max_seqlen, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=False, |
| return_debug_mask=False, |
| scale=self.scale, |
| window_size_left=-1, |
| window_size_right=-1, |
| alibi_slopes=None, |
| ) |
|
|
| |
| attn_output = attn_output.reshape(seq_len, self.embed_dim) |
|
|
| |
| if attn_output.dtype != orig_dtype: |
| attn_output = attn_output.to(orig_dtype) |
|
|
| |
| attn_output = self.out_proj(attn_output) |
|
|
| |
| attn_output = attn_output.unsqueeze(0) |
|
|
| return attn_output, None |
|
|
|
|
| class IsaacSiglip2EncoderLayer(nn.Module): |
| """Siglip2 encoder layer with variable-length attention.""" |
|
|
| def __init__(self, config: PixelShuffleSiglip2VisionConfig): |
| super().__init__() |
| self.embed_dim = config.hidden_size |
| self.self_attn = Siglip2VariableLengthAttention(config) |
|
|
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
| self.mlp = Siglip2MLP(config) |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor = None, |
| max_seqlen: int = None, |
| ) -> tuple[torch.FloatTensor]: |
| residual = hidden_states |
|
|
| hidden_states = self.layer_norm1(hidden_states) |
|
|
| hidden_states, attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
|
|
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.layer_norm2(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return (hidden_states,) |
|
|
|
|
| class IsaacEncoder(nn.Module): |
| """Encoder using Isaac encoder layers with variable-length attention support.""" |
|
|
| def __init__(self, config: PixelShuffleSiglip2VisionConfig): |
| super().__init__() |
| self.config = config |
| self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
| def forward( |
| self, |
| inputs_embeds, |
| cu_seqlens: torch.Tensor | None = None, |
| max_seqlen: int | None = None, |
| output_hidden_states: bool = False, |
| ): |
| all_hidden_states = () if output_hidden_states else None |
|
|
| hidden_states = inputs_embeds |
|
|
| for encoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_outputs = encoder_layer( |
| hidden_states, |
| cu_seqlens, |
| max_seqlen, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| return hidden_states, all_hidden_states, None |
|
|
|
|
| def create_pixel_shuffle_index_map( |
| seq_sizes: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| device: torch.device | None = None, |
| ) -> torch.Tensor: |
| """ |
| Build a gather-index map that tells us, for every *output* token after |
| pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. |
| |
| Args |
| ---- |
| seq_sizes : (num_images,) - #patches in each image (row-major order) |
| token_grids : (num_images,2) - (height, width) for every image |
| scale_factor : spatial down-scale factor (≥2) |
| device : (optional) overrides `seq_sizes.device` |
| |
| Returns |
| ------- |
| gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. |
| gather_idx[i, j] is the *flat* index into the *original* |
| packed sequence for the j-th sub-patch that forms the |
| i-th output token. |
| """ |
| if device is None: |
| device = seq_sizes.device |
|
|
| r = int(scale_factor) |
| if r < 2: |
| raise ValueError("`scale_factor` must be ≥ 2") |
|
|
| |
| |
| if not torch.compiler.is_compiling(): |
| if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): |
| raise AssertionError( |
| f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" |
| ) |
|
|
| gather_chunks: list[torch.Tensor] = [] |
| tok_offset = 0 |
|
|
| for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): |
| |
| grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset |
| grid = grid.view(h, w) |
|
|
| |
| |
| grid = grid.view(h, w // r, r) |
| |
| grid = grid.view(h // r, r, w // r, r) |
| |
| grid = grid.permute(0, 2, 1, 3).contiguous() |
| |
| gather_chunks.append(grid.reshape(-1, r * r)) |
|
|
| tok_offset += seq_len |
|
|
| |
| gather_idx = torch.cat(gather_chunks, dim=0) |
| return gather_idx |
|
|
|
|
| def pixel_shuffle_varlen( |
| x: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| ) -> torch.Tensor: |
| r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. |
| |
| Args: |
| x (`torch.Tensor`): |
| Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes |
| produced by stacking image patches. |
| token_grids (`torch.Tensor`): |
| Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes |
| corresponding to each image segment inside `x`. |
| scale_factor (`int`, *optional*, defaults to 1): |
| Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a |
| single embedding channel-group. |
| |
| Returns: |
| `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: |
| `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` |
| if the singleton batch dimension was present. |
| |
| Raises: |
| ValueError: If more than one batch item is provided. |
| """ |
| keep_batch_dim = x.dim() == 3 |
| if keep_batch_dim: |
| if x.size(0) != 1: |
| raise AssertionError("Packed sequence is expected to have batch_size == 1") |
| x_ = x.squeeze(0) |
| else: |
| x_ = x |
|
|
| embed_dim = x_.size(-1) |
| r = int(scale_factor) |
|
|
| |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| gather_idx = create_pixel_shuffle_index_map( |
| seq_sizes=seq_sizes, |
| token_grids=token_grids, |
| scale_factor=r, |
| device=x_.device, |
| ) |
|
|
| |
| gathered = x_[gather_idx] |
|
|
| |
| out = gathered.reshape(gathered.size(0), embed_dim * r * r) |
|
|
| |
| if keep_batch_dim: |
| out = out.unsqueeze(0) |
| return out |
|
|
|
|
| class Siglip2SequenceVisionTransformer(nn.Module): |
| def __init__(self, config: PixelShuffleSiglip2VisionConfig): |
| super().__init__() |
| self.config = config |
| self.embeddings = Siglip2VariableSequenceEmbeddings(config) |
| self.encoder = IsaacEncoder(config) |
| self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor |
|
|
| def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): |
| seq_patches, token_grids = packed_seq_patches |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) |
|
|
| |
| hidden_states = hidden_states.unsqueeze(0) |
|
|
| |
| cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) |
|
|
| |
| hidden_states, _, _ = self.encoder( |
| inputs_embeds=hidden_states, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
|
|
| |
| hidden_states = self.post_layernorm(hidden_states) |
|
|
| if self.pixel_shuffle_scale_factor > 1: |
| hidden_states = pixel_shuffle_varlen( |
| x=hidden_states, |
| token_grids=token_grids, |
| scale_factor=self.pixel_shuffle_scale_factor, |
| ) |
| |
| hidden_states = hidden_states.squeeze(0) |
|
|
| |
| return hidden_states |
|
|
|
|
| |
| |
| |
|
|
| MAX_PIXELS = 60_000_000 |
|
|
| |
| VISION_MEAN = (0.5, 0.5, 0.5) |
| VISION_STD = (0.5, 0.5, 0.5) |
| VISION_SCALE = 1 / 255 |
|
|
|
|
| def _make_writeable(arr: np.ndarray) -> np.ndarray: |
| """Return *arr* itself if it is already writeable, otherwise try to flip the |
| write flag in-place and finally fall back to `arr.copy()`. |
| This guarantees the buffer handed to `torch.from_numpy()` is always |
| writeable, silencing the PyTorch warning about undefined behaviour. |
| """ |
| if arr.flags.writeable: |
| return arr |
|
|
| |
| |
| try: |
| arr.setflags(write=True) |
| return arr |
| except ValueError: |
| |
| return arr.copy() |
|
|
|
|
| def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: |
| if image.width * image.height > MAX_PIXELS: |
| raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") |
| img = image if image.mode == "RGB" else image.convert("RGB") |
| arr = np.asarray(img) |
| arr = _make_writeable(arr) |
| return torch.from_numpy(arr) |
|
|
|
|
| def get_image_size_for_max_num_patches( |
| image_height: int, |
| image_width: int, |
| patch_size: int, |
| max_num_patches: int, |
| min_num_patches: int | None = None, |
| eps: float = 1e-5, |
| pixel_shuffle_scale: int = 1, |
| ) -> tuple[int, int]: |
| r"""Compute a target resolution whose patch grid satisfies patching parametrization. |
| |
| Args: |
| image_height (`int`): |
| Height in pixels of the source image prior to any resizing. |
| image_width (`int`): |
| Width in pixels of the source image prior to any resizing. |
| patch_size (`int`): |
| Size of the square patch used by the vision encoder. |
| max_num_patches (`int`): |
| Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. |
| min_num_patches (`int`, *optional*): |
| Lower bound on the number of patches. When provided the image will be scaled up if necessary. |
| eps (`float`, *optional*, defaults to 1e-5): |
| Convergence tolerance for the internal binary search to determing the target dimensions. |
| pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
| Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. |
| |
| Returns: |
| `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` |
| and respect both the maximum and optional minimum patch-count constraints. |
| """ |
|
|
| def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): |
| scaled_size = scale * original_size |
| divisor = patch_size * pixel_shuffle_scale |
| scaled_size = math.ceil(scaled_size / divisor) * divisor |
| scaled_size = max(divisor, scaled_size) |
| return int(scaled_size) |
|
|
| |
| divisor = patch_size * pixel_shuffle_scale |
| adjusted_height = math.ceil(image_height / divisor) * divisor |
| adjusted_height = max(divisor, adjusted_height) |
| adjusted_width = math.ceil(image_width / divisor) * divisor |
| adjusted_width = max(divisor, adjusted_width) |
|
|
| num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) |
|
|
| if min_num_patches is not None and num_patches < min_num_patches: |
| |
| scale_min, scale_max = 1.0, 100.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches >= min_num_patches: |
| scale_max = scale |
| else: |
| scale_min = scale |
| scale = scale_max |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| return target_height, target_width |
| elif num_patches <= max_num_patches: |
| return adjusted_height, adjusted_width |
| else: |
| |
| scale_min, scale_max = eps / 10, 1.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches <= max_num_patches: |
| scale_min = scale |
| else: |
| scale_max = scale |
| scale = scale_min |
| target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) |
| target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) |
| return target_height, target_width |
|
|
|
|
| _MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) |
| _STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) |
|
|
|
|
| def prepare_image_tensor( |
| image: torch.Tensor, |
| scale: float = VISION_SCALE, |
| ) -> torch.Tensor: |
| r"""Standardize RGB images prior to patch extraction via rescaling and whitening. |
| |
| Args: |
| image (`torch.Tensor`): |
| Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating |
| point if needed. |
| scale (`float`, *optional*, defaults to `VISION_SCALE`): |
| Scalar multiplier applied before normalization. |
| Returns: |
| `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. |
| """ |
| if not torch.is_floating_point(image): |
| image = image.float() |
| rescaled = image * scale |
|
|
| |
| mean_tensor = _MEAN_TENSOR.to(image.device) |
| std_tensor = _STD_TENSOR.to(image.device) |
|
|
| normalized = (rescaled - mean_tensor) / std_tensor |
| return normalized |
|
|
|
|
| def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: |
| r"""Convert normalized images into flattened ViT-style patches. |
| |
| Args: |
| image (`torch.Tensor`): |
| Tensor of shape `(num_images, height, width, channels)`. |
| patch_size (`int`): |
| Edge length of the square patches |
| |
| Returns: |
| `torch.Tensor`: |
| Patch tensor where each position stores the flattened pixels belonging to that patch. |
| |
| Raises: |
| ValueError: If `height` or `width` is not divisible by `patch_size`. |
| """ |
| num_images, height, width, channels = image.shape |
| if height % patch_size or width % patch_size: |
| raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") |
| patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) |
| patches = patches.permute(0, 1, 3, 2, 4, 5) |
| patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) |
| return patches |
|
|
|
|
| def process_vision_for_patches( |
| images: torch.Tensor, |
| patch_size: int, |
| max_num_patches: int, |
| min_num_patches: int | None = None, |
| pixel_shuffle_scale: int = 1, |
| ) -> tuple[torch.Tensor, list[int]]: |
| r"""Resize, normalize, and patchify RGB images for the vision encoder. |
| |
| Args: |
| images (`torch.Tensor`): |
| Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a |
| batch. Channels are expected to be RGB. |
| patch_size (`int`): |
| Edge length of square patches; implictly controls resize grid granularity. |
| max_num_patches (`int`): |
| Maximum number of patches allowed after resizing. |
| min_num_patches (`int`, *optional*): |
| Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. |
| pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
| pixel shuffle scale factor; influences the target grid that the function produces. |
| |
| Returns: |
| `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape |
| `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` |
| encodes effective `(images, height, width)` dimensions after optional pixel shuffling. |
| """ |
| |
| if images.dim() == 3: |
| images = images.unsqueeze(0) |
|
|
| |
| images = images.permute(0, 3, 1, 2) |
|
|
| |
| _, _, orig_height, orig_width = images.shape |
| target_height, target_width = get_image_size_for_max_num_patches( |
| orig_height, |
| orig_width, |
| patch_size, |
| max_num_patches, |
| min_num_patches=min_num_patches, |
| pixel_shuffle_scale=pixel_shuffle_scale, |
| ) |
|
|
| |
| images = F.interpolate( |
| images, |
| size=(target_height, target_width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| |
| images = images.permute(0, 2, 3, 1) |
|
|
| |
| images = prepare_image_tensor(images) |
|
|
| |
| patches = patchify_vision(images, patch_size=patch_size) |
|
|
| |
| n_images, h_patches, w_patches, _ = patches.shape |
| dims_virtual = ( |
| [1, h_patches, w_patches] |
| if pixel_shuffle_scale == 1 |
| else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] |
| ) |
|
|
| return patches, dims_virtual |
|
|
|
|
| def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: |
| """ |
| Returns shape (dim//2,). |
| """ |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| return inv_freq |
|
|
|
|
| def precompute_cos_sin_3d( |
| position_ids: torch.Tensor, |
| inv_freq: torch.Tensor, |
| mrope_half_section: list[int], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| r"""Generate 3D rotary embeddings for multi-axis positions. |
| |
| Args: |
| position_ids (`torch.Tensor`): |
| Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. |
| inv_freq (`torch.Tensor`): |
| Precomputed inverse frequency vector used to derive rotary phases. |
| mrope_half_section (`list[int]`): |
| Sizes the axis-specific frequency blocks. |
| |
| Returns: |
| `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready |
| to be passed into rotary attention layers. |
| """ |
| B = position_ids.shape[1] |
| T = position_ids.shape[2] |
| dim_half = inv_freq.shape[0] |
| device = position_ids.device |
|
|
| |
| cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) |
| sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) |
|
|
| offset = 0 |
| for d in range(3): |
| block_size = mrope_half_section[d] |
| freq_slice = inv_freq[offset : offset + block_size] |
| |
| phase = position_ids[d].unsqueeze(-1).float() * freq_slice |
|
|
| cos_part = phase.cos() |
| sin_part = phase.sin() |
|
|
| |
| cos_3d[:, :, offset : offset + block_size] = cos_part |
| cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part |
| sin_3d[:, :, offset : offset + block_size] = sin_part |
| sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part |
|
|
| offset += block_size |
|
|
| return cos_3d, sin_3d |
|
|
|
|
| class RopeScaling(TypedDict, total=False): |
| rope_type: str |
| factor: float |
| mrope_section: list[int] |
| mrope_interleaved: bool |
| low_freq_factor: float |
| high_freq_factor: float |
| original_max_position_embeddings: int |
|
|
|
|
| class IsaacConfig(Qwen3Config): |
| """Configuration class for Isaac multimodal model.""" |
|
|
| model_type = "isaac" |
| sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} |
|
|
| def __init__( |
| self, |
| vision_config=None, |
| vision_patch_size: int = 16, |
| vision_max_num_patches: int = 256, |
| vision_min_num_patches: int | None = None, |
| pixel_shuffle_scale: int = 1, |
| max_sequence_length: int = 16384, |
| vision_token: str = "<image>", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| if isinstance(vision_config, dict): |
| self.vision_config = self.sub_configs["vision_config"](**vision_config) |
| elif vision_config is None: |
| self.vision_config = self.sub_configs["vision_config"]() |
| else: |
| self.vision_config = vision_config |
|
|
| |
| self.video_patch_size = vision_patch_size |
| self.vision_max_num_patches = vision_max_num_patches |
| self.vision_min_num_patches = vision_min_num_patches |
| self.pixel_shuffle_scale = pixel_shuffle_scale |
|
|
| |
| self.max_sequence_length = max_sequence_length |
| self.vision_token = vision_token |
|
|
|
|
| |
| |
| |
|
|
|
|
| def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: |
| r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. |
| |
| Args: |
| tokenizer (`AutoTokenizer`): |
| Tokenizer used to convert text into model vocabulary ids. |
| text (`str`): |
| Plain-text fragment to encode. |
| time (`float`, *optional*, defaults to 0.0): |
| Timeline coordinate associated with the event. Both start and end times use the same value because text |
| segments are instantaneous in the scheduler. |
| |
| Returns: |
| `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching |
| metadata so that downstream processors can compute modality-specific embeddings. |
| """ |
| tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) |
|
|
| |
| num_tokens = len(tokens) |
| dims_virtual = [num_tokens, 1] |
| dims_real = dims_virtual.copy() |
|
|
| |
| |
| if tokens.dim() == 1: |
| tokens = tokens.unsqueeze(-1) |
|
|
| return Event( |
| data=tokens, |
| type=TextType.text, |
| time=(time, time), |
| dims_virtual=dims_virtual, |
| dims_real=dims_real, |
| idx_range=(0, num_tokens), |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class IsaacProcessor(ProcessorMixin): |
| attributes = ["tokenizer"] |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
|
|
| def __init__( |
| self, |
| tokenizer: Qwen2Tokenizer, |
| config: IsaacConfig | dict, |
| ): |
| super().__init__(tokenizer) |
| self.tokenizer = tokenizer |
|
|
| if isinstance(config, dict): |
| config = IsaacConfig(**config) |
| self.config = config |
|
|
| |
| self.vision_token = config.vision_token |
|
|
| |
| self.max_sequence_length = config.max_sequence_length |
|
|
| |
| self.patch_size = config.video_patch_size |
| self.max_num_patches = config.vision_max_num_patches |
| self.min_num_patches = config.vision_min_num_patches |
| self.pixel_shuffle_scale = config.pixel_shuffle_scale |
|
|
| def apply_chat_template( |
| self, |
| messages: list[dict[str, Any]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = False, |
| **kwargs, |
| ) -> Any: |
| return self.tokenizer.apply_chat_template( |
| messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs |
| ) |
|
|
| def build_event_stream_simple( |
| self, |
| text: str, |
| images: list[PIL.Image.Image] | None = None, |
| ) -> Stream: |
| events = [] |
| |
| |
|
|
| pattern = re.escape(self.vision_token) |
| parts = re.split(f"({pattern})", text) |
|
|
| image_idx = 0 |
| for current_time, part in enumerate(parts): |
| if part == self.vision_token: |
| |
| if image_idx < len(images): |
| |
| image_tensor = extract_image_pil(images[image_idx]) |
| if image_tensor is not None: |
| |
| vision_event = Event( |
| data=image_tensor.unsqueeze(0), |
| type=VisionType.image, |
| time=(current_time, current_time), |
| ) |
| events.append(vision_event) |
| image_idx += 1 |
| elif part: |
| |
| text_event = create_text_event(self.tokenizer, part, time=current_time) |
| events.append(text_event) |
|
|
| |
| if any(event.type == VisionType.image for event in events): |
| |
| text_events = [event for event in events if event.type == TextType.text] |
| vision_events = [event for event in events if event.type == VisionType.image] |
|
|
| |
| processed_vision_events = [] |
| for vision_event in vision_events: |
| |
| patches, dims_virtual = process_vision_for_patches( |
| vision_event.data.squeeze(0), |
| patch_size=self.patch_size, |
| max_num_patches=self.max_num_patches, |
| min_num_patches=self.min_num_patches, |
| pixel_shuffle_scale=self.pixel_shuffle_scale, |
| ) |
|
|
| |
| vision_event.data = patches.unsqueeze(1) |
| vision_event.dims_virtual = dims_virtual |
| vision_event.dims_real = ( |
| dims_virtual |
| if self.pixel_shuffle_scale == 1 |
| else [ |
| dims_virtual[0], |
| dims_virtual[1] * self.pixel_shuffle_scale, |
| dims_virtual[2] * self.pixel_shuffle_scale, |
| ] |
| ) |
| vision_event.idx_range = (0, math.prod(dims_virtual)) |
|
|
| |
| vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) |
| processed_vision_events.append(vision_event) |
|
|
| events = text_events + processed_vision_events |
|
|
| |
| return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) |
|
|
| def __call__( |
| self, |
| text: Union[str, list[str]], |
| images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, |
| return_tensors: str | TensorType | None = TensorType.PYTORCH, |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Process text and images into TensorStream format. |
| Args: |
| text: Input text or list of texts with vision tokens |
| images: PIL image or list of images (optional) |
| return_tensors: Format for output tensors |
| |
| Returns: |
| BatchFeature with input_ids and tensor_stream |
| """ |
| |
| if isinstance(text, str): |
| texts = [text] |
| else: |
| texts = text |
|
|
| if images is not None: |
| if isinstance(images, PIL.Image.Image): |
| images_list = [images] |
| else: |
| images_list = images |
| else: |
| images_list = None |
|
|
| if len(texts) != 1: |
| raise ValueError("IsaacProcessor currently supports batch_size=1") |
| if images_list is not None: |
| |
| vision_token_count = texts[0].count(self.vision_token) |
| if vision_token_count != len(images_list): |
| raise ValueError( |
| f"Number of {self.vision_token} tokens in text ({vision_token_count}) " |
| f"must match number of images ({len(images_list)})" |
| ) |
|
|
| |
| stream = self.build_event_stream_simple( |
| text=texts[0], |
| images=images_list, |
| ) |
|
|
| |
| tensor_stream = TensorStream([stream]) |
|
|
| |
| _, T = tensor_stream.shape |
| if T > self.max_sequence_length: |
| tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) |
|
|
| |
| tokens = tensor_stream_token_view(tensor_stream) |
| if return_tensors in (TensorType.PYTORCH, "pt"): |
| input_ids = torch.as_tensor(tokens, dtype=torch.long) |
| else: |
| input_ids = tokens |
|
|
| data = { |
| "input_ids": input_ids, |
| "tensor_stream": tensor_stream, |
| } |
|
|
| return BatchFeature(data=data) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: |
| r"""Create 3D positional indices for token input. |
| |
| Args: |
| input_ids (`torch.Tensor`): |
| Tensor of shape `(batch_size, seq_len)` containing token ids. |
| |
| Returns: |
| `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the |
| 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. |
| """ |
| batch_size, seq_length = input_ids.shape |
| position_ids = torch.arange(seq_length, device=input_ids.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) |
| return position_ids |
|
|
|
|
| class IsaacRotaryEmbedding(nn.Module): |
| def __init__(self, config: IsaacConfig, device=None): |
| super().__init__() |
|
|
| |
| self.hidden_size = config.hidden_size |
| self.num_attention_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
|
|
| |
| rope_scaling = getattr(config, "rope_scaling", None) or {} |
|
|
| |
| self.rope_type = rope_scaling.get("rope_type", "default") |
|
|
| self.mrope_section = [ |
| self.head_dim // 4, |
| self.head_dim // 8, |
| self.head_dim // 8, |
| ] |
|
|
| rope_base = getattr(config, "rope_theta", 10000.0) |
| inv_freq = precompute_inv_freq(rope_base, self.head_dim) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| with torch.no_grad(): |
| |
| not_spatial = ~(modality_tensor == VisionType.image.value) |
| |
| data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) |
| |
| data_1d = data_1d.expand(-1, position_ids.shape[-1]) |
| position_ids = position_ids.clone() |
| position_ids[not_spatial] = data_1d |
| position_ids = position_ids.permute(2, 0, 1) |
| cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) |
|
|
| return cos, sin |
|
|
|
|
| class IsaacModel(Qwen3Model): |
| def __init__(self, config: IsaacConfig): |
| super().__init__(config) |
| text_cfg = getattr(config, "get_text_config", lambda: config)() |
| self.layers = torch.nn.ModuleList( |
| [Qwen3DecoderLayer(text_cfg, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) |
|
|
| vision_cfg = config.vision_config |
| if vision_cfg is None: |
| raise ValueError("IsaacConfig should always have vision_config") |
|
|
| hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) |
| self.vision_embedding = nn.Sequential( |
| Siglip2SequenceVisionTransformer(vision_cfg), |
| nn.Linear( |
| hidden_dim, |
| 4 * hidden_dim, |
| bias=False, |
| ), |
| nn.SiLU(), |
| nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), |
| ) |
|
|
| |
| self.embed_fns = { |
| TextType: self.embed_text_tokens, |
| VisionType: self.embed_vision, |
| } |
|
|
| def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """Embed text tokens, squeezing singleton dimensions.""" |
| |
| h = self.embed_tokens(token_ids) |
| if h.dim() >= 2 and h.size(-2) == 1: |
| h = h[..., 0, :] |
| return h |
|
|
| def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| """Embed vision tokens using the vision encoder.""" |
| |
| return self.vision_embedding(vision_tokens) |
|
|
| def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: |
| """ |
| Embed each modality stream independently, preserving the original TensorStream |
| structure. |
| """ |
| flat_stream = tensor_stream.flat_stream() |
| per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) |
| per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} |
|
|
| |
| token_grids = defaultdict(list) |
| for stream in tensor_stream.streams: |
| for event in stream: |
| token_grids[event.type].append(event.dims(virtual=False)) |
|
|
| embedded_compact = {} |
| for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): |
| if stream_type.modality == VisionType: |
| |
| grids = token_grids.get(stream_type, []) |
| if len(grids) == 0: |
| input_tensor = modality_payload_tensor |
| else: |
| token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] |
| input_tensor = (modality_payload_tensor, token_grids_tensor) |
| embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) |
| else: |
| embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) |
|
|
| |
| embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) |
| h = embedded_ts.compact() |
| return h |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| tensor_stream: TensorStream | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| modality_tensor: torch.LongTensor | None = None, |
| past_key_values: list[torch.FloatTensor] | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs, |
| ) -> tuple | BaseModelOutputWithPast: |
| """ |
| Forward pass with MRoPE position embeddings. |
| |
| Computes position embeddings once and passes them through all layers. |
| """ |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if tensor_stream is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both tensor_stream and inputs_embeds") |
| elif tensor_stream is not None: |
| |
| inputs_embeds = self.embed_stream(tensor_stream) |
| |
| if modality_tensor is None: |
| modality_tensor = modality_mask(tensor_stream) |
| elif input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| |
| if modality_tensor is None: |
| batch_size, seq_length = input_ids.shape |
| modality_tensor = torch.full( |
| (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long |
| ) |
| elif inputs_embeds is None: |
| raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") |
|
|
| |
| if position_ids is None: |
| if tensor_stream is not None: |
| position_ids = compute_mrope_pos_tensor(tensor_stream) |
| else: |
| position_ids = compute_position_ids_input_ids(input_ids) |
|
|
| |
| cos, sin = self.rotary_emb(position_ids, modality_tensor) |
| cos = cos.to(inputs_embeds.dtype) |
| sin = sin.to(inputs_embeds.dtype) |
|
|
| |
| if attention_mask is not None: |
| attention_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, False |
| ) |
|
|
| |
| hidden_states = inputs_embeds |
|
|
| for decoder_layer in self.layers: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=(cos, sin), |
| **kwargs, |
| ) |
|
|
| hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs |
|
|
| |
| hidden_states = self.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| output_attentions: bool = False, |
| ): |
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and past_key_values is not None: |
| is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] |
| if is_padding_right: |
| raise ValueError( |
| "You are attempting to perform batched generation with padding_side='right'" |
| " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " |
| " call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
| ) |
| if attention_mask is not None and 0.0 in attention_mask: |
| return attention_mask |
| return None |
|
|
| |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| using_static_cache = isinstance(past_key_values, StaticCache) |
| using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) |
|
|
| |
| if ( |
| self.config._attn_implementation == "sdpa" |
| and not (using_static_cache or using_sliding_window_cache) |
| and not output_attentions |
| ): |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| attention_mask, |
| inputs_embeds=input_tensor, |
| past_key_values_length=past_seen_tokens, |
| sliding_window=self.config.sliding_window, |
| is_training=self.training, |
| ): |
| return None |
|
|
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| |
| if using_sliding_window_cache or using_static_cache: |
| target_length = past_key_values.get_max_cache_shape() |
| |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else past_seen_tokens + sequence_length + 1 |
| ) |
|
|
| |
| causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| device=device, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| config=self.config, |
| past_key_values=past_key_values, |
| ) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type in ["cuda", "xpu", "npu"] |
| and not output_attentions |
| ): |
| |
| |
| |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
| @staticmethod |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| config: Qwen3Config, |
| past_key_values: Cache, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| device (`torch.device`): |
| The device to place the 4D attention mask on. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| config (`Qwen3Config`): |
| The model's configuration class |
| past_key_values (`Cache`): |
| The cache class that is being used currently to generate |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = torch.full( |
| (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| ) |
| diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| if config.sliding_window is not None: |
| |
| |
| if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: |
| sliding_attend_mask = torch.arange(target_length, device=device) <= ( |
| cache_position.reshape(-1, 1) - config.sliding_window |
| ) |
| diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
| causal_mask *= diagonal_attend_mask |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| if attention_mask.shape[-1] > target_length: |
| attention_mask = attention_mask[:, :target_length] |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| causal_mask.device |
| ) |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
| return causal_mask |
|
|
|
|
|
|
| class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): |
| """Isaac multimodal model for conditional generation.""" |
|
|
| config_class = IsaacConfig |
|
|
| def __init__(self, config: IsaacConfig): |
| Qwen3PreTrainedModel.__init__(self, config) |
| self.model = IsaacModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| self.rope_deltas = None |
|
|
| self.config = config |
|
|
| def get_rope_index( |
| self, |
| input_ids: torch.Tensor | None, |
| tensor_stream: TensorStream | None, |
| attention_mask: torch.Tensor | None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute MRoPE position ids from a TensorStream (or 1D fallback). |
| |
| Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. |
| rope_deltas is (B,1) used to advance positions in decode. |
| """ |
| |
| if tensor_stream is None and input_ids is None: |
| raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") |
|
|
| if tensor_stream is not None: |
| pos_3d = compute_mrope_pos_tensor(tensor_stream) |
| else: |
| pos_3d = compute_position_ids_input_ids(input_ids) |
| B, L, _ = pos_3d.shape |
|
|
| |
| m_per_batch = pos_3d.amax(dim=(1, 2)) |
|
|
| |
| if attention_mask is None: |
| seq_lens = torch.full_like(m_per_batch, L) |
| else: |
| seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) |
|
|
| rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) |
| return pos_3d, rope_deltas |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| tensor_stream: TensorStream | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: list[torch.FloatTensor] | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs, |
| ) -> tuple | CausalLMOutputWithPast: |
| """ |
| Forward pass for conditional generation supporting both standard inputs and TensorStream. |
| Uses our embed_stream approach for multimodal inputs. |
| """ |
|
|
| |
| if tensor_stream is not None: |
| input_ids = None |
| if input_ids is None and inputs_embeds is None and tensor_stream is None: |
| raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") |
|
|
| |
| |
| |
| if position_ids is None and tensor_stream is not None: |
| position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) |
| elif position_ids is None and input_ids is not None: |
| |
| position_ids = compute_position_ids_input_ids(input_ids) |
| if cache_position is not None and self.rope_deltas is not None: |
| |
| |
| rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) |
| else: |
| rope_delta = 0 |
| if cache_position is not None and not isinstance(rope_delta, int): |
| batch_size = input_ids.shape[0] |
| rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) |
| position_ids = position_ids.add(rope_delta) |
|
|
| if tensor_stream is not None: |
| modality_tensor = modality_mask(tensor_stream) |
| else: |
| batch_size, seq_len = input_ids.shape |
| modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| tensor_stream=tensor_stream, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| modality_tensor=modality_tensor, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=None, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: list[torch.FloatTensor] | None = None, |
| attention_mask: torch.Tensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| tensor_stream: TensorStream | None = None, |
| cache_position: torch.LongTensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| use_cache: bool = True, |
| **kwargs, |
| ) -> dict[str, Any]: |
| """ |
| Prepare inputs for generation, handling TensorStream inputs properly. |
| """ |
| |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| |
| if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): |
| model_inputs["tensor_stream"] = tensor_stream |
| |
| model_inputs["position_ids"] = None |
| |
| if cache_position is not None and cache_position[0] != 0: |
| model_inputs["tensor_stream"] = None |
| return model_inputs |
|
|
| def can_generate(self) -> bool: |
| return True |
|
|
|
|
| __all__ = [ |
| "IsaacConfig", |
| "IsaacModel", |
| "IsaacForConditionalGeneration", |
| "IsaacProcessor", |
| ] |
|
|