File size: 6,237 Bytes
dc54883 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
Utility functions for model analysis and summary.
Provides:
- count_parameters: total and trainable parameter counts
- estimate_flops: rough GFLOPs estimate for the tracker
- estimate_model_size: model file size in MB
- print_model_summary: formatted constraint compliance check
"""
import torch
import torch.nn as nn
def count_parameters(model: nn.Module, trainable_only: bool = False) -> int:
"""Count model parameters.
Args:
model: PyTorch module
trainable_only: if True, count only trainable parameters
Returns:
Total number of parameters
"""
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24,
patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4,
mlp_ratio=4.0, num_heads=4):
"""Estimate GFLOPs for a forward pass.
Rough calculation based on major operations:
- Patch embedding: Conv2d FLOPs
- Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP
- Prediction heads: Conv2d FLOPs
"""
import math
n_template = (template_size // patch_size) ** 2 # 64
n_search = (search_size // patch_size) ** 2 # 256
S = n_template + n_search # 320
inner_dim = math.ceil(proj_factor * dim / 64) * 64 # 768
num_proj_heads = inner_dim // qkv_proj_blocksize # 192
head_dim = inner_dim // num_proj_heads # 4
total_flops = 0
# Patch embedding: Conv2d(3, dim, patch_size, stride=patch_size)
# FLOPs per output pixel: 2 * in_channels * kernel_h * kernel_w * out_channels
# For template: (128/16)^2 = 64 output positions
# For search: (256/16)^2 = 256 output positions
patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search)
total_flops += patch_flops
# Per mLSTM block
for _ in range(depth):
# proj_up: Linear(dim, 2*inner_dim): 2 * dim * 2*inner_dim * S
total_flops += 2 * dim * 2 * inner_dim * S
# Conv1d (depthwise, k=4): 2 * inner_dim * 4 * S
total_flops += 2 * inner_dim * 4 * S
# Q/K/V (LinearHeadwiseExpand): 3 * num_proj_heads * 2 * head_dim * head_dim * S
total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S
# Gates: Linear(3*inner, num_heads): 2 * 2 * 3*inner_dim * num_heads * S
total_flops += 2 * 2 * 3 * inner_dim * num_heads * S
# Attention: Q @ K^T: 2 * num_heads * S * (inner_dim//num_heads) * S
head_dim_cell = inner_dim // num_heads
total_flops += 2 * num_heads * S * head_dim_cell * S
# Attention @ V: 2 * num_heads * S * S * head_dim_cell
total_flops += 2 * num_heads * S * S * head_dim_cell
# proj_down: Linear(inner_dim, dim): 2 * inner_dim * dim * S
total_flops += 2 * inner_dim * dim * S
# NOTE: Standard ViL blocks do NOT have a separate MLP/FFN.
# The gated output inside mLSTMCell (proj_up β z-gate β proj_down) serves as MLP.
# Only TMoE blocks in the last 2 layers add an MLP, but we approximate uniformly.
# Prediction heads (rough estimate): ~0.5G
total_flops += 0.5e9
gflops = total_flops / 1e9
return gflops
def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float:
"""Estimate model file size in MB.
Args:
model: PyTorch module
dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16)
Returns:
Estimated size in MB
"""
num_params = count_parameters(model)
size_bytes = num_params * dtype_bytes
size_mb = size_bytes / (1024 * 1024)
return size_mb
def print_model_summary(model, config=None):
"""Print formatted model summary with constraint compliance.
Constraints:
- β€50M parameters
- β€30ms latency
- β€20 GFLOPs
- β€500MB model size
"""
total_params = count_parameters(model)
trainable_params = count_parameters(model, trainable_only=True)
config = config or {}
gflops = estimate_flops(
model,
dim=config.get('dim', 384),
depth=config.get('depth', 24),
patch_size=config.get('patch_size', 16),
proj_factor=config.get('proj_factor', 2.0),
qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4),
mlp_ratio=config.get('mlp_ratio', 4.0),
num_heads=config.get('num_heads', 4),
)
size_fp32 = estimate_model_size(model, dtype_bytes=4)
size_fp16 = estimate_model_size(model, dtype_bytes=2)
print("=" * 60)
print("ViL Tracker Model Summary")
print("=" * 60)
print(f"Total Parameters: {total_params:>12,} ({total_params/1e6:.2f}M)")
print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)")
print(f"Estimated GFLOPs: {gflops:>12.2f}")
print(f"Model Size (fp32): {size_fp32:>12.1f} MB")
print(f"Model Size (fp16): {size_fp16:>12.1f} MB")
print("-" * 60)
print("Constraint Compliance:")
param_ok = total_params <= 50e6
flop_ok = gflops <= 20
size_ok = size_fp16 <= 500 # Using fp16 for deployment
print(f" Parameters β€50M: {'β
' if param_ok else 'β'} ({total_params/1e6:.2f}M)")
print(f" GFLOPs β€20: {'β
' if flop_ok else 'β'} ({gflops:.2f})")
print(f" Size β€500MB: {'β
' if size_ok else 'β'} ({size_fp16:.1f}MB fp16)")
print(f" Latency β€30ms: β³ (requires GPU benchmark)")
print("=" * 60)
# Per-component breakdown
print("\nParameter Breakdown:")
for name, module in model.named_children():
mod_params = count_parameters(module)
print(f" {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)")
return {
'total_params': total_params,
'trainable_params': trainable_params,
'gflops': gflops,
'size_fp32_mb': size_fp32,
'size_fp16_mb': size_fp16,
'param_ok': param_ok,
'flop_ok': flop_ok,
'size_ok': size_ok,
}
|