# embed_vision_gemma3n.py # -*- coding: utf-8 -*- import os from typing import Optional, Tuple, Dict import torch import torch.nn as nn from safetensors.torch import load_file as safetensors_load_file from transformers import AutoConfig, AutoModel from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder # noqa from utils import load_json def _split_state_dict_from_tmp(sd: Dict[str, torch.Tensor]) \ -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Model extractor saved tmp.state_dict() where tmp has attributes: - vision_tower - embed_vision (optional) So keys look like: - vision_tower.xxx - embed_vision.xxx """ vt = {} ev = {} for k, v in sd.items(): if k.startswith("vision_tower."): vt[k[len("vision_tower."):]] = v elif k.startswith("embed_vision."): ev[k[len("embed_vision."):]] = v return vt, ev # ============================================================ # Optional lightweight learnable token reducer # ============================================================ class VisionTokenReducer(nn.Module): """ Perceiver-style learnable cross-attention pooling with optional bottleneck. Base (no bottleneck): [B,T,D] -> [B,K,D] Bottleneck mode (bottleneck_dim=d): [B,T,D] -> down -> [B,T,d] -> cross-attn -> [B,K,d] -> (optional up) -> [B,K,D] Notes: - num_heads does NOT change parameter count of MultiheadAttention (depends on D only). - perform_norm_latent controls whether to pre-norm the learnable latent queries. """ def __init__( self, vision_dim: int, num_output_tokens: int, num_heads: int = 4, perform_norm_latent: bool = True, bottleneck_dim: Optional[int] = None, project_back: bool = True, ): super().__init__() self.vision_dim = int(vision_dim) self.num_output_tokens = int(num_output_tokens) self.num_heads = int(num_heads) self.perform_norm_latent = bool(perform_norm_latent) self.bottleneck_dim = None if bottleneck_dim is None else int(bottleneck_dim) self.project_back = bool(project_back) # Decide the attention working dimension: D (base) or d (bottleneck) attn_dim = self.vision_dim if self.bottleneck_dim is None else self.bottleneck_dim if attn_dim % self.num_heads != 0: raise ValueError(f"embed_dim ({attn_dim}) must be divisible by num_heads ({self.num_heads})") # Optional projection layers for bottleneck mode if self.bottleneck_dim is None: self.down = None self.up = None else: # bias=False keeps it lightweight; switch to True if you prefer self.down = nn.Linear(self.vision_dim, attn_dim, bias=False) self.up = nn.Linear(attn_dim, self.vision_dim, bias=False) if self.project_back else None # Learnable latent tokens (K, attn_dim) self.latents = nn.Parameter(torch.randn(self.num_output_tokens, attn_dim) * 0.02) # Separate norms: typically more stable than sharing one LN self.norm_latents = nn.LayerNorm(attn_dim) self.norm_x = nn.LayerNorm(attn_dim) # Cross-attention: query=latents, key/value=x self.attn = nn.MultiheadAttention( embed_dim=attn_dim, num_heads=self.num_heads, batch_first=True, ) def init_weights(self, std: float = 0.02): # Optional bottleneck projections if self.down is not None: nn.init.normal_(self.down.weight, std=std) if self.up is not None: nn.init.normal_(self.up.weight, std=std) # Learnable latent queries nn.init.normal_(self.latents, std=std) # LayerNorm nn.init.ones_(self.norm_latents.weight) nn.init.zeros_(self.norm_latents.bias) nn.init.ones_(self.norm_x.weight) nn.init.zeros_(self.norm_x.bias) # MultiheadAttention: use PyTorch's own reset only self.attn._reset_parameters() # noqa def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, D] where D == vision_dim Returns: out: [B, K, D] if (bottleneck_dim is None) or project_back=True [B, K, d] if bottleneck_dim is not None and project_back=False """ if x.dim() != 3: raise ValueError(f"Expected x [B,T,D], got {tuple(x.shape)}") if x.size(-1) != self.vision_dim: raise ValueError(f"Expected last dim D={self.vision_dim}, got {x.size(-1)}") B = x.size(0) # Bottleneck projection if enabled if self.down is not None: x = self.down(x) # [B,T,d] # Expand learnable latents across batch latents = self.latents.unsqueeze(0).expand(B, -1, -1) # [B,K,attn_dim] # Pre-norm (optional for latents, always for input tokens) if self.perform_norm_latent: latents = self.norm_latents(latents) x = self.norm_x(x) # Cross-attention pooling out, _ = self.attn(query=latents, key=x, value=x) # [B,K,attn_dim] # Project back to original dim if requested if self.up is not None: out = self.up(out) # [B,K,D] return out # ============================================================ # Main body # ============================================================ class Gemma3nVisionFeatureExtractor(nn.Module): """ Vision-only feature extractor for Gemma-3n that matches transformers' Gemma3nModel.get_image_features(). Input: pixel_values [B, 3, H, W] Output: image_features [B, vision_soft_tokens_per_image, text_hidden_size] """ def __init__( self, vision_tower: nn.Module, embed_vision: Optional[nn.Module], vision_hidden_size: int, vision_soft_tokens_per_image: int, text_hidden_size: int, num_output_tokens_reduced: Optional[int] = None, num_heads_for_token_reduction: int = 4, perform_norm_latent_for_token_reduction: bool = True, reducer_bottleneck_dim: Optional[int] = None, reducer_project_back: bool = True, ): super().__init__() self.vision_tower = vision_tower self.embed_vision = embed_vision self.vision_hidden_size = int(vision_hidden_size) self.vision_soft_tokens_per_image = int(vision_soft_tokens_per_image) self.text_hidden_size = int(text_hidden_size) self.has_embed_vision = embed_vision is not None # Freeze vision modules self.vision_tower.requires_grad_(False) if self.embed_vision is not None: self.embed_vision.requires_grad_(False) # Reduce number of tokens if num_output_tokens_reduced is not None: reducer_dim = text_hidden_size if self.has_embed_vision else vision_hidden_size self.reducer = VisionTokenReducer( vision_dim=reducer_dim, num_output_tokens=num_output_tokens_reduced, num_heads=num_heads_for_token_reduction, perform_norm_latent=perform_norm_latent_for_token_reduction, bottleneck_dim=reducer_bottleneck_dim, project_back=reducer_project_back, ) else: self.reducer = None def init_weights(self, std: float = 0.02): if self.reducer is not None: self.reducer.init_weights(std) def get_actual_hidden_dim(self) -> int: """ Return the actual feature hidden dimension produced by this extractor. The output dimension depends on: - whether embed_vision is used - whether a reducer is present - reducer bottleneck + project_back configuration Returns: int: feature hidden size of output tokens """ # Base dimension before reducer base_dim = self.text_hidden_size if self.has_embed_vision else self.vision_hidden_size # No reducer if self.reducer is None: return base_dim # Reducer without bottleneck if self.reducer.bottleneck_dim is None: return base_dim # Bottleneck reducer if self.reducer.project_back: return base_dim # Bottleneck without projection back return int(self.reducer.bottleneck_dim) def train(self, mode: bool = True) -> "Gemma3nVisionFeatureExtractor": """ Override train(): vision is not trainable""" super().train(mode=mode) self.vision_tower.eval() if self.embed_vision is not None: self.embed_vision.eval() return self def forward( self, pixel_values: torch.Tensor, valid_positions: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: pixel_values: [B, 3, H, W] valid_positions: Indicates which samples have valid images. Supported formats: - BoolTensor [B] where True means "has image" - LongTensor [K] with indices of samples that have images If None: assume all samples have images. Returns: features: [B, T_img, D] vision_mask: [B, T_img] (1=valid vision token, 0=masked out) """ if pixel_values.dim() != 4: raise ValueError(f"pixel_values must be [B,3,H,W], got {tuple(pixel_values.shape)}") B = pixel_values.size(0) device = next(self.vision_tower.parameters()).device dtype = next(self.vision_tower.parameters()).dtype # -------------------------------------------------------- # Build per-sample valid-image mask # -------------------------------------------------------- if valid_positions is None: valid_mask = torch.ones(B, dtype=torch.bool, device=pixel_values.device) else: if valid_positions.dtype == torch.bool: if valid_positions.shape != (B,): raise ValueError(f"valid_positions (bool) must be [B], got {tuple(valid_positions.shape)}") valid_mask = valid_positions.to(device=pixel_values.device) else: if valid_positions.dim() != 1: raise ValueError(f"valid_positions (indices) must be 1D, got {tuple(valid_positions.shape)}") valid_mask = torch.zeros(B, dtype=torch.bool, device=pixel_values.device) valid_mask[valid_positions.to(device=pixel_values.device, dtype=torch.long)] = True num_valid = int(valid_mask.sum().item()) # -------------------------------------------------------- # Figure out final output shape in advance # -------------------------------------------------------- if self.reducer is None: T_img = self.vision_soft_tokens_per_image else: T_img = self.reducer.num_output_tokens D_out = self.get_actual_hidden_dim() # vision_mask always returned for full batch vision_mask = valid_mask[:, None].expand(B, T_img).to(dtype=torch.long) # Fast path: no valid image at all if num_valid == 0: features = torch.zeros(B, T_img, D_out, device=device, dtype=dtype) return features, vision_mask # -------------------------------------------------------- # Run only valid samples through frozen vision stack # -------------------------------------------------------- pixel_values_valid = pixel_values[valid_mask].to(device=device, dtype=dtype) with torch.no_grad(): vision_last = self.vision_tower( pixel_values=pixel_values_valid, do_pooling=False, return_dict=True, ).last_hidden_state if vision_last.dim() != 4: raise RuntimeError(f"Expected vision last_hidden_state (B,C,h,w), got {tuple(vision_last.shape)}") Bv, C, h, w = vision_last.shape if Bv != num_valid: raise RuntimeError("Batch size mismatch between valid pixel_values and vision_last") if C != self.vision_hidden_size: raise RuntimeError(f"Expected vision_hidden_size={self.vision_hidden_size}, got C={C}") if h * w != self.vision_soft_tokens_per_image: raise RuntimeError( f"Expected h*w={self.vision_soft_tokens_per_image}, got {h * w}. " f"Check processor image size/crop or config." ) # (Bv, C, h, w) -> (Bv, C, HW) -> (Bv, HW, C) vision_tokens = vision_last.reshape(Bv, C, self.vision_soft_tokens_per_image).permute(0, 2, 1).contiguous() # Scale by sqrt(C) (matches Gemma codepath) vision_tokens = vision_tokens * (self.vision_hidden_size ** 0.5) # -------------------------------------------------------- # Extract valid-image features only # -------------------------------------------------------- if not self.has_embed_vision: valid_features = vision_tokens # [Bv, HW, C] if self.reducer is not None: valid_features = self.reducer(valid_features) # [Bv, T_img, C or d] else: with torch.no_grad(): valid_features = self.embed_vision(inputs_embeds=vision_tokens) if valid_features.shape != (Bv, self.vision_soft_tokens_per_image, self.text_hidden_size): raise RuntimeError( f"Bad output shape {tuple(valid_features.shape)}; expected " f"({Bv}, {self.vision_soft_tokens_per_image}, {self.text_hidden_size})" ) if self.reducer is not None: valid_features = self.reducer(valid_features) # -------------------------------------------------------- # Scatter back to full batch; invalid samples stay zero # -------------------------------------------------------- if valid_features.size(1) != T_img: raise RuntimeError(f"T_img mismatch: expected {T_img}, got {valid_features.size(1)}") if valid_features.size(2) != D_out: raise RuntimeError(f"D_out mismatch: expected {D_out}, got {valid_features.size(2)}") features = torch.zeros(B, T_img, D_out, device=valid_features.device, dtype=valid_features.dtype) features[valid_mask] = valid_features return features, vision_mask @classmethod def from_pretrained_vision_only_dir( cls, model_dir: str, map_location: str = "cpu", num_output_tokens_reduced: Optional[int] = None, num_heads_for_token_reduction: int = 4, perform_norm_latent_for_token_reduction: bool = True, reducer_bottleneck_dim: Optional[int] = None, reducer_project_back: bool = True, ) -> "Gemma3nVisionFeatureExtractor": weights_path = os.path.join(model_dir, "model.safetensors") if not os.path.isfile(weights_path): raise FileNotFoundError(f"Missing weights: {weights_path}") ve_cfg_path = os.path.join(model_dir, "vision_extractor_config.json") if not os.path.isfile(ve_cfg_path): raise FileNotFoundError(f"Missing {ve_cfg_path}") ve_cfg = load_json(ve_cfg_path) vision_soft_tokens_per_image = int(ve_cfg.get("vision_soft_tokens_per_image", 256)) vision_hidden_size = int(ve_cfg.get("vision_hidden_size", -1)) text_hidden_size = int(ve_cfg.get("text_hidden_size", -1)) has_embed_vision = bool(ve_cfg.get("has_embed_vision", True)) if vision_hidden_size <= 0: raise ValueError("vision_hidden_size missing/invalid in vision_extractor_config.json") if has_embed_vision and text_hidden_size <= 0: raise ValueError("text_hidden_size missing/invalid in vision_extractor_config.json") cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) vision_cfg = getattr(cfg, "vision_config", cfg) text_cfg = getattr(cfg, "text_config", None) vision_tower = AutoModel.from_config(vision_cfg, trust_remote_code=True) embed_vision = None if has_embed_vision: if text_cfg is None: raise RuntimeError( "config.json does not contain text_config, but has_embed_vision=True. " "You need a Gemma3nConfig-like config.json in this folder." ) embed_vision = Gemma3nMultimodalEmbedder(vision_cfg, text_cfg) sd = safetensors_load_file(weights_path, device=map_location) vt_sd, ev_sd = _split_state_dict_from_tmp(sd) if not vt_sd: raise RuntimeError("No vision_tower.* keys found in model.safetensors") if has_embed_vision and not ev_sd: raise RuntimeError("has_embed_vision=True but no embed_vision.* keys found in model.safetensors") missing_vt, unexpected_vt = vision_tower.load_state_dict(vt_sd, strict=True) if missing_vt or unexpected_vt: raise RuntimeError(f"vision_tower load mismatch: missing={missing_vt}, unexpected={unexpected_vt}") if has_embed_vision: missing_ev, unexpected_ev = embed_vision.load_state_dict(ev_sd, strict=True) if missing_ev or unexpected_ev: raise RuntimeError(f"embed_vision load mismatch: missing={missing_ev}, unexpected={unexpected_ev}") vision_tower.eval() if embed_vision is not None: embed_vision.eval() model = cls( vision_tower=vision_tower, embed_vision=embed_vision, vision_hidden_size=vision_hidden_size, vision_soft_tokens_per_image=vision_soft_tokens_per_image, text_hidden_size=text_hidden_size if has_embed_vision else vision_hidden_size, num_output_tokens_reduced=num_output_tokens_reduced, num_heads_for_token_reduction=num_heads_for_token_reduction, perform_norm_latent_for_token_reduction=perform_norm_latent_for_token_reduction, reducer_bottleneck_dim=reducer_bottleneck_dim, reducer_project_back=reducer_project_back, ) model.eval() return model def _demo_main(): import argparse from PIL import Image from transformers import AutoProcessor from pathlib import Path parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, default="./model_weights/gemma3n_E2B_vision_only") parser.add_argument("--device", type=str, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["bfloat16", "float16", "float32"]) parser.add_argument("--num_output_tokens_reduced", type=int, default=32) parser.add_argument("--reducer_bottleneck_dim", type=int, default=768) parser.add_argument("--reducer_project_back", action="store_true") args = parser.parse_args() model_dir = str(Path(args.model_dir).resolve()) # Force local loading processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) model = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir( model_dir=model_dir, map_location="cpu", num_output_tokens_reduced=args.num_output_tokens_reduced, num_heads_for_token_reduction=4, reducer_bottleneck_dim=args.reducer_bottleneck_dim, reducer_project_back=args.reducer_project_back, ) model.init_weights() model.to(device=args.device, dtype=args.dtype) model.eval() def count_params(module): return sum(p.numel() for p in module.parameters()) vision_params = count_params(model.vision_tower) embed_params = 0 if model.has_embed_vision and model.embed_vision is not None: embed_params = count_params(model.embed_vision) reducer_params = 0 if model.reducer is not None: reducer_params = count_params(model.reducer) frozen_params = vision_params + embed_params total_params = frozen_params + reducer_params print(f"Vision tower parameters (frozen): {vision_params:,}") if model.has_embed_vision: print(f"Embed vision parameters (frozen): {embed_params:,}") else: print("Embed vision: NONE") if model.reducer is not None: print(f"Reducer parameters (trainable): {reducer_params:,}") else: print("Reducer: NONE") print(f"Total frozen parameters: {frozen_params:,}") print(f"Total trainable parameters: {reducer_params:,}") print(f"Total parameters: {total_params:,}") img1 = Image.new("RGB", (768, 768), color=(0, 0, 0)) img2 = Image.new("RGB", (768, 768), color=(255, 255, 255)) inputs = processor( text=["", ""], images=[[img1], [img2]], return_tensors="pt", ) pixel_values = inputs["pixel_values"].to( device=next(model.parameters()).device, dtype=next(model.parameters()).dtype, ) print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device) with torch.no_grad(): feats, masks = model(pixel_values) print("features:", tuple(feats.shape), feats.dtype, feats.device) print("masks:", tuple(masks.shape), masks.dtype, masks.device) if __name__ == "__main__": _demo_main()