omar-ah commited on
Commit
908010b
·
verified ·
1 Parent(s): 823a1a3

Upload vil_tracker/inference/kalman.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/inference/kalman.py +141 -0
vil_tracker/inference/kalman.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kalman Filter for online tracking state estimation.
3
+
4
+ 8-state model: [cx, cy, w, h, vx, vy, vw, vh]
5
+ - Position + size (4 states) + velocities (4 states)
6
+ - Constant velocity motion model
7
+ - Adaptive measurement noise based on prediction uncertainty
8
+ """
9
+
10
+ import numpy as np
11
+
12
+
13
+ class KalmanFilter:
14
+ """8-state Kalman filter for bounding box tracking.
15
+
16
+ State: [cx, cy, w, h, vx, vy, vw, vh]
17
+ Measurement: [cx, cy, w, h]
18
+
19
+ Features:
20
+ - Adaptive measurement noise (R) based on prediction uncertainty
21
+ - Chi-squared gating for outlier rejection
22
+ - Velocity damping for stable predictions
23
+ """
24
+
25
+ def __init__(self, dt: float = 1.0):
26
+ self.dt = dt
27
+ self.ndim = 4 # measurement dimensions
28
+ self.nstate = 8 # state dimensions
29
+
30
+ # State transition matrix (constant velocity)
31
+ self.F = np.eye(self.nstate)
32
+ for i in range(self.ndim):
33
+ self.F[i, i + self.ndim] = dt
34
+
35
+ # Measurement matrix
36
+ self.H = np.eye(self.ndim, self.nstate)
37
+
38
+ # Process noise
39
+ self._std_weight_position = 1.0 / 20
40
+ self._std_weight_velocity = 1.0 / 160
41
+
42
+ # State
43
+ self.x = None # State mean
44
+ self.P = None # State covariance
45
+ self._initialized = False
46
+
47
+ def initialize(self, measurement: np.ndarray):
48
+ """Initialize filter with first measurement [cx, cy, w, h]."""
49
+ self.x = np.zeros(self.nstate)
50
+ self.x[:self.ndim] = measurement
51
+
52
+ std = [
53
+ 2 * self._std_weight_position * measurement[2],
54
+ 2 * self._std_weight_position * measurement[3],
55
+ 2 * self._std_weight_position * measurement[2],
56
+ 2 * self._std_weight_position * measurement[3],
57
+ 10 * self._std_weight_velocity * measurement[2],
58
+ 10 * self._std_weight_velocity * measurement[3],
59
+ 10 * self._std_weight_velocity * measurement[2],
60
+ 10 * self._std_weight_velocity * measurement[3],
61
+ ]
62
+ self.P = np.diag(np.square(std))
63
+ self._initialized = True
64
+
65
+ def predict(self) -> np.ndarray:
66
+ """Predict next state. Returns predicted [cx, cy, w, h]."""
67
+ if not self._initialized:
68
+ raise RuntimeError("Filter not initialized")
69
+
70
+ # Process noise
71
+ std = [
72
+ self._std_weight_position * self.x[2],
73
+ self._std_weight_position * self.x[3],
74
+ self._std_weight_position * self.x[2],
75
+ self._std_weight_position * self.x[3],
76
+ self._std_weight_velocity * self.x[2],
77
+ self._std_weight_velocity * self.x[3],
78
+ self._std_weight_velocity * self.x[2],
79
+ self._std_weight_velocity * self.x[3],
80
+ ]
81
+ Q = np.diag(np.square(std))
82
+
83
+ # State prediction
84
+ self.x = self.F @ self.x
85
+ self.P = self.F @ self.P @ self.F.T + Q
86
+
87
+ # Velocity damping
88
+ self.x[self.ndim:] *= 0.95
89
+
90
+ return self.x[:self.ndim].copy()
91
+
92
+ def update(self, measurement: np.ndarray, uncertainty: float = 1.0):
93
+ """Update state with new measurement.
94
+
95
+ Args:
96
+ measurement: [cx, cy, w, h] observed box
97
+ uncertainty: prediction uncertainty (scales measurement noise)
98
+ """
99
+ if not self._initialized:
100
+ self.initialize(measurement)
101
+ return
102
+
103
+ # Measurement noise (adaptive based on uncertainty)
104
+ std = [
105
+ self._std_weight_position * self.x[2] * uncertainty,
106
+ self._std_weight_position * self.x[3] * uncertainty,
107
+ self._std_weight_position * self.x[2] * uncertainty,
108
+ self._std_weight_position * self.x[3] * uncertainty,
109
+ ]
110
+ R = np.diag(np.square(std))
111
+
112
+ # Innovation
113
+ y = measurement - self.H @ self.x
114
+ S = self.H @ self.P @ self.H.T + R
115
+
116
+ # Chi-squared gating (reject outliers)
117
+ mahalanobis = y @ np.linalg.inv(S) @ y
118
+ if mahalanobis > 16.0: # ~99.99% chi-squared threshold for 4 DOF
119
+ return # Reject this measurement
120
+
121
+ # Kalman gain
122
+ K = self.P @ self.H.T @ np.linalg.inv(S)
123
+
124
+ # State update
125
+ self.x = self.x + K @ y
126
+ I_KH = np.eye(self.nstate) - K @ self.H
127
+ self.P = I_KH @ self.P @ I_KH.T + K @ R @ K.T # Joseph form
128
+
129
+ # Ensure w, h stay positive
130
+ self.x[2] = max(self.x[2], 1.0)
131
+ self.x[3] = max(self.x[3], 1.0)
132
+
133
+ def get_state(self) -> np.ndarray:
134
+ """Get current state estimate [cx, cy, w, h]."""
135
+ if not self._initialized:
136
+ return np.zeros(self.ndim)
137
+ return self.x[:self.ndim].copy()
138
+
139
+ @property
140
+ def initialized(self):
141
+ return self._initialized