vil-tracker / vil_tracker /inference /online_tracker.py
omar-ah's picture
Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window
92a81c6 verified
"""
Online Tracker: Full inference pipeline for ViL Tracker.
Pipeline per frame:
1. Crop search region around predicted position
2. Run model: template + search → heatmap, size, offset
3. Decode predictions → candidate box
4. Apply Kalman filter for temporal smoothing
5. Update search region for next frame
Features:
- Adaptive search region scaling
- Confidence-based template update (skip when uncertain)
- Kalman filter with uncertainty-adaptive noise
"""
import torch
import numpy as np
from .kalman import KalmanFilter
class OnlineTracker:
"""Online single-object tracker using ViL backbone.
Combines:
- Kalman filter for dynamic motion-model-based search centering (handles UAV ego-motion)
- Hanning window for positional prior penalty on heatmap (suppresses edge false positives)
- Uncertainty-adaptive Kalman measurement noise
- Confidence-gated template update
Usage:
tracker = OnlineTracker(model, device='cuda')
tracker.initialize(first_frame, init_bbox) # [x, y, w, h]
for frame in video[1:]:
bbox = tracker.track(frame) # returns [x, y, w, h]
"""
def __init__(
self,
model,
device: str = 'cuda',
template_size: int = 128,
search_size: int = 256,
search_scale: float = 4.0,
confidence_threshold: float = 0.3,
template_update_threshold: float = 0.8,
use_hanning: bool = True,
):
self.model = model
self.device = device
self.template_size = template_size
self.search_size = search_size
self.search_scale = search_scale
self.confidence_threshold = confidence_threshold
self.template_update_threshold = template_update_threshold
self.model.eval()
# Hanning window for positional prior (generated once, reused every frame)
feat_size = search_size // 16 # 256/16 = 16
if use_hanning:
from ..models.heads import create_hanning_window
self.hanning_window = create_hanning_window(feat_size).to(device)
else:
self.hanning_window = None
# State
self.template = None
self.kalman = KalmanFilter()
self.target_pos = None # [cx, cy]
self.target_sz = None # [w, h]
self.frame_count = 0
def initialize(self, frame: np.ndarray, bbox: list):
"""Initialize tracker with first frame and bounding box.
Args:
frame: (H, W, 3) BGR or RGB numpy array
bbox: [x, y, w, h] initial bounding box (top-left format)
"""
x, y, w, h = bbox
self.target_pos = np.array([x + w / 2, y + h / 2])
self.target_sz = np.array([w, h])
# Crop and embed template
self.template = self._crop_and_preprocess(
frame, self.target_pos, self.target_sz,
output_size=self.template_size,
scale_factor=2.0,
)
# Initialize Kalman filter
self.kalman.initialize(np.array([
self.target_pos[0], self.target_pos[1],
self.target_sz[0], self.target_sz[1],
]))
# Reset temporal modulation
self.model.reset_temporal()
self.frame_count = 0
def track(self, frame: np.ndarray) -> list:
"""Track target in new frame.
Args:
frame: (H, W, 3) numpy array
Returns:
[x, y, w, h] predicted bounding box (top-left format)
"""
self.frame_count += 1
# Kalman predict
kf_pred = self.kalman.predict()
pred_pos = kf_pred[:2]
pred_sz = kf_pred[2:]
# Crop search region around predicted position
search = self._crop_and_preprocess(
frame, pred_pos, pred_sz,
output_size=self.search_size,
scale_factor=self.search_scale,
)
# Run model
with torch.no_grad():
output = self.model(
self.template.to(self.device),
search.to(self.device),
use_temporal=(self.frame_count > 1),
)
# Extract predictions — re-decode with Hanning window for inference
from ..models.heads import decode_predictions
boxes_tensor, scores_tensor = decode_predictions(
output['heatmap'],
output['size'],
output['offset'],
search_size=self.search_size,
feat_size=self.search_size // 16,
hanning_window=self.hanning_window,
)
boxes = boxes_tensor.cpu().numpy()[0] # [cx, cy, w, h] in search region
score = scores_tensor.cpu().item()
# Map back to original frame coordinates
scale_factor = self.search_scale * max(pred_sz) / self.search_size
cx = (boxes[0] - self.search_size / 2) * scale_factor + pred_pos[0]
cy = (boxes[1] - self.search_size / 2) * scale_factor + pred_pos[1]
w = boxes[2] * scale_factor
h = boxes[3] * scale_factor
# Confidence-based update
if score > self.confidence_threshold:
# Get uncertainty for Kalman noise adaptation
uncertainty = 1.0
if 'log_variance' in output:
log_var = output['log_variance'].mean().cpu().item()
uncertainty = max(0.5, min(3.0, np.exp(log_var / 2)))
self.kalman.update(np.array([cx, cy, w, h]), uncertainty)
# Update template if very confident
if score > self.template_update_threshold and self.frame_count % 10 == 0:
self.template = self._crop_and_preprocess(
frame, np.array([cx, cy]), np.array([w, h]),
output_size=self.template_size,
scale_factor=2.0,
)
# Use Kalman-smoothed state
state = self.kalman.get_state()
self.target_pos = state[:2]
self.target_sz = state[2:]
# Return top-left format [x, y, w, h]
return [
self.target_pos[0] - self.target_sz[0] / 2,
self.target_pos[1] - self.target_sz[1] / 2,
self.target_sz[0],
self.target_sz[1],
]
def _crop_and_preprocess(
self,
frame: np.ndarray,
center: np.ndarray,
size: np.ndarray,
output_size: int,
scale_factor: float,
) -> torch.Tensor:
"""Crop and preprocess image region.
Args:
frame: (H, W, 3) numpy array
center: [cx, cy] crop center
size: [w, h] target size
output_size: desired output size
scale_factor: how much larger than target to crop
Returns:
(1, 3, output_size, output_size) preprocessed tensor
"""
H, W = frame.shape[:2]
# Compute crop size
crop_size = max(size[0], size[1]) * scale_factor
crop_size = max(crop_size, 10) # minimum crop size
# Crop coordinates
x1 = int(center[0] - crop_size / 2)
y1 = int(center[1] - crop_size / 2)
x2 = int(x1 + crop_size)
y2 = int(y1 + crop_size)
# Handle boundaries with padding
pad_left = max(0, -x1)
pad_top = max(0, -y1)
pad_right = max(0, x2 - W)
pad_bottom = max(0, y2 - H)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(W, x2)
y2 = min(H, y2)
crop = frame[y1:y2, x1:x2]
if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
crop = np.pad(crop, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
mode='constant', constant_values=0)
# Resize to output_size
if crop.shape[0] > 0 and crop.shape[1] > 0:
import torch.nn.functional as F
crop_tensor = torch.from_numpy(crop).float().permute(2, 0, 1).unsqueeze(0)
crop_tensor = F.interpolate(crop_tensor, size=(output_size, output_size),
mode='bilinear', align_corners=False)
else:
crop_tensor = torch.zeros(1, 3, output_size, output_size)
# Normalize to [0, 1]
crop_tensor = crop_tensor / 255.0
return crop_tensor