| |
| |
|
|
| import os |
| from typing import Optional, Tuple, Dict |
|
|
| import torch |
| import torch.nn as nn |
| from safetensors.torch import load_file as safetensors_load_file |
| from transformers import AutoConfig, AutoModel |
| from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder |
|
|
| from utils import load_json |
|
|
|
|
| def _split_state_dict_from_tmp(sd: Dict[str, torch.Tensor]) \ |
| -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
| """ |
| Model extractor saved tmp.state_dict() where tmp has attributes: |
| - vision_tower |
| - embed_vision (optional) |
| So keys look like: |
| - vision_tower.xxx |
| - embed_vision.xxx |
| """ |
| vt = {} |
| ev = {} |
| for k, v in sd.items(): |
| if k.startswith("vision_tower."): |
| vt[k[len("vision_tower."):]] = v |
| elif k.startswith("embed_vision."): |
| ev[k[len("embed_vision."):]] = v |
| return vt, ev |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VisionTokenReducer(nn.Module): |
| """ |
| Perceiver-style learnable cross-attention pooling with optional bottleneck. |
| |
| Base (no bottleneck): |
| [B,T,D] -> [B,K,D] |
| |
| Bottleneck mode (bottleneck_dim=d): |
| [B,T,D] -> down -> [B,T,d] -> cross-attn -> [B,K,d] -> (optional up) -> [B,K,D] |
| |
| Notes: |
| - num_heads does NOT change parameter count of MultiheadAttention (depends on D only). |
| - perform_norm_latent controls whether to pre-norm the learnable latent queries. |
| """ |
|
|
| def __init__( |
| self, |
| vision_dim: int, |
| num_output_tokens: int, |
| num_heads: int = 4, |
| perform_norm_latent: bool = True, |
| bottleneck_dim: Optional[int] = None, |
| project_back: bool = True, |
| ): |
| super().__init__() |
|
|
| self.vision_dim = int(vision_dim) |
| self.num_output_tokens = int(num_output_tokens) |
| self.num_heads = int(num_heads) |
| self.perform_norm_latent = bool(perform_norm_latent) |
|
|
| self.bottleneck_dim = None if bottleneck_dim is None else int(bottleneck_dim) |
| self.project_back = bool(project_back) |
|
|
| |
| attn_dim = self.vision_dim if self.bottleneck_dim is None else self.bottleneck_dim |
| if attn_dim % self.num_heads != 0: |
| raise ValueError(f"embed_dim ({attn_dim}) must be divisible by num_heads ({self.num_heads})") |
|
|
| |
| if self.bottleneck_dim is None: |
| self.down = None |
| self.up = None |
| else: |
| |
| self.down = nn.Linear(self.vision_dim, attn_dim, bias=False) |
| self.up = nn.Linear(attn_dim, self.vision_dim, bias=False) if self.project_back else None |
|
|
| |
| self.latents = nn.Parameter(torch.randn(self.num_output_tokens, attn_dim) * 0.02) |
|
|
| |
| self.norm_latents = nn.LayerNorm(attn_dim) |
| self.norm_x = nn.LayerNorm(attn_dim) |
|
|
| |
| self.attn = nn.MultiheadAttention( |
| embed_dim=attn_dim, |
| num_heads=self.num_heads, |
| batch_first=True, |
| ) |
|
|
| def init_weights(self, std: float = 0.02): |
| |
| if self.down is not None: |
| nn.init.normal_(self.down.weight, std=std) |
| if self.up is not None: |
| nn.init.normal_(self.up.weight, std=std) |
|
|
| |
| nn.init.normal_(self.latents, std=std) |
|
|
| |
| nn.init.ones_(self.norm_latents.weight) |
| nn.init.zeros_(self.norm_latents.bias) |
| nn.init.ones_(self.norm_x.weight) |
| nn.init.zeros_(self.norm_x.bias) |
|
|
| |
| self.attn._reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, T, D] where D == vision_dim |
| |
| Returns: |
| out: [B, K, D] if (bottleneck_dim is None) or project_back=True |
| [B, K, d] if bottleneck_dim is not None and project_back=False |
| """ |
| if x.dim() != 3: |
| raise ValueError(f"Expected x [B,T,D], got {tuple(x.shape)}") |
| if x.size(-1) != self.vision_dim: |
| raise ValueError(f"Expected last dim D={self.vision_dim}, got {x.size(-1)}") |
|
|
| B = x.size(0) |
|
|
| |
| if self.down is not None: |
| x = self.down(x) |
|
|
| |
| latents = self.latents.unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| if self.perform_norm_latent: |
| latents = self.norm_latents(latents) |
| x = self.norm_x(x) |
|
|
| |
| out, _ = self.attn(query=latents, key=x, value=x) |
|
|
| |
| if self.up is not None: |
| out = self.up(out) |
|
|
| return out |
|
|
|
|
| |
| |
| |
|
|
| class Gemma3nVisionFeatureExtractor(nn.Module): |
| """ |
| Vision-only feature extractor for Gemma-3n that matches transformers' Gemma3nModel.get_image_features(). |
| |
| Input: pixel_values [B, 3, H, W] |
| Output: image_features [B, vision_soft_tokens_per_image, text_hidden_size] |
| """ |
|
|
| def __init__( |
| self, |
| vision_tower: nn.Module, |
| embed_vision: Optional[nn.Module], |
| vision_hidden_size: int, |
| vision_soft_tokens_per_image: int, |
| text_hidden_size: int, |
| num_output_tokens_reduced: Optional[int] = None, |
| num_heads_for_token_reduction: int = 4, |
| perform_norm_latent_for_token_reduction: bool = True, |
| reducer_bottleneck_dim: Optional[int] = None, |
| reducer_project_back: bool = True, |
| ): |
| super().__init__() |
| self.vision_tower = vision_tower |
| self.embed_vision = embed_vision |
| self.vision_hidden_size = int(vision_hidden_size) |
| self.vision_soft_tokens_per_image = int(vision_soft_tokens_per_image) |
| self.text_hidden_size = int(text_hidden_size) |
| self.has_embed_vision = embed_vision is not None |
|
|
| |
| self.vision_tower.requires_grad_(False) |
| if self.embed_vision is not None: |
| self.embed_vision.requires_grad_(False) |
|
|
| |
| if num_output_tokens_reduced is not None: |
| reducer_dim = text_hidden_size if self.has_embed_vision else vision_hidden_size |
| self.reducer = VisionTokenReducer( |
| vision_dim=reducer_dim, |
| num_output_tokens=num_output_tokens_reduced, |
| num_heads=num_heads_for_token_reduction, |
| perform_norm_latent=perform_norm_latent_for_token_reduction, |
| bottleneck_dim=reducer_bottleneck_dim, |
| project_back=reducer_project_back, |
| ) |
| else: |
| self.reducer = None |
|
|
| def init_weights(self, std: float = 0.02): |
| if self.reducer is not None: |
| self.reducer.init_weights(std) |
|
|
| def get_actual_hidden_dim(self) -> int: |
| """ |
| Return the actual feature hidden dimension produced by this extractor. |
| |
| The output dimension depends on: |
| - whether embed_vision is used |
| - whether a reducer is present |
| - reducer bottleneck + project_back configuration |
| |
| Returns: |
| int: feature hidden size of output tokens |
| """ |
|
|
| |
| base_dim = self.text_hidden_size if self.has_embed_vision else self.vision_hidden_size |
|
|
| |
| if self.reducer is None: |
| return base_dim |
|
|
| |
| if self.reducer.bottleneck_dim is None: |
| return base_dim |
|
|
| |
| if self.reducer.project_back: |
| return base_dim |
|
|
| |
| return int(self.reducer.bottleneck_dim) |
|
|
| def train(self, mode: bool = True) -> "Gemma3nVisionFeatureExtractor": |
| """ Override train(): vision is not trainable""" |
| super().train(mode=mode) |
| self.vision_tower.eval() |
| if self.embed_vision is not None: |
| self.embed_vision.eval() |
| return self |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| valid_positions: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| pixel_values: [B, 3, H, W] |
| valid_positions: |
| Indicates which samples have valid images. |
| Supported formats: |
| - BoolTensor [B] where True means "has image" |
| - LongTensor [K] with indices of samples that have images |
| If None: assume all samples have images. |
| |
| Returns: |
| features: [B, T_img, D] |
| vision_mask: [B, T_img] (1=valid vision token, 0=masked out) |
| """ |
| if pixel_values.dim() != 4: |
| raise ValueError(f"pixel_values must be [B,3,H,W], got {tuple(pixel_values.shape)}") |
|
|
| B = pixel_values.size(0) |
| device = next(self.vision_tower.parameters()).device |
| dtype = next(self.vision_tower.parameters()).dtype |
|
|
| |
| |
| |
| if valid_positions is None: |
| valid_mask = torch.ones(B, dtype=torch.bool, device=pixel_values.device) |
| else: |
| if valid_positions.dtype == torch.bool: |
| if valid_positions.shape != (B,): |
| raise ValueError(f"valid_positions (bool) must be [B], got {tuple(valid_positions.shape)}") |
| valid_mask = valid_positions.to(device=pixel_values.device) |
| else: |
| if valid_positions.dim() != 1: |
| raise ValueError(f"valid_positions (indices) must be 1D, got {tuple(valid_positions.shape)}") |
| valid_mask = torch.zeros(B, dtype=torch.bool, device=pixel_values.device) |
| valid_mask[valid_positions.to(device=pixel_values.device, dtype=torch.long)] = True |
|
|
| num_valid = int(valid_mask.sum().item()) |
|
|
| |
| |
| |
| if self.reducer is None: |
| T_img = self.vision_soft_tokens_per_image |
| else: |
| T_img = self.reducer.num_output_tokens |
|
|
| D_out = self.get_actual_hidden_dim() |
|
|
| |
| vision_mask = valid_mask[:, None].expand(B, T_img).to(dtype=torch.long) |
|
|
| |
| if num_valid == 0: |
| features = torch.zeros(B, T_img, D_out, device=device, dtype=dtype) |
| return features, vision_mask |
|
|
| |
| |
| |
| pixel_values_valid = pixel_values[valid_mask].to(device=device, dtype=dtype) |
|
|
| with torch.no_grad(): |
| vision_last = self.vision_tower( |
| pixel_values=pixel_values_valid, |
| do_pooling=False, |
| return_dict=True, |
| ).last_hidden_state |
|
|
| if vision_last.dim() != 4: |
| raise RuntimeError(f"Expected vision last_hidden_state (B,C,h,w), got {tuple(vision_last.shape)}") |
|
|
| Bv, C, h, w = vision_last.shape |
| if Bv != num_valid: |
| raise RuntimeError("Batch size mismatch between valid pixel_values and vision_last") |
| if C != self.vision_hidden_size: |
| raise RuntimeError(f"Expected vision_hidden_size={self.vision_hidden_size}, got C={C}") |
| if h * w != self.vision_soft_tokens_per_image: |
| raise RuntimeError( |
| f"Expected h*w={self.vision_soft_tokens_per_image}, got {h * w}. " |
| f"Check processor image size/crop or config." |
| ) |
|
|
| |
| vision_tokens = vision_last.reshape(Bv, C, self.vision_soft_tokens_per_image).permute(0, 2, 1).contiguous() |
|
|
| |
| vision_tokens = vision_tokens * (self.vision_hidden_size ** 0.5) |
|
|
| |
| |
| |
| if not self.has_embed_vision: |
| valid_features = vision_tokens |
| if self.reducer is not None: |
| valid_features = self.reducer(valid_features) |
| else: |
| with torch.no_grad(): |
| valid_features = self.embed_vision(inputs_embeds=vision_tokens) |
|
|
| if valid_features.shape != (Bv, self.vision_soft_tokens_per_image, self.text_hidden_size): |
| raise RuntimeError( |
| f"Bad output shape {tuple(valid_features.shape)}; expected " |
| f"({Bv}, {self.vision_soft_tokens_per_image}, {self.text_hidden_size})" |
| ) |
|
|
| if self.reducer is not None: |
| valid_features = self.reducer(valid_features) |
|
|
| |
| |
| |
| if valid_features.size(1) != T_img: |
| raise RuntimeError(f"T_img mismatch: expected {T_img}, got {valid_features.size(1)}") |
| if valid_features.size(2) != D_out: |
| raise RuntimeError(f"D_out mismatch: expected {D_out}, got {valid_features.size(2)}") |
|
|
| features = torch.zeros(B, T_img, D_out, device=valid_features.device, dtype=valid_features.dtype) |
| features[valid_mask] = valid_features |
|
|
| return features, vision_mask |
|
|
| @classmethod |
| def from_pretrained_vision_only_dir( |
| cls, |
| model_dir: str, |
| map_location: str = "cpu", |
| num_output_tokens_reduced: Optional[int] = None, |
| num_heads_for_token_reduction: int = 4, |
| perform_norm_latent_for_token_reduction: bool = True, |
| reducer_bottleneck_dim: Optional[int] = None, |
| reducer_project_back: bool = True, |
| ) -> "Gemma3nVisionFeatureExtractor": |
| weights_path = os.path.join(model_dir, "model.safetensors") |
| if not os.path.isfile(weights_path): |
| raise FileNotFoundError(f"Missing weights: {weights_path}") |
|
|
| ve_cfg_path = os.path.join(model_dir, "vision_extractor_config.json") |
| if not os.path.isfile(ve_cfg_path): |
| raise FileNotFoundError(f"Missing {ve_cfg_path}") |
| ve_cfg = load_json(ve_cfg_path) |
|
|
| vision_soft_tokens_per_image = int(ve_cfg.get("vision_soft_tokens_per_image", 256)) |
| vision_hidden_size = int(ve_cfg.get("vision_hidden_size", -1)) |
| text_hidden_size = int(ve_cfg.get("text_hidden_size", -1)) |
| has_embed_vision = bool(ve_cfg.get("has_embed_vision", True)) |
|
|
| if vision_hidden_size <= 0: |
| raise ValueError("vision_hidden_size missing/invalid in vision_extractor_config.json") |
| if has_embed_vision and text_hidden_size <= 0: |
| raise ValueError("text_hidden_size missing/invalid in vision_extractor_config.json") |
|
|
| cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) |
| vision_cfg = getattr(cfg, "vision_config", cfg) |
| text_cfg = getattr(cfg, "text_config", None) |
|
|
| vision_tower = AutoModel.from_config(vision_cfg, trust_remote_code=True) |
|
|
| embed_vision = None |
| if has_embed_vision: |
| if text_cfg is None: |
| raise RuntimeError( |
| "config.json does not contain text_config, but has_embed_vision=True. " |
| "You need a Gemma3nConfig-like config.json in this folder." |
| ) |
| embed_vision = Gemma3nMultimodalEmbedder(vision_cfg, text_cfg) |
|
|
| sd = safetensors_load_file(weights_path, device=map_location) |
|
|
| vt_sd, ev_sd = _split_state_dict_from_tmp(sd) |
| if not vt_sd: |
| raise RuntimeError("No vision_tower.* keys found in model.safetensors") |
| if has_embed_vision and not ev_sd: |
| raise RuntimeError("has_embed_vision=True but no embed_vision.* keys found in model.safetensors") |
|
|
| missing_vt, unexpected_vt = vision_tower.load_state_dict(vt_sd, strict=True) |
| if missing_vt or unexpected_vt: |
| raise RuntimeError(f"vision_tower load mismatch: missing={missing_vt}, unexpected={unexpected_vt}") |
|
|
| if has_embed_vision: |
| missing_ev, unexpected_ev = embed_vision.load_state_dict(ev_sd, strict=True) |
| if missing_ev or unexpected_ev: |
| raise RuntimeError(f"embed_vision load mismatch: missing={missing_ev}, unexpected={unexpected_ev}") |
|
|
| vision_tower.eval() |
| if embed_vision is not None: |
| embed_vision.eval() |
|
|
| model = cls( |
| vision_tower=vision_tower, |
| embed_vision=embed_vision, |
| vision_hidden_size=vision_hidden_size, |
| vision_soft_tokens_per_image=vision_soft_tokens_per_image, |
| text_hidden_size=text_hidden_size if has_embed_vision else vision_hidden_size, |
| num_output_tokens_reduced=num_output_tokens_reduced, |
| num_heads_for_token_reduction=num_heads_for_token_reduction, |
| perform_norm_latent_for_token_reduction=perform_norm_latent_for_token_reduction, |
| reducer_bottleneck_dim=reducer_bottleneck_dim, |
| reducer_project_back=reducer_project_back, |
| ) |
| model.eval() |
| return model |
|
|
|
|
| def _demo_main(): |
| import argparse |
| from PIL import Image |
| from transformers import AutoProcessor |
| from pathlib import Path |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_dir", type=str, default="./model_weights/gemma3n_E2B_vision_only") |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--dtype", type=str, default="float32", choices=["bfloat16", "float16", "float32"]) |
| parser.add_argument("--num_output_tokens_reduced", type=int, default=32) |
| parser.add_argument("--reducer_bottleneck_dim", type=int, default=768) |
| parser.add_argument("--reducer_project_back", action="store_true") |
| args = parser.parse_args() |
|
|
| model_dir = str(Path(args.model_dir).resolve()) |
|
|
| |
| processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) |
|
|
| model = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir( |
| model_dir=model_dir, |
| map_location="cpu", |
| num_output_tokens_reduced=args.num_output_tokens_reduced, |
| num_heads_for_token_reduction=4, |
| reducer_bottleneck_dim=args.reducer_bottleneck_dim, |
| reducer_project_back=args.reducer_project_back, |
| ) |
| model.init_weights() |
| model.to(device=args.device, dtype=args.dtype) |
| model.eval() |
|
|
| def count_params(module): |
| return sum(p.numel() for p in module.parameters()) |
|
|
| vision_params = count_params(model.vision_tower) |
|
|
| embed_params = 0 |
| if model.has_embed_vision and model.embed_vision is not None: |
| embed_params = count_params(model.embed_vision) |
|
|
| reducer_params = 0 |
| if model.reducer is not None: |
| reducer_params = count_params(model.reducer) |
|
|
| frozen_params = vision_params + embed_params |
| total_params = frozen_params + reducer_params |
|
|
| print(f"Vision tower parameters (frozen): {vision_params:,}") |
|
|
| if model.has_embed_vision: |
| print(f"Embed vision parameters (frozen): {embed_params:,}") |
| else: |
| print("Embed vision: NONE") |
|
|
| if model.reducer is not None: |
| print(f"Reducer parameters (trainable): {reducer_params:,}") |
| else: |
| print("Reducer: NONE") |
|
|
| print(f"Total frozen parameters: {frozen_params:,}") |
| print(f"Total trainable parameters: {reducer_params:,}") |
| print(f"Total parameters: {total_params:,}") |
|
|
| img1 = Image.new("RGB", (768, 768), color=(0, 0, 0)) |
| img2 = Image.new("RGB", (768, 768), color=(255, 255, 255)) |
|
|
| inputs = processor( |
| text=["", ""], |
| images=[[img1], [img2]], |
| return_tensors="pt", |
| ) |
|
|
| pixel_values = inputs["pixel_values"].to( |
| device=next(model.parameters()).device, |
| dtype=next(model.parameters()).dtype, |
| ) |
|
|
| print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device) |
|
|
| with torch.no_grad(): |
| feats, masks = model(pixel_values) |
|
|
| print("features:", tuple(feats.shape), feats.dtype, feats.device) |
| print("masks:", tuple(masks.shape), masks.dtype, masks.device) |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|