soilformer / modelling /embed_vision_gemma3n.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# embed_vision_gemma3n.py
# -*- coding: utf-8 -*-
import os
from typing import Optional, Tuple, Dict
import torch
import torch.nn as nn
from safetensors.torch import load_file as safetensors_load_file
from transformers import AutoConfig, AutoModel
from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder # noqa
from utils import load_json
def _split_state_dict_from_tmp(sd: Dict[str, torch.Tensor]) \
-> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Model extractor saved tmp.state_dict() where tmp has attributes:
- vision_tower
- embed_vision (optional)
So keys look like:
- vision_tower.xxx
- embed_vision.xxx
"""
vt = {}
ev = {}
for k, v in sd.items():
if k.startswith("vision_tower."):
vt[k[len("vision_tower."):]] = v
elif k.startswith("embed_vision."):
ev[k[len("embed_vision."):]] = v
return vt, ev
# ============================================================
# Optional lightweight learnable token reducer
# ============================================================
class VisionTokenReducer(nn.Module):
"""
Perceiver-style learnable cross-attention pooling with optional bottleneck.
Base (no bottleneck):
[B,T,D] -> [B,K,D]
Bottleneck mode (bottleneck_dim=d):
[B,T,D] -> down -> [B,T,d] -> cross-attn -> [B,K,d] -> (optional up) -> [B,K,D]
Notes:
- num_heads does NOT change parameter count of MultiheadAttention (depends on D only).
- perform_norm_latent controls whether to pre-norm the learnable latent queries.
"""
def __init__(
self,
vision_dim: int,
num_output_tokens: int,
num_heads: int = 4,
perform_norm_latent: bool = True,
bottleneck_dim: Optional[int] = None,
project_back: bool = True,
):
super().__init__()
self.vision_dim = int(vision_dim)
self.num_output_tokens = int(num_output_tokens)
self.num_heads = int(num_heads)
self.perform_norm_latent = bool(perform_norm_latent)
self.bottleneck_dim = None if bottleneck_dim is None else int(bottleneck_dim)
self.project_back = bool(project_back)
# Decide the attention working dimension: D (base) or d (bottleneck)
attn_dim = self.vision_dim if self.bottleneck_dim is None else self.bottleneck_dim
if attn_dim % self.num_heads != 0:
raise ValueError(f"embed_dim ({attn_dim}) must be divisible by num_heads ({self.num_heads})")
# Optional projection layers for bottleneck mode
if self.bottleneck_dim is None:
self.down = None
self.up = None
else:
# bias=False keeps it lightweight; switch to True if you prefer
self.down = nn.Linear(self.vision_dim, attn_dim, bias=False)
self.up = nn.Linear(attn_dim, self.vision_dim, bias=False) if self.project_back else None
# Learnable latent tokens (K, attn_dim)
self.latents = nn.Parameter(torch.randn(self.num_output_tokens, attn_dim) * 0.02)
# Separate norms: typically more stable than sharing one LN
self.norm_latents = nn.LayerNorm(attn_dim)
self.norm_x = nn.LayerNorm(attn_dim)
# Cross-attention: query=latents, key/value=x
self.attn = nn.MultiheadAttention(
embed_dim=attn_dim,
num_heads=self.num_heads,
batch_first=True,
)
def init_weights(self, std: float = 0.02):
# Optional bottleneck projections
if self.down is not None:
nn.init.normal_(self.down.weight, std=std)
if self.up is not None:
nn.init.normal_(self.up.weight, std=std)
# Learnable latent queries
nn.init.normal_(self.latents, std=std)
# LayerNorm
nn.init.ones_(self.norm_latents.weight)
nn.init.zeros_(self.norm_latents.bias)
nn.init.ones_(self.norm_x.weight)
nn.init.zeros_(self.norm_x.bias)
# MultiheadAttention: use PyTorch's own reset only
self.attn._reset_parameters() # noqa
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, D] where D == vision_dim
Returns:
out: [B, K, D] if (bottleneck_dim is None) or project_back=True
[B, K, d] if bottleneck_dim is not None and project_back=False
"""
if x.dim() != 3:
raise ValueError(f"Expected x [B,T,D], got {tuple(x.shape)}")
if x.size(-1) != self.vision_dim:
raise ValueError(f"Expected last dim D={self.vision_dim}, got {x.size(-1)}")
B = x.size(0)
# Bottleneck projection if enabled
if self.down is not None:
x = self.down(x) # [B,T,d]
# Expand learnable latents across batch
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # [B,K,attn_dim]
# Pre-norm (optional for latents, always for input tokens)
if self.perform_norm_latent:
latents = self.norm_latents(latents)
x = self.norm_x(x)
# Cross-attention pooling
out, _ = self.attn(query=latents, key=x, value=x) # [B,K,attn_dim]
# Project back to original dim if requested
if self.up is not None:
out = self.up(out) # [B,K,D]
return out
# ============================================================
# Main body
# ============================================================
class Gemma3nVisionFeatureExtractor(nn.Module):
"""
Vision-only feature extractor for Gemma-3n that matches transformers' Gemma3nModel.get_image_features().
Input: pixel_values [B, 3, H, W]
Output: image_features [B, vision_soft_tokens_per_image, text_hidden_size]
"""
def __init__(
self,
vision_tower: nn.Module,
embed_vision: Optional[nn.Module],
vision_hidden_size: int,
vision_soft_tokens_per_image: int,
text_hidden_size: int,
num_output_tokens_reduced: Optional[int] = None,
num_heads_for_token_reduction: int = 4,
perform_norm_latent_for_token_reduction: bool = True,
reducer_bottleneck_dim: Optional[int] = None,
reducer_project_back: bool = True,
):
super().__init__()
self.vision_tower = vision_tower
self.embed_vision = embed_vision
self.vision_hidden_size = int(vision_hidden_size)
self.vision_soft_tokens_per_image = int(vision_soft_tokens_per_image)
self.text_hidden_size = int(text_hidden_size)
self.has_embed_vision = embed_vision is not None
# Freeze vision modules
self.vision_tower.requires_grad_(False)
if self.embed_vision is not None:
self.embed_vision.requires_grad_(False)
# Reduce number of tokens
if num_output_tokens_reduced is not None:
reducer_dim = text_hidden_size if self.has_embed_vision else vision_hidden_size
self.reducer = VisionTokenReducer(
vision_dim=reducer_dim,
num_output_tokens=num_output_tokens_reduced,
num_heads=num_heads_for_token_reduction,
perform_norm_latent=perform_norm_latent_for_token_reduction,
bottleneck_dim=reducer_bottleneck_dim,
project_back=reducer_project_back,
)
else:
self.reducer = None
def init_weights(self, std: float = 0.02):
if self.reducer is not None:
self.reducer.init_weights(std)
def get_actual_hidden_dim(self) -> int:
"""
Return the actual feature hidden dimension produced by this extractor.
The output dimension depends on:
- whether embed_vision is used
- whether a reducer is present
- reducer bottleneck + project_back configuration
Returns:
int: feature hidden size of output tokens
"""
# Base dimension before reducer
base_dim = self.text_hidden_size if self.has_embed_vision else self.vision_hidden_size
# No reducer
if self.reducer is None:
return base_dim
# Reducer without bottleneck
if self.reducer.bottleneck_dim is None:
return base_dim
# Bottleneck reducer
if self.reducer.project_back:
return base_dim
# Bottleneck without projection back
return int(self.reducer.bottleneck_dim)
def train(self, mode: bool = True) -> "Gemma3nVisionFeatureExtractor":
""" Override train(): vision is not trainable"""
super().train(mode=mode)
self.vision_tower.eval()
if self.embed_vision is not None:
self.embed_vision.eval()
return self
def forward(
self,
pixel_values: torch.Tensor,
valid_positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
pixel_values: [B, 3, H, W]
valid_positions:
Indicates which samples have valid images.
Supported formats:
- BoolTensor [B] where True means "has image"
- LongTensor [K] with indices of samples that have images
If None: assume all samples have images.
Returns:
features: [B, T_img, D]
vision_mask: [B, T_img] (1=valid vision token, 0=masked out)
"""
if pixel_values.dim() != 4:
raise ValueError(f"pixel_values must be [B,3,H,W], got {tuple(pixel_values.shape)}")
B = pixel_values.size(0)
device = next(self.vision_tower.parameters()).device
dtype = next(self.vision_tower.parameters()).dtype
# --------------------------------------------------------
# Build per-sample valid-image mask
# --------------------------------------------------------
if valid_positions is None:
valid_mask = torch.ones(B, dtype=torch.bool, device=pixel_values.device)
else:
if valid_positions.dtype == torch.bool:
if valid_positions.shape != (B,):
raise ValueError(f"valid_positions (bool) must be [B], got {tuple(valid_positions.shape)}")
valid_mask = valid_positions.to(device=pixel_values.device)
else:
if valid_positions.dim() != 1:
raise ValueError(f"valid_positions (indices) must be 1D, got {tuple(valid_positions.shape)}")
valid_mask = torch.zeros(B, dtype=torch.bool, device=pixel_values.device)
valid_mask[valid_positions.to(device=pixel_values.device, dtype=torch.long)] = True
num_valid = int(valid_mask.sum().item())
# --------------------------------------------------------
# Figure out final output shape in advance
# --------------------------------------------------------
if self.reducer is None:
T_img = self.vision_soft_tokens_per_image
else:
T_img = self.reducer.num_output_tokens
D_out = self.get_actual_hidden_dim()
# vision_mask always returned for full batch
vision_mask = valid_mask[:, None].expand(B, T_img).to(dtype=torch.long)
# Fast path: no valid image at all
if num_valid == 0:
features = torch.zeros(B, T_img, D_out, device=device, dtype=dtype)
return features, vision_mask
# --------------------------------------------------------
# Run only valid samples through frozen vision stack
# --------------------------------------------------------
pixel_values_valid = pixel_values[valid_mask].to(device=device, dtype=dtype)
with torch.no_grad():
vision_last = self.vision_tower(
pixel_values=pixel_values_valid,
do_pooling=False,
return_dict=True,
).last_hidden_state
if vision_last.dim() != 4:
raise RuntimeError(f"Expected vision last_hidden_state (B,C,h,w), got {tuple(vision_last.shape)}")
Bv, C, h, w = vision_last.shape
if Bv != num_valid:
raise RuntimeError("Batch size mismatch between valid pixel_values and vision_last")
if C != self.vision_hidden_size:
raise RuntimeError(f"Expected vision_hidden_size={self.vision_hidden_size}, got C={C}")
if h * w != self.vision_soft_tokens_per_image:
raise RuntimeError(
f"Expected h*w={self.vision_soft_tokens_per_image}, got {h * w}. "
f"Check processor image size/crop or config."
)
# (Bv, C, h, w) -> (Bv, C, HW) -> (Bv, HW, C)
vision_tokens = vision_last.reshape(Bv, C, self.vision_soft_tokens_per_image).permute(0, 2, 1).contiguous()
# Scale by sqrt(C) (matches Gemma codepath)
vision_tokens = vision_tokens * (self.vision_hidden_size ** 0.5)
# --------------------------------------------------------
# Extract valid-image features only
# --------------------------------------------------------
if not self.has_embed_vision:
valid_features = vision_tokens # [Bv, HW, C]
if self.reducer is not None:
valid_features = self.reducer(valid_features) # [Bv, T_img, C or d]
else:
with torch.no_grad():
valid_features = self.embed_vision(inputs_embeds=vision_tokens)
if valid_features.shape != (Bv, self.vision_soft_tokens_per_image, self.text_hidden_size):
raise RuntimeError(
f"Bad output shape {tuple(valid_features.shape)}; expected "
f"({Bv}, {self.vision_soft_tokens_per_image}, {self.text_hidden_size})"
)
if self.reducer is not None:
valid_features = self.reducer(valid_features)
# --------------------------------------------------------
# Scatter back to full batch; invalid samples stay zero
# --------------------------------------------------------
if valid_features.size(1) != T_img:
raise RuntimeError(f"T_img mismatch: expected {T_img}, got {valid_features.size(1)}")
if valid_features.size(2) != D_out:
raise RuntimeError(f"D_out mismatch: expected {D_out}, got {valid_features.size(2)}")
features = torch.zeros(B, T_img, D_out, device=valid_features.device, dtype=valid_features.dtype)
features[valid_mask] = valid_features
return features, vision_mask
@classmethod
def from_pretrained_vision_only_dir(
cls,
model_dir: str,
map_location: str = "cpu",
num_output_tokens_reduced: Optional[int] = None,
num_heads_for_token_reduction: int = 4,
perform_norm_latent_for_token_reduction: bool = True,
reducer_bottleneck_dim: Optional[int] = None,
reducer_project_back: bool = True,
) -> "Gemma3nVisionFeatureExtractor":
weights_path = os.path.join(model_dir, "model.safetensors")
if not os.path.isfile(weights_path):
raise FileNotFoundError(f"Missing weights: {weights_path}")
ve_cfg_path = os.path.join(model_dir, "vision_extractor_config.json")
if not os.path.isfile(ve_cfg_path):
raise FileNotFoundError(f"Missing {ve_cfg_path}")
ve_cfg = load_json(ve_cfg_path)
vision_soft_tokens_per_image = int(ve_cfg.get("vision_soft_tokens_per_image", 256))
vision_hidden_size = int(ve_cfg.get("vision_hidden_size", -1))
text_hidden_size = int(ve_cfg.get("text_hidden_size", -1))
has_embed_vision = bool(ve_cfg.get("has_embed_vision", True))
if vision_hidden_size <= 0:
raise ValueError("vision_hidden_size missing/invalid in vision_extractor_config.json")
if has_embed_vision and text_hidden_size <= 0:
raise ValueError("text_hidden_size missing/invalid in vision_extractor_config.json")
cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
vision_cfg = getattr(cfg, "vision_config", cfg)
text_cfg = getattr(cfg, "text_config", None)
vision_tower = AutoModel.from_config(vision_cfg, trust_remote_code=True)
embed_vision = None
if has_embed_vision:
if text_cfg is None:
raise RuntimeError(
"config.json does not contain text_config, but has_embed_vision=True. "
"You need a Gemma3nConfig-like config.json in this folder."
)
embed_vision = Gemma3nMultimodalEmbedder(vision_cfg, text_cfg)
sd = safetensors_load_file(weights_path, device=map_location)
vt_sd, ev_sd = _split_state_dict_from_tmp(sd)
if not vt_sd:
raise RuntimeError("No vision_tower.* keys found in model.safetensors")
if has_embed_vision and not ev_sd:
raise RuntimeError("has_embed_vision=True but no embed_vision.* keys found in model.safetensors")
missing_vt, unexpected_vt = vision_tower.load_state_dict(vt_sd, strict=True)
if missing_vt or unexpected_vt:
raise RuntimeError(f"vision_tower load mismatch: missing={missing_vt}, unexpected={unexpected_vt}")
if has_embed_vision:
missing_ev, unexpected_ev = embed_vision.load_state_dict(ev_sd, strict=True)
if missing_ev or unexpected_ev:
raise RuntimeError(f"embed_vision load mismatch: missing={missing_ev}, unexpected={unexpected_ev}")
vision_tower.eval()
if embed_vision is not None:
embed_vision.eval()
model = cls(
vision_tower=vision_tower,
embed_vision=embed_vision,
vision_hidden_size=vision_hidden_size,
vision_soft_tokens_per_image=vision_soft_tokens_per_image,
text_hidden_size=text_hidden_size if has_embed_vision else vision_hidden_size,
num_output_tokens_reduced=num_output_tokens_reduced,
num_heads_for_token_reduction=num_heads_for_token_reduction,
perform_norm_latent_for_token_reduction=perform_norm_latent_for_token_reduction,
reducer_bottleneck_dim=reducer_bottleneck_dim,
reducer_project_back=reducer_project_back,
)
model.eval()
return model
def _demo_main():
import argparse
from PIL import Image
from transformers import AutoProcessor
from pathlib import Path
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="./model_weights/gemma3n_E2B_vision_only")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dtype", type=str, default="float32", choices=["bfloat16", "float16", "float32"])
parser.add_argument("--num_output_tokens_reduced", type=int, default=32)
parser.add_argument("--reducer_bottleneck_dim", type=int, default=768)
parser.add_argument("--reducer_project_back", action="store_true")
args = parser.parse_args()
model_dir = str(Path(args.model_dir).resolve())
# Force local loading
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
model = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
model_dir=model_dir,
map_location="cpu",
num_output_tokens_reduced=args.num_output_tokens_reduced,
num_heads_for_token_reduction=4,
reducer_bottleneck_dim=args.reducer_bottleneck_dim,
reducer_project_back=args.reducer_project_back,
)
model.init_weights()
model.to(device=args.device, dtype=args.dtype)
model.eval()
def count_params(module):
return sum(p.numel() for p in module.parameters())
vision_params = count_params(model.vision_tower)
embed_params = 0
if model.has_embed_vision and model.embed_vision is not None:
embed_params = count_params(model.embed_vision)
reducer_params = 0
if model.reducer is not None:
reducer_params = count_params(model.reducer)
frozen_params = vision_params + embed_params
total_params = frozen_params + reducer_params
print(f"Vision tower parameters (frozen): {vision_params:,}")
if model.has_embed_vision:
print(f"Embed vision parameters (frozen): {embed_params:,}")
else:
print("Embed vision: NONE")
if model.reducer is not None:
print(f"Reducer parameters (trainable): {reducer_params:,}")
else:
print("Reducer: NONE")
print(f"Total frozen parameters: {frozen_params:,}")
print(f"Total trainable parameters: {reducer_params:,}")
print(f"Total parameters: {total_params:,}")
img1 = Image.new("RGB", (768, 768), color=(0, 0, 0))
img2 = Image.new("RGB", (768, 768), color=(255, 255, 255))
inputs = processor(
text=["", ""],
images=[[img1], [img2]],
return_tensors="pt",
)
pixel_values = inputs["pixel_values"].to(
device=next(model.parameters()).device,
dtype=next(model.parameters()).dtype,
)
print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
with torch.no_grad():
feats, masks = model(pixel_values)
print("features:", tuple(feats.shape), feats.dtype, feats.device)
print("masks:", tuple(masks.shape), masks.dtype, masks.device)
if __name__ == "__main__":
_demo_main()