""" Utility functions for model analysis and summary. Provides: - count_parameters: total and trainable parameter counts - estimate_flops: rough GFLOPs estimate for the tracker - estimate_model_size: model file size in MB - print_model_summary: formatted constraint compliance check """ import torch import torch.nn as nn def count_parameters(model: nn.Module, trainable_only: bool = False) -> int: """Count model parameters. Args: model: PyTorch module trainable_only: if True, count only trainable parameters Returns: Total number of parameters """ if trainable_only: return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters()) def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24, patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4, mlp_ratio=4.0, num_heads=4): """Estimate GFLOPs for a forward pass. Rough calculation based on major operations: - Patch embedding: Conv2d FLOPs - Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP - Prediction heads: Conv2d FLOPs """ import math n_template = (template_size // patch_size) ** 2 # 64 n_search = (search_size // patch_size) ** 2 # 256 S = n_template + n_search # 320 inner_dim = math.ceil(proj_factor * dim / 64) * 64 # 768 num_proj_heads = inner_dim // qkv_proj_blocksize # 192 head_dim = inner_dim // num_proj_heads # 4 total_flops = 0 # Patch embedding: Conv2d(3, dim, patch_size, stride=patch_size) # FLOPs per output pixel: 2 * in_channels * kernel_h * kernel_w * out_channels # For template: (128/16)^2 = 64 output positions # For search: (256/16)^2 = 256 output positions patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search) total_flops += patch_flops # Per mLSTM block for _ in range(depth): # proj_up: Linear(dim, 2*inner_dim): 2 * dim * 2*inner_dim * S total_flops += 2 * dim * 2 * inner_dim * S # Conv1d (depthwise, k=4): 2 * inner_dim * 4 * S total_flops += 2 * inner_dim * 4 * S # Q/K/V (LinearHeadwiseExpand): 3 * num_proj_heads * 2 * head_dim * head_dim * S total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S # Gates: Linear(3*inner, num_heads): 2 * 2 * 3*inner_dim * num_heads * S total_flops += 2 * 2 * 3 * inner_dim * num_heads * S # Attention: Q @ K^T: 2 * num_heads * S * (inner_dim//num_heads) * S head_dim_cell = inner_dim // num_heads total_flops += 2 * num_heads * S * head_dim_cell * S # Attention @ V: 2 * num_heads * S * S * head_dim_cell total_flops += 2 * num_heads * S * S * head_dim_cell # proj_down: Linear(inner_dim, dim): 2 * inner_dim * dim * S total_flops += 2 * inner_dim * dim * S # NOTE: Standard ViL blocks do NOT have a separate MLP/FFN. # The gated output inside mLSTMCell (proj_up → z-gate → proj_down) serves as MLP. # Only TMoE blocks in the last 2 layers add an MLP, but we approximate uniformly. # Prediction heads (rough estimate): ~0.5G total_flops += 0.5e9 gflops = total_flops / 1e9 return gflops def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float: """Estimate model file size in MB. Args: model: PyTorch module dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16) Returns: Estimated size in MB """ num_params = count_parameters(model) size_bytes = num_params * dtype_bytes size_mb = size_bytes / (1024 * 1024) return size_mb def print_model_summary(model, config=None): """Print formatted model summary with constraint compliance. Constraints: - ≤50M parameters - ≤30ms latency - ≤20 GFLOPs - ≤500MB model size """ total_params = count_parameters(model) trainable_params = count_parameters(model, trainable_only=True) config = config or {} gflops = estimate_flops( model, dim=config.get('dim', 384), depth=config.get('depth', 24), patch_size=config.get('patch_size', 16), proj_factor=config.get('proj_factor', 2.0), qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4), mlp_ratio=config.get('mlp_ratio', 4.0), num_heads=config.get('num_heads', 4), ) size_fp32 = estimate_model_size(model, dtype_bytes=4) size_fp16 = estimate_model_size(model, dtype_bytes=2) print("=" * 60) print("ViL Tracker Model Summary") print("=" * 60) print(f"Total Parameters: {total_params:>12,} ({total_params/1e6:.2f}M)") print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)") print(f"Estimated GFLOPs: {gflops:>12.2f}") print(f"Model Size (fp32): {size_fp32:>12.1f} MB") print(f"Model Size (fp16): {size_fp16:>12.1f} MB") print("-" * 60) print("Constraint Compliance:") param_ok = total_params <= 50e6 flop_ok = gflops <= 20 size_ok = size_fp16 <= 500 # Using fp16 for deployment print(f" Parameters ≤50M: {'✅' if param_ok else '❌'} ({total_params/1e6:.2f}M)") print(f" GFLOPs ≤20: {'✅' if flop_ok else '❌'} ({gflops:.2f})") print(f" Size ≤500MB: {'✅' if size_ok else '❌'} ({size_fp16:.1f}MB fp16)") print(f" Latency ≤30ms: ⏳ (requires GPU benchmark)") print("=" * 60) # Per-component breakdown print("\nParameter Breakdown:") for name, module in model.named_children(): mod_params = count_parameters(module) print(f" {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)") return { 'total_params': total_params, 'trainable_params': trainable_params, 'gflops': gflops, 'size_fp32_mb': size_fp32, 'size_fp16_mb': size_fp16, 'param_ok': param_ok, 'flop_ok': flop_ok, 'size_ok': size_ok, }