omar-ah commited on
Commit
dc54883
·
verified ·
1 Parent(s): 4e8b763

Upload vil_tracker/utils/helpers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/utils/helpers.py +170 -0
vil_tracker/utils/helpers.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for model analysis and summary.
3
+
4
+ Provides:
5
+ - count_parameters: total and trainable parameter counts
6
+ - estimate_flops: rough GFLOPs estimate for the tracker
7
+ - estimate_model_size: model file size in MB
8
+ - print_model_summary: formatted constraint compliance check
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ def count_parameters(model: nn.Module, trainable_only: bool = False) -> int:
16
+ """Count model parameters.
17
+
18
+ Args:
19
+ model: PyTorch module
20
+ trainable_only: if True, count only trainable parameters
21
+ Returns:
22
+ Total number of parameters
23
+ """
24
+ if trainable_only:
25
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
26
+ return sum(p.numel() for p in model.parameters())
27
+
28
+
29
+ def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24,
30
+ patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4,
31
+ mlp_ratio=4.0, num_heads=4):
32
+ """Estimate GFLOPs for a forward pass.
33
+
34
+ Rough calculation based on major operations:
35
+ - Patch embedding: Conv2d FLOPs
36
+ - Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP
37
+ - Prediction heads: Conv2d FLOPs
38
+ """
39
+ import math
40
+
41
+ n_template = (template_size // patch_size) ** 2 # 64
42
+ n_search = (search_size // patch_size) ** 2 # 256
43
+ S = n_template + n_search # 320
44
+ inner_dim = math.ceil(proj_factor * dim / 64) * 64 # 768
45
+ num_proj_heads = inner_dim // qkv_proj_blocksize # 192
46
+ head_dim = inner_dim // num_proj_heads # 4
47
+
48
+ total_flops = 0
49
+
50
+ # Patch embedding: Conv2d(3, dim, patch_size, stride=patch_size)
51
+ # FLOPs per output pixel: 2 * in_channels * kernel_h * kernel_w * out_channels
52
+ # For template: (128/16)^2 = 64 output positions
53
+ # For search: (256/16)^2 = 256 output positions
54
+ patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search)
55
+ total_flops += patch_flops
56
+
57
+ # Per mLSTM block
58
+ for _ in range(depth):
59
+ # proj_up: Linear(dim, 2*inner_dim): 2 * dim * 2*inner_dim * S
60
+ total_flops += 2 * dim * 2 * inner_dim * S
61
+
62
+ # Conv1d (depthwise, k=4): 2 * inner_dim * 4 * S
63
+ total_flops += 2 * inner_dim * 4 * S
64
+
65
+ # Q/K/V (LinearHeadwiseExpand): 3 * num_proj_heads * 2 * head_dim * head_dim * S
66
+ total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S
67
+
68
+ # Gates: Linear(3*inner, num_heads): 2 * 2 * 3*inner_dim * num_heads * S
69
+ total_flops += 2 * 2 * 3 * inner_dim * num_heads * S
70
+
71
+ # Attention: Q @ K^T: 2 * num_heads * S * (inner_dim//num_heads) * S
72
+ head_dim_cell = inner_dim // num_heads
73
+ total_flops += 2 * num_heads * S * head_dim_cell * S
74
+
75
+ # Attention @ V: 2 * num_heads * S * S * head_dim_cell
76
+ total_flops += 2 * num_heads * S * S * head_dim_cell
77
+
78
+ # proj_down: Linear(inner_dim, dim): 2 * inner_dim * dim * S
79
+ total_flops += 2 * inner_dim * dim * S
80
+
81
+ # NOTE: Standard ViL blocks do NOT have a separate MLP/FFN.
82
+ # The gated output inside mLSTMCell (proj_up → z-gate → proj_down) serves as MLP.
83
+ # Only TMoE blocks in the last 2 layers add an MLP, but we approximate uniformly.
84
+
85
+ # Prediction heads (rough estimate): ~0.5G
86
+ total_flops += 0.5e9
87
+
88
+ gflops = total_flops / 1e9
89
+ return gflops
90
+
91
+
92
+ def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float:
93
+ """Estimate model file size in MB.
94
+
95
+ Args:
96
+ model: PyTorch module
97
+ dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16)
98
+ Returns:
99
+ Estimated size in MB
100
+ """
101
+ num_params = count_parameters(model)
102
+ size_bytes = num_params * dtype_bytes
103
+ size_mb = size_bytes / (1024 * 1024)
104
+ return size_mb
105
+
106
+
107
+ def print_model_summary(model, config=None):
108
+ """Print formatted model summary with constraint compliance.
109
+
110
+ Constraints:
111
+ - ≤50M parameters
112
+ - ≤30ms latency
113
+ - ≤20 GFLOPs
114
+ - ≤500MB model size
115
+ """
116
+ total_params = count_parameters(model)
117
+ trainable_params = count_parameters(model, trainable_only=True)
118
+
119
+ config = config or {}
120
+ gflops = estimate_flops(
121
+ model,
122
+ dim=config.get('dim', 384),
123
+ depth=config.get('depth', 24),
124
+ patch_size=config.get('patch_size', 16),
125
+ proj_factor=config.get('proj_factor', 2.0),
126
+ qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4),
127
+ mlp_ratio=config.get('mlp_ratio', 4.0),
128
+ num_heads=config.get('num_heads', 4),
129
+ )
130
+
131
+ size_fp32 = estimate_model_size(model, dtype_bytes=4)
132
+ size_fp16 = estimate_model_size(model, dtype_bytes=2)
133
+
134
+ print("=" * 60)
135
+ print("ViL Tracker Model Summary")
136
+ print("=" * 60)
137
+ print(f"Total Parameters: {total_params:>12,} ({total_params/1e6:.2f}M)")
138
+ print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)")
139
+ print(f"Estimated GFLOPs: {gflops:>12.2f}")
140
+ print(f"Model Size (fp32): {size_fp32:>12.1f} MB")
141
+ print(f"Model Size (fp16): {size_fp16:>12.1f} MB")
142
+ print("-" * 60)
143
+ print("Constraint Compliance:")
144
+
145
+ param_ok = total_params <= 50e6
146
+ flop_ok = gflops <= 20
147
+ size_ok = size_fp16 <= 500 # Using fp16 for deployment
148
+
149
+ print(f" Parameters ≤50M: {'✅' if param_ok else '❌'} ({total_params/1e6:.2f}M)")
150
+ print(f" GFLOPs ≤20: {'✅' if flop_ok else '❌'} ({gflops:.2f})")
151
+ print(f" Size ≤500MB: {'✅' if size_ok else '❌'} ({size_fp16:.1f}MB fp16)")
152
+ print(f" Latency ≤30ms: ⏳ (requires GPU benchmark)")
153
+ print("=" * 60)
154
+
155
+ # Per-component breakdown
156
+ print("\nParameter Breakdown:")
157
+ for name, module in model.named_children():
158
+ mod_params = count_parameters(module)
159
+ print(f" {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)")
160
+
161
+ return {
162
+ 'total_params': total_params,
163
+ 'trainable_params': trainable_params,
164
+ 'gflops': gflops,
165
+ 'size_fp32_mb': size_fp32,
166
+ 'size_fp16_mb': size_fp16,
167
+ 'param_ok': param_ok,
168
+ 'flop_ok': flop_ok,
169
+ 'size_ok': size_ok,
170
+ }