File size: 6,237 Bytes
dc54883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Utility functions for model analysis and summary.

Provides:
- count_parameters: total and trainable parameter counts
- estimate_flops: rough GFLOPs estimate for the tracker
- estimate_model_size: model file size in MB
- print_model_summary: formatted constraint compliance check
"""

import torch
import torch.nn as nn


def count_parameters(model: nn.Module, trainable_only: bool = False) -> int:
    """Count model parameters.
    
    Args:
        model: PyTorch module
        trainable_only: if True, count only trainable parameters
    Returns:
        Total number of parameters
    """
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def estimate_flops(model, template_size=128, search_size=256, dim=384, depth=24, 
                   patch_size=16, proj_factor=2.0, qkv_proj_blocksize=4,
                   mlp_ratio=4.0, num_heads=4):
    """Estimate GFLOPs for a forward pass.
    
    Rough calculation based on major operations:
    - Patch embedding: Conv2d FLOPs
    - Per mLSTM block: proj_up + conv1d + QKV + attention + proj_down + MLP
    - Prediction heads: Conv2d FLOPs
    """
    import math
    
    n_template = (template_size // patch_size) ** 2  # 64
    n_search = (search_size // patch_size) ** 2      # 256
    S = n_template + n_search                         # 320
    inner_dim = math.ceil(proj_factor * dim / 64) * 64  # 768
    num_proj_heads = inner_dim // qkv_proj_blocksize    # 192
    head_dim = inner_dim // num_proj_heads              # 4
    
    total_flops = 0
    
    # Patch embedding: Conv2d(3, dim, patch_size, stride=patch_size)
    # FLOPs per output pixel: 2 * in_channels * kernel_h * kernel_w * out_channels
    # For template: (128/16)^2 = 64 output positions
    # For search: (256/16)^2 = 256 output positions
    patch_flops = 2 * 3 * patch_size * patch_size * dim * (n_template + n_search)
    total_flops += patch_flops
    
    # Per mLSTM block
    for _ in range(depth):
        # proj_up: Linear(dim, 2*inner_dim): 2 * dim * 2*inner_dim * S
        total_flops += 2 * dim * 2 * inner_dim * S
        
        # Conv1d (depthwise, k=4): 2 * inner_dim * 4 * S
        total_flops += 2 * inner_dim * 4 * S
        
        # Q/K/V (LinearHeadwiseExpand): 3 * num_proj_heads * 2 * head_dim * head_dim * S
        total_flops += 3 * num_proj_heads * 2 * head_dim * head_dim * S
        
        # Gates: Linear(3*inner, num_heads): 2 * 2 * 3*inner_dim * num_heads * S
        total_flops += 2 * 2 * 3 * inner_dim * num_heads * S
        
        # Attention: Q @ K^T: 2 * num_heads * S * (inner_dim//num_heads) * S
        head_dim_cell = inner_dim // num_heads
        total_flops += 2 * num_heads * S * head_dim_cell * S
        
        # Attention @ V: 2 * num_heads * S * S * head_dim_cell
        total_flops += 2 * num_heads * S * S * head_dim_cell
        
        # proj_down: Linear(inner_dim, dim): 2 * inner_dim * dim * S
        total_flops += 2 * inner_dim * dim * S
        
        # NOTE: Standard ViL blocks do NOT have a separate MLP/FFN.
        # The gated output inside mLSTMCell (proj_up β†’ z-gate β†’ proj_down) serves as MLP.
        # Only TMoE blocks in the last 2 layers add an MLP, but we approximate uniformly.
    
    # Prediction heads (rough estimate): ~0.5G
    total_flops += 0.5e9
    
    gflops = total_flops / 1e9
    return gflops


def estimate_model_size(model: nn.Module, dtype_bytes: int = 4) -> float:
    """Estimate model file size in MB.
    
    Args:
        model: PyTorch module
        dtype_bytes: bytes per parameter (4 for fp32, 2 for fp16)
    Returns:
        Estimated size in MB
    """
    num_params = count_parameters(model)
    size_bytes = num_params * dtype_bytes
    size_mb = size_bytes / (1024 * 1024)
    return size_mb


def print_model_summary(model, config=None):
    """Print formatted model summary with constraint compliance.
    
    Constraints:
    - ≀50M parameters
    - ≀30ms latency
    - ≀20 GFLOPs
    - ≀500MB model size
    """
    total_params = count_parameters(model)
    trainable_params = count_parameters(model, trainable_only=True)
    
    config = config or {}
    gflops = estimate_flops(
        model,
        dim=config.get('dim', 384),
        depth=config.get('depth', 24),
        patch_size=config.get('patch_size', 16),
        proj_factor=config.get('proj_factor', 2.0),
        qkv_proj_blocksize=config.get('qkv_proj_blocksize', 4),
        mlp_ratio=config.get('mlp_ratio', 4.0),
        num_heads=config.get('num_heads', 4),
    )
    
    size_fp32 = estimate_model_size(model, dtype_bytes=4)
    size_fp16 = estimate_model_size(model, dtype_bytes=2)
    
    print("=" * 60)
    print("ViL Tracker Model Summary")
    print("=" * 60)
    print(f"Total Parameters:     {total_params:>12,} ({total_params/1e6:.2f}M)")
    print(f"Trainable Parameters: {trainable_params:>12,} ({trainable_params/1e6:.2f}M)")
    print(f"Estimated GFLOPs:     {gflops:>12.2f}")
    print(f"Model Size (fp32):    {size_fp32:>12.1f} MB")
    print(f"Model Size (fp16):    {size_fp16:>12.1f} MB")
    print("-" * 60)
    print("Constraint Compliance:")
    
    param_ok = total_params <= 50e6
    flop_ok = gflops <= 20
    size_ok = size_fp16 <= 500  # Using fp16 for deployment
    
    print(f"  Parameters ≀50M:  {'βœ…' if param_ok else '❌'} ({total_params/1e6:.2f}M)")
    print(f"  GFLOPs ≀20:       {'βœ…' if flop_ok else '❌'} ({gflops:.2f})")
    print(f"  Size ≀500MB:      {'βœ…' if size_ok else '❌'} ({size_fp16:.1f}MB fp16)")
    print(f"  Latency ≀30ms:    ⏳ (requires GPU benchmark)")
    print("=" * 60)
    
    # Per-component breakdown
    print("\nParameter Breakdown:")
    for name, module in model.named_children():
        mod_params = count_parameters(module)
        print(f"  {name:30s} {mod_params:>10,} ({mod_params/1e6:.2f}M)")
    
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'gflops': gflops,
        'size_fp32_mb': size_fp32,
        'size_fp16_mb': size_fp16,
        'param_ok': param_ok,
        'flop_ok': flop_ok,
        'size_ok': size_ok,
    }