vil-tracker / vil_tracker /utils /helpers.py
omar-ah's picture
Upload vil_tracker/utils/helpers.py with huggingface_hub
dc54883 verified
"""
Utility functions for model analysis and summary.
Provides:
- count_parameters: total and trainable parameter counts
- estimate_flops: rough GFLOPs estimate for the tracker
- estimate_model_size: model file size in MB
- print_model_summary: formatted constraint compliance check
"""
import torch
import torch.nn as nn
def count_parameters(model: nn.Module, trainable_only: bool = False) -> int:
"""Count model parameters.
Args:
model: PyTorch module
trainable_only: if True, count only trainable parameters
Returns:
Total number of parameters
"""
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24,
patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4,
mlp_ratio=4.0, num_heads=4):
"""Estimate GFLOPs for a forward pass.
Rough calculation based on major operations:
- Patch embedding: Conv2d FLOPs
- Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP
- Prediction heads: Conv2d FLOPs
"""
import math
n_template = (template_size // patch_size) ** 2 # 64
n_search = (search_size // patch_size) ** 2 # 256
S = n_template + n_search # 320
inner_dim = math.ceil(proj_factor * dim / 64) * 64 # 768
num_proj_heads = inner_dim // qkv_proj_blocksize # 192
head_dim = inner_dim // num_proj_heads # 4
total_flops = 0
# Patch embedding: Conv2d(3, dim, patch_size, stride=patch_size)
# FLOPs per output pixel: 2 * in_channels * kernel_h * kernel_w * out_channels
# For template: (128/16)^2 = 64 output positions
# For search: (256/16)^2 = 256 output positions
patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search)
total_flops += patch_flops
# Per mLSTM block
for _ in range(depth):
# proj_up: Linear(dim, 2*inner_dim): 2 * dim * 2*inner_dim * S
total_flops += 2 * dim * 2 * inner_dim * S
# Conv1d (depthwise, k=4): 2 * inner_dim * 4 * S
total_flops += 2 * inner_dim * 4 * S
# Q/K/V (LinearHeadwiseExpand): 3 * num_proj_heads * 2 * head_dim * head_dim * S
total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S
# Gates: Linear(3*inner, num_heads): 2 * 2 * 3*inner_dim * num_heads * S
total_flops += 2 * 2 * 3 * inner_dim * num_heads * S
# Attention: Q @ K^T: 2 * num_heads * S * (inner_dim//num_heads) * S
head_dim_cell = inner_dim // num_heads
total_flops += 2 * num_heads * S * head_dim_cell * S
# Attention @ V: 2 * num_heads * S * S * head_dim_cell
total_flops += 2 * num_heads * S * S * head_dim_cell
# proj_down: Linear(inner_dim, dim): 2 * inner_dim * dim * S
total_flops += 2 * inner_dim * dim * S
# NOTE: Standard ViL blocks do NOT have a separate MLP/FFN.
# The gated output inside mLSTMCell (proj_up β†’ z-gate β†’ proj_down) serves as MLP.
# Only TMoE blocks in the last 2 layers add an MLP, but we approximate uniformly.
# Prediction heads (rough estimate): ~0.5G
total_flops += 0.5e9
gflops = total_flops / 1e9
return gflops
def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float:
"""Estimate model file size in MB.
Args:
model: PyTorch module
dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16)
Returns:
Estimated size in MB
"""
num_params = count_parameters(model)
size_bytes = num_params * dtype_bytes
size_mb = size_bytes / (1024 * 1024)
return size_mb
def print_model_summary(model, config=None):
"""Print formatted model summary with constraint compliance.
Constraints:
- ≀50M parameters
- ≀30ms latency
- ≀20 GFLOPs
- ≀500MB model size
"""
total_params = count_parameters(model)
trainable_params = count_parameters(model, trainable_only=True)
config = config or {}
gflops = estimate_flops(
model,
dim=config.get('dim', 384),
depth=config.get('depth', 24),
patch_size=config.get('patch_size', 16),
proj_factor=config.get('proj_factor', 2.0),
qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4),
mlp_ratio=config.get('mlp_ratio', 4.0),
num_heads=config.get('num_heads', 4),
)
size_fp32 = estimate_model_size(model, dtype_bytes=4)
size_fp16 = estimate_model_size(model, dtype_bytes=2)
print("=" * 60)
print("ViL Tracker Model Summary")
print("=" * 60)
print(f"Total Parameters: {total_params:>12,} ({total_params/1e6:.2f}M)")
print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)")
print(f"Estimated GFLOPs: {gflops:>12.2f}")
print(f"Model Size (fp32): {size_fp32:>12.1f} MB")
print(f"Model Size (fp16): {size_fp16:>12.1f} MB")
print("-" * 60)
print("Constraint Compliance:")
param_ok = total_params <= 50e6
flop_ok = gflops <= 20
size_ok = size_fp16 <= 500 # Using fp16 for deployment
print(f" Parameters ≀50M: {'βœ…' if param_ok else '❌'} ({total_params/1e6:.2f}M)")
print(f" GFLOPs ≀20: {'βœ…' if flop_ok else '❌'} ({gflops:.2f})")
print(f" Size ≀500MB: {'βœ…' if size_ok else '❌'} ({size_fp16:.1f}MB fp16)")
print(f" Latency ≀30ms: ⏳ (requires GPU benchmark)")
print("=" * 60)
# Per-component breakdown
print("\nParameter Breakdown:")
for name, module in model.named_children():
mod_params = count_parameters(module)
print(f" {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)")
return {
'total_params': total_params,
'trainable_params': trainable_params,
'gflops': gflops,
'size_fp32_mb': size_fp32,
'size_fp16_mb': size_fp16,
'param_ok': param_ok,
'flop_ok': flop_ok,
'size_ok': size_ok,
}