omar-ah commited on
Commit
fc9248f
verified
1 Parent(s): 01f95f3

Upload vil_tracker/training/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/training/train.py +244 -0
vil_tracker/training/train.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for ViL Tracker.
3
+
4
+ Two-phase training:
5
+ Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet
6
+ - Full model training with focal + GIoU + size losses
7
+ - ACL curriculum (progressive difficulty ramp-up)
8
+ - 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs
9
+
10
+ Phase 2: Fine-tuning with TMoE and distillation
11
+ - Freeze shared experts in TMoE blocks
12
+ - Add contrastive loss on temporal features
13
+ - Optional AFKD distillation from MCITrack teacher
14
+ - 100 epochs, lr=1e-5
15
+
16
+ Hardware: Designed for A10G (24GB) or A100 (80GB)
17
+ """
18
+
19
+ import os
20
+ import json
21
+ import math
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.optim as optim
25
+ from torch.utils.data import DataLoader
26
+ from torch.cuda.amp import autocast, GradScaler
27
+
28
+
29
+ def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1):
30
+ """Build AdamW optimizer with layer-wise learning rate decay."""
31
+ backbone_params = []
32
+ head_params = []
33
+ other_params = []
34
+
35
+ for name, param in model.named_parameters():
36
+ if not param.requires_grad:
37
+ continue
38
+ if 'backbone' in name:
39
+ backbone_params.append(param)
40
+ elif 'center_head' in name or 'uncertainty_head' in name:
41
+ head_params.append(param)
42
+ else:
43
+ other_params.append(param)
44
+
45
+ param_groups = [
46
+ {'params': backbone_params, 'lr': lr * backbone_lr_scale},
47
+ {'params': head_params, 'lr': lr},
48
+ {'params': other_params, 'lr': lr * 0.5},
49
+ ]
50
+
51
+ return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
52
+
53
+
54
+ def build_scheduler(optimizer, total_epochs, warmup_epochs=5):
55
+ """Cosine annealing with linear warmup."""
56
+ def lr_lambda(epoch):
57
+ if epoch < warmup_epochs:
58
+ return epoch / warmup_epochs
59
+ progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
60
+ return 0.5 * (1 + math.cos(math.pi * progress))
61
+
62
+ return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
63
+
64
+
65
+ def train_one_epoch(
66
+ model, dataloader, optimizer, scheduler, scaler, loss_fn, device,
67
+ epoch, total_epochs, acl_lambda=None, grad_clip=1.0,
68
+ ):
69
+ """Train for one epoch with AMP and gradient clipping."""
70
+ model.train()
71
+ total_loss = 0
72
+ num_batches = 0
73
+
74
+ for batch_idx, batch in enumerate(dataloader):
75
+ template = batch['template'].to(device)
76
+ search = batch['search'].to(device)
77
+ gt_heatmap = batch['heatmap'].to(device)
78
+ gt_size = batch['size'].to(device)
79
+ gt_boxes = batch['boxes'].to(device)
80
+
81
+ optimizer.zero_grad()
82
+
83
+ with autocast(enabled=scaler is not None):
84
+ pred = model(template, search, use_temporal=False)
85
+ loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
86
+ loss = loss_dict['total']
87
+
88
+ # ACL difficulty weighting
89
+ if acl_lambda is not None:
90
+ loss = loss * acl_lambda
91
+
92
+ if scaler is not None:
93
+ scaler.scale(loss).backward()
94
+ scaler.unscale_(optimizer)
95
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
96
+ scaler.step(optimizer)
97
+ scaler.update()
98
+ else:
99
+ loss.backward()
100
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
101
+ optimizer.step()
102
+
103
+ total_loss += loss.item()
104
+ num_batches += 1
105
+
106
+ if batch_idx % 100 == 0:
107
+ print(f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
108
+ f"Loss: {loss.item():.4f} | "
109
+ f"Heatmap: {loss_dict['heatmap']:.4f} | "
110
+ f"GIoU: {loss_dict['giou']:.4f} | "
111
+ f"Size: {loss_dict['size']:.4f}")
112
+
113
+ avg_loss = total_loss / max(num_batches, 1)
114
+ return avg_loss
115
+
116
+
117
+ def train_phase1(
118
+ model, train_dataset, config, device='cuda',
119
+ num_epochs=300, lr=1e-4, batch_size=32, num_workers=4,
120
+ save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
121
+ ):
122
+ """Phase 1: Standard supervised training."""
123
+ print(f"=== Phase 1 Training: {num_epochs} epochs ===")
124
+
125
+ os.makedirs(save_dir, exist_ok=True)
126
+
127
+ from .losses import CombinedTrackingLoss
128
+ loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
129
+
130
+ model = model.to(device)
131
+ optimizer = build_optimizer(model, lr=lr)
132
+ scheduler = build_scheduler(optimizer, num_epochs)
133
+ scaler = GradScaler() if device == 'cuda' else None
134
+
135
+ dataloader = DataLoader(
136
+ train_dataset, batch_size=batch_size, shuffle=True,
137
+ num_workers=num_workers, pin_memory=True, drop_last=True,
138
+ )
139
+
140
+ best_loss = float('inf')
141
+
142
+ for epoch in range(num_epochs):
143
+ # ACL curriculum: linear ramp-up of difficulty
144
+ acl_lambda = min(1.0, (epoch + 1) / 50) # Ramp up over 50 epochs
145
+
146
+ avg_loss = train_one_epoch(
147
+ model, dataloader, optimizer, scheduler, scaler, loss_fn,
148
+ device, epoch, num_epochs, acl_lambda=acl_lambda,
149
+ )
150
+
151
+ scheduler.step()
152
+
153
+ print(f"Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
154
+ f"LR: {scheduler.get_last_lr()[0]:.6f} | ACL 位: {acl_lambda:.2f}")
155
+
156
+ # Save best
157
+ if avg_loss < best_loss:
158
+ best_loss = avg_loss
159
+ torch.save({
160
+ 'epoch': epoch,
161
+ 'model_state_dict': model.state_dict(),
162
+ 'optimizer_state_dict': optimizer.state_dict(),
163
+ 'loss': best_loss,
164
+ }, os.path.join(save_dir, 'best_phase1.pth'))
165
+
166
+ # Save periodic
167
+ if (epoch + 1) % 50 == 0:
168
+ torch.save({
169
+ 'epoch': epoch,
170
+ 'model_state_dict': model.state_dict(),
171
+ 'optimizer_state_dict': optimizer.state_dict(),
172
+ 'loss': avg_loss,
173
+ }, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth'))
174
+
175
+ if push_to_hub and hub_model_id:
176
+ _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase1')
177
+
178
+ return model
179
+
180
+
181
+ def train_phase2(
182
+ model, train_dataset, config, device='cuda',
183
+ num_epochs=100, lr=1e-5, batch_size=32, num_workers=4,
184
+ save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
185
+ ):
186
+ """Phase 2: Fine-tuning with frozen shared experts."""
187
+ print(f"=== Phase 2 Training: {num_epochs} epochs ===")
188
+
189
+ # Freeze shared experts
190
+ model.freeze_backbone_shared_experts()
191
+
192
+ from .losses import CombinedTrackingLoss
193
+ loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
194
+
195
+ model = model.to(device)
196
+ optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01)
197
+ scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2)
198
+ scaler = GradScaler() if device == 'cuda' else None
199
+
200
+ dataloader = DataLoader(
201
+ train_dataset, batch_size=batch_size, shuffle=True,
202
+ num_workers=num_workers, pin_memory=True, drop_last=True,
203
+ )
204
+
205
+ best_loss = float('inf')
206
+
207
+ for epoch in range(num_epochs):
208
+ avg_loss = train_one_epoch(
209
+ model, dataloader, optimizer, scheduler, scaler, loss_fn,
210
+ device, epoch, num_epochs,
211
+ )
212
+
213
+ scheduler.step()
214
+
215
+ print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
216
+ f"LR: {scheduler.get_last_lr()[0]:.6f}")
217
+
218
+ if avg_loss < best_loss:
219
+ best_loss = avg_loss
220
+ torch.save({
221
+ 'epoch': epoch,
222
+ 'model_state_dict': model.state_dict(),
223
+ 'loss': best_loss,
224
+ }, os.path.join(save_dir, 'best_phase2.pth'))
225
+
226
+ if push_to_hub and hub_model_id:
227
+ _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2')
228
+
229
+ return model
230
+
231
+
232
+ def _push_checkpoint_to_hub(model, save_dir, hub_model_id, phase):
233
+ """Push checkpoint to HuggingFace Hub."""
234
+ try:
235
+ from huggingface_hub import HfApi
236
+ api = HfApi()
237
+ api.upload_folder(
238
+ folder_path=save_dir,
239
+ repo_id=hub_model_id,
240
+ path_in_repo=f'checkpoints/{phase}',
241
+ )
242
+ print(f"Pushed {phase} checkpoint to {hub_model_id}")
243
+ except Exception as e:
244
+ print(f"Warning: Could not push to hub: {e}")