| """ |
| Utility functions for model analysis and summary. |
| |
| Provides: |
| - count_parameters: total and trainable parameter counts |
| - estimate_flops: rough GFLOPs estimate for the tracker |
| - estimate_model_size: model file size in MB |
| - print_model_summary: formatted constraint compliance check |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def count_parameters(model: nn.Module, trainable_only: bool = False) -> int: |
| """Count model parameters. |
| |
| Args: |
| model: PyTorch module |
| trainable_only: if True, count only trainable parameters |
| Returns: |
| Total number of parameters |
| """ |
| if trainable_only: |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24, |
| patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4, |
| mlp_ratio=4.0, num_heads=4): |
| """Estimate GFLOPs for a forward pass. |
| |
| Rough calculation based on major operations: |
| - Patch embedding: Conv2d FLOPs |
| - Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP |
| - Prediction heads: Conv2d FLOPs |
| """ |
| import math |
| |
| n_template = (template_size // patch_size) ** 2 |
| n_search = (search_size // patch_size) ** 2 |
| S = n_template + n_search |
| inner_dim = math.ceil(proj_factor * dim / 64) * 64 |
| num_proj_heads = inner_dim // qkv_proj_blocksize |
| head_dim = inner_dim // num_proj_heads |
| |
| total_flops = 0 |
| |
| |
| |
| |
| |
| patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search) |
| total_flops += patch_flops |
| |
| |
| for _ in range(depth): |
| |
| total_flops += 2 * dim * 2 * inner_dim * S |
| |
| |
| total_flops += 2 * inner_dim * 4 * S |
| |
| |
| total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S |
| |
| |
| total_flops += 2 * 2 * 3 * inner_dim * num_heads * S |
| |
| |
| head_dim_cell = inner_dim // num_heads |
| total_flops += 2 * num_heads * S * head_dim_cell * S |
| |
| |
| total_flops += 2 * num_heads * S * S * head_dim_cell |
| |
| |
| total_flops += 2 * inner_dim * dim * S |
| |
| |
| |
| |
| |
| |
| total_flops += 0.5e9 |
| |
| gflops = total_flops / 1e9 |
| return gflops |
|
|
|
|
| def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float: |
| """Estimate model file size in MB. |
| |
| Args: |
| model: PyTorch module |
| dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16) |
| Returns: |
| Estimated size in MB |
| """ |
| num_params = count_parameters(model) |
| size_bytes = num_params * dtype_bytes |
| size_mb = size_bytes / (1024 * 1024) |
| return size_mb |
|
|
|
|
| def print_model_summary(model, config=None): |
| """Print formatted model summary with constraint compliance. |
| |
| Constraints: |
| - β€50M parameters |
| - β€30ms latency |
| - β€20 GFLOPs |
| - β€500MB model size |
| """ |
| total_params = count_parameters(model) |
| trainable_params = count_parameters(model, trainable_only=True) |
| |
| config = config or {} |
| gflops = estimate_flops( |
| model, |
| dim=config.get('dim', 384), |
| depth=config.get('depth', 24), |
| patch_size=config.get('patch_size', 16), |
| proj_factor=config.get('proj_factor', 2.0), |
| qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4), |
| mlp_ratio=config.get('mlp_ratio', 4.0), |
| num_heads=config.get('num_heads', 4), |
| ) |
| |
| size_fp32 = estimate_model_size(model, dtype_bytes=4) |
| size_fp16 = estimate_model_size(model, dtype_bytes=2) |
| |
| print("=" * 60) |
| print("ViL Tracker Model Summary") |
| print("=" * 60) |
| print(f"Total Parameters: {total_params:>12,} ({total_params/1e6:.2f}M)") |
| print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)") |
| print(f"Estimated GFLOPs: {gflops:>12.2f}") |
| print(f"Model Size (fp32): {size_fp32:>12.1f} MB") |
| print(f"Model Size (fp16): {size_fp16:>12.1f} MB") |
| print("-" * 60) |
| print("Constraint Compliance:") |
| |
| param_ok = total_params <= 50e6 |
| flop_ok = gflops <= 20 |
| size_ok = size_fp16 <= 500 |
| |
| print(f" Parameters β€50M: {'β
' if param_ok else 'β'} ({total_params/1e6:.2f}M)") |
| print(f" GFLOPs β€20: {'β
' if flop_ok else 'β'} ({gflops:.2f})") |
| print(f" Size β€500MB: {'β
' if size_ok else 'β'} ({size_fp16:.1f}MB fp16)") |
| print(f" Latency β€30ms: β³ (requires GPU benchmark)") |
| print("=" * 60) |
| |
| |
| print("\nParameter Breakdown:") |
| for name, module in model.named_children(): |
| mod_params = count_parameters(module) |
| print(f" {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)") |
| |
| return { |
| 'total_params': total_params, |
| 'trainable_params': trainable_params, |
| 'gflops': gflops, |
| 'size_fp32_mb': size_fp32, |
| 'size_fp16_mb': size_fp16, |
| 'param_ok': param_ok, |
| 'flop_ok': flop_ok, |
| 'size_ok': size_ok, |
| } |
|
|