Jdice27 commited on
Commit
e8142ba
·
verified ·
1 Parent(s): 48b8bfe

Add data_pipeline.py

Browse files
Files changed (1) hide show
  1. data_pipeline.py +960 -0
data_pipeline.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AirTrackLM - Data Pipeline
3
+ ==========================
4
+ Converts raw ADS-B (lat, lon, alt, timestamp) to model-ready tensors.
5
+
6
+ Pipeline:
7
+ 1. Load trajectories from `traffic` library or raw CSV
8
+ 2. Resample to fixed time interval (default 5s)
9
+ 3. Convert lat/lon/alt to ENU (East-North-Up) coordinates using first point as origin
10
+ 4. Compute velocity via 3-point central derivative on ENU positions
11
+ 5. Derive COG, SOG from x-y ground velocity; ROT from COG; altitude rate from z velocity
12
+ 6. Binary geohash encoding (40-bit per axis, following LLM4STP)
13
+ 7. Discretize features into bins
14
+ 8. Compute uncertainty scores
15
+ 9. Build sliding-window PyTorch Dataset
16
+ """
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from typing import Optional, Tuple, List, Dict
22
+ import pyproj
23
+ from dataclasses import dataclass, field
24
+
25
+
26
+ # ============================================================
27
+ # 1. ENU Coordinate Conversion
28
+ # ============================================================
29
+
30
+ class ENUConverter:
31
+ """
32
+ Convert WGS84 (lat, lon, alt) to local East-North-Up (ENU) coordinates.
33
+ Origin is set to the first point of each trajectory.
34
+
35
+ Uses pyproj for geodetically correct transformations.
36
+ """
37
+
38
+ def __init__(self, origin_lat: float, origin_lon: float, origin_alt: float = 0.0):
39
+ self.origin_lat = origin_lat
40
+ self.origin_lon = origin_lon
41
+ self.origin_alt = origin_alt
42
+
43
+ # ECEF transformer
44
+ self.ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
45
+ self.lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
46
+ self.transformer_to_ecef = pyproj.Transformer.from_proj(self.lla, self.ecef, always_xy=True)
47
+ self.transformer_to_lla = pyproj.Transformer.from_proj(self.ecef, self.lla, always_xy=True)
48
+
49
+ # Origin in ECEF
50
+ self.x0, self.y0, self.z0 = self.transformer_to_ecef.transform(
51
+ origin_lon, origin_lat, origin_alt
52
+ )
53
+
54
+ # Rotation matrix (ECEF -> ENU)
55
+ lat_r = np.radians(origin_lat)
56
+ lon_r = np.radians(origin_lon)
57
+ self.R = np.array([
58
+ [-np.sin(lon_r), np.cos(lon_r), 0 ],
59
+ [-np.sin(lat_r)*np.cos(lon_r), -np.sin(lat_r)*np.sin(lon_r), np.cos(lat_r)],
60
+ [ np.cos(lat_r)*np.cos(lon_r), np.cos(lat_r)*np.sin(lon_r), np.sin(lat_r)]
61
+ ])
62
+
63
+ def to_enu(self, lats: np.ndarray, lons: np.ndarray, alts: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
64
+ """Convert arrays of lat/lon/alt to ENU (meters)."""
65
+ # To ECEF
66
+ x, y, z = self.transformer_to_ecef.transform(lons, lats, alts)
67
+
68
+ # Offset from origin
69
+ dx = x - self.x0
70
+ dy = y - self.y0
71
+ dz = z - self.z0
72
+
73
+ # Rotate to ENU
74
+ ecef_delta = np.stack([dx, dy, dz], axis=0) # (3, N)
75
+ enu = self.R @ ecef_delta # (3, N)
76
+
77
+ east = enu[0] # meters
78
+ north = enu[1] # meters
79
+ up = enu[2] # meters
80
+
81
+ return east, north, up
82
+
83
+ def from_enu(self, east: np.ndarray, north: np.ndarray, up: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
84
+ """Convert ENU back to lat/lon/alt."""
85
+ enu = np.stack([east, north, up], axis=0)
86
+ ecef_delta = self.R.T @ enu
87
+
88
+ x = ecef_delta[0] + self.x0
89
+ y = ecef_delta[1] + self.y0
90
+ z = ecef_delta[2] + self.z0
91
+
92
+ lons, lats, alts = self.transformer_to_lla.transform(x, y, z)
93
+ return lats, lons, alts
94
+
95
+
96
+ # ============================================================
97
+ # 2. Three-Point Central Derivative
98
+ # ============================================================
99
+
100
+ def three_point_derivative(values: np.ndarray, dt: np.ndarray) -> np.ndarray:
101
+ """
102
+ Compute derivative using 3-point central difference.
103
+
104
+ For interior points (i=1..N-2):
105
+ f'(i) = (f(i+1) - f(i-1)) / (t(i+1) - t(i-1))
106
+
107
+ For endpoints:
108
+ f'(0) = (f(1) - f(0)) / (t(1) - t(0)) # forward difference
109
+ f'(N-1) = (f(N-1) - f(N-2)) / (t(N-1) - t(N-2)) # backward difference
110
+
111
+ Args:
112
+ values: shape (N,) — the signal to differentiate
113
+ dt: shape (N,) — cumulative time from start (seconds)
114
+
115
+ Returns:
116
+ derivative: shape (N,) — rate of change per second
117
+ """
118
+ N = len(values)
119
+ deriv = np.zeros(N)
120
+
121
+ if N < 2:
122
+ return deriv
123
+
124
+ # Forward difference for first point
125
+ dt_fwd = dt[1] - dt[0]
126
+ if dt_fwd > 0:
127
+ deriv[0] = (values[1] - values[0]) / dt_fwd
128
+
129
+ # Central difference for interior points
130
+ for i in range(1, N - 1):
131
+ dt_span = dt[i + 1] - dt[i - 1]
132
+ if dt_span > 0:
133
+ deriv[i] = (values[i + 1] - values[i - 1]) / dt_span
134
+
135
+ # Backward difference for last point
136
+ dt_bwd = dt[-1] - dt[-2]
137
+ if dt_bwd > 0:
138
+ deriv[-1] = (values[-1] - values[-2]) / dt_bwd
139
+
140
+ return deriv
141
+
142
+
143
+ def three_point_derivative_vectorized(values: np.ndarray, dt: np.ndarray) -> np.ndarray:
144
+ """Vectorized version of 3-point central derivative."""
145
+ N = len(values)
146
+ deriv = np.zeros(N)
147
+
148
+ if N < 2:
149
+ return deriv
150
+
151
+ # Forward difference for first point
152
+ dt_fwd = dt[1] - dt[0]
153
+ if dt_fwd > 0:
154
+ deriv[0] = (values[1] - values[0]) / dt_fwd
155
+
156
+ # Central difference for interior points (vectorized)
157
+ if N > 2:
158
+ dt_span = dt[2:] - dt[:-2] # (N-2,)
159
+ mask = dt_span > 0
160
+ val_diff = values[2:] - values[:-2] # (N-2,)
161
+ deriv[1:-1] = np.where(mask, val_diff / np.maximum(dt_span, 1e-10), 0.0)
162
+
163
+ # Backward difference for last point
164
+ dt_bwd = dt[-1] - dt[-2]
165
+ if dt_bwd > 0:
166
+ deriv[-1] = (values[-1] - values[-2]) / dt_bwd
167
+
168
+ return deriv
169
+
170
+
171
+ # ============================================================
172
+ # 3. Feature Derivation from ENU positions
173
+ # ============================================================
174
+
175
+ def derive_features_enu(
176
+ east: np.ndarray,
177
+ north: np.ndarray,
178
+ up: np.ndarray,
179
+ timestamps: np.ndarray
180
+ ) -> Dict[str, np.ndarray]:
181
+ """
182
+ Derive COG, SOG, ROT, and altitude rate from ENU positions
183
+ using 3-point central derivatives.
184
+
185
+ Args:
186
+ east, north, up: ENU coordinates in meters, shape (N,)
187
+ timestamps: Unix timestamps in seconds, shape (N,)
188
+
189
+ Returns:
190
+ dict with keys: 'vx', 'vy', 'vz', 'COG', 'SOG', 'ROT', 'alt_rate'
191
+ Each is shape (N,)
192
+ """
193
+ # Cumulative time from start
194
+ t = timestamps - timestamps[0]
195
+
196
+ # 3-point derivative on ENU positions → velocities (m/s)
197
+ vx = three_point_derivative_vectorized(east, t) # East velocity (m/s)
198
+ vy = three_point_derivative_vectorized(north, t) # North velocity (m/s)
199
+ vz = three_point_derivative_vectorized(up, t) # Up velocity (m/s)
200
+
201
+ # SOG from ground plane velocity: sqrt(vx² + vy²), convert m/s → knots
202
+ sog_ms = np.sqrt(vx**2 + vy**2)
203
+ sog_knots = sog_ms * 1.94384 # m/s to knots
204
+
205
+ # COG from ground velocity components: atan2(vx, vy) → degrees from North
206
+ # atan2(East, North) gives bearing from North, clockwise
207
+ cog_deg = np.degrees(np.arctan2(vx, vy)) % 360
208
+
209
+ # ROT: derivative of COG (degrees/second)
210
+ # Need to handle circular wraparound — unwrap COG first
211
+ cog_unwrapped = np.unwrap(np.radians(cog_deg))
212
+ rot_rad_s = three_point_derivative_vectorized(cog_unwrapped, t)
213
+ rot_deg_s = np.degrees(rot_rad_s)
214
+
215
+ # Altitude rate: vz converted to ft/min
216
+ alt_rate_ftmin = vz * 196.85 # m/s → ft/min (1 m/s = 196.85 ft/min)
217
+
218
+ return {
219
+ 'vx': vx,
220
+ 'vy': vy,
221
+ 'vz': vz,
222
+ 'COG': cog_deg,
223
+ 'SOG': sog_knots,
224
+ 'ROT': rot_deg_s,
225
+ 'alt_rate': alt_rate_ftmin,
226
+ }
227
+
228
+
229
+ # ============================================================
230
+ # 4. Binary Geohash Encoding (following LLM4STP, 40-bit precision)
231
+ # ============================================================
232
+
233
+ def binary_geohash_encode(
234
+ values: np.ndarray,
235
+ precision: int = 40,
236
+ v_min: float = 0.0,
237
+ v_max: float = 1.0
238
+ ) -> np.ndarray:
239
+ """
240
+ Encode normalized values as binary geohash via successive bisection.
241
+ Matches LLM4STP's num2bits() implementation.
242
+
243
+ Args:
244
+ values: shape (N,) — normalized to [v_min, v_max]
245
+ precision: number of bits
246
+ v_min, v_max: range bounds
247
+
248
+ Returns:
249
+ bits: shape (N, precision) — binary encoding (0/1)
250
+ """
251
+ N = len(values)
252
+ bits = np.zeros((N, precision), dtype=np.int64)
253
+
254
+ _min = np.full(N, v_min)
255
+ _max = np.full(N, v_max)
256
+
257
+ for p in range(precision):
258
+ mid = (_min + _max) / 2
259
+ mask = values > mid
260
+ bits[:, p] = mask.astype(np.int64)
261
+ _min = np.where(mask, mid, _min)
262
+ _max = np.where(mask, _max, mid)
263
+
264
+ return bits
265
+
266
+
267
+ def binary_geohash_decode(
268
+ bits: np.ndarray,
269
+ precision: int = 40,
270
+ v_min: float = 0.0,
271
+ v_max: float = 1.0
272
+ ) -> np.ndarray:
273
+ """Decode binary geohash back to values."""
274
+ N = bits.shape[0]
275
+ _min = np.full(N, v_min)
276
+ _max = np.full(N, v_max)
277
+
278
+ for p in range(precision):
279
+ mid = (_min + _max) / 2
280
+ mask = bits[:, p].astype(bool)
281
+ _min = np.where(mask, mid, _min)
282
+ _max = np.where(mask, _max, mid)
283
+
284
+ return (_min + _max) / 2
285
+
286
+
287
+ class GeohashEncoder:
288
+ """
289
+ 3D geohash encoder for aviation.
290
+ Encodes (east, north, up) ENU coordinates as binary geohash.
291
+
292
+ Following LLM4STP: 40 bits per axis, coordinates normalized to [0,1].
293
+ Extended to 3 axes (E, N, U) for aviation 3D trajectory.
294
+
295
+ Total encoding: 40*3 = 120 bits per timestep.
296
+ """
297
+
298
+ def __init__(self, precision: int = 40):
299
+ self.precision = precision
300
+ # Normalization bounds — set from training data
301
+ self.e_min = None
302
+ self.e_max = None
303
+ self.n_min = None
304
+ self.n_max = None
305
+ self.u_min = None
306
+ self.u_max = None
307
+
308
+ def fit(self, east: np.ndarray, north: np.ndarray, up: np.ndarray, margin: float = 0.05):
309
+ """Set normalization bounds from training data with margin."""
310
+ e_range = east.max() - east.min()
311
+ n_range = north.max() - north.min()
312
+ u_range = up.max() - up.min()
313
+
314
+ self.e_min = east.min() - margin * max(e_range, 1.0)
315
+ self.e_max = east.max() + margin * max(e_range, 1.0)
316
+ self.n_min = north.min() - margin * max(n_range, 1.0)
317
+ self.n_max = north.max() + margin * max(n_range, 1.0)
318
+ self.u_min = up.min() - margin * max(u_range, 1.0)
319
+ self.u_max = up.max() + margin * max(u_range, 1.0)
320
+
321
+ def normalize(self, values: np.ndarray, v_min: float, v_max: float) -> np.ndarray:
322
+ """Normalize to [0, 1]."""
323
+ return np.clip((values - v_min) / max(v_max - v_min, 1e-10), 0.0, 1.0)
324
+
325
+ def encode(self, east: np.ndarray, north: np.ndarray, up: np.ndarray) -> np.ndarray:
326
+ """
327
+ Encode ENU positions as 3D binary geohash.
328
+
329
+ Returns:
330
+ bits: shape (N, precision*3) — concatenated [E_bits | N_bits | U_bits]
331
+ """
332
+ e_norm = self.normalize(east, self.e_min, self.e_max)
333
+ n_norm = self.normalize(north, self.n_min, self.n_max)
334
+ u_norm = self.normalize(up, self.u_min, self.u_max)
335
+
336
+ e_bits = binary_geohash_encode(e_norm, self.precision)
337
+ n_bits = binary_geohash_encode(n_norm, self.precision)
338
+ u_bits = binary_geohash_encode(u_norm, self.precision)
339
+
340
+ return np.concatenate([e_bits, n_bits, u_bits], axis=1) # (N, 120)
341
+
342
+
343
+ # ============================================================
344
+ # 5. Feature Discretization
345
+ # ============================================================
346
+
347
+ @dataclass
348
+ class FeatureBins:
349
+ """Configuration for discretizing continuous features into bins."""
350
+
351
+ # COG: [0, 360) degrees, 2° bins for high resolution
352
+ cog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 360, 181)) # 180 bins
353
+
354
+ # SOG: [0, 600] knots, 2-knot bins
355
+ sog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 600, 301)) # 300 bins
356
+
357
+ # ROT: [-6, 6] deg/s, 0.1 deg/s bins
358
+ rot_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6, 6, 121)) # 120 bins
359
+
360
+ # Altitude rate: [-6000, 6000] ft/min, 100 ft/min bins
361
+ alt_rate_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6000, 6000, 121)) # 120 bins
362
+
363
+ @property
364
+ def n_cog_bins(self): return len(self.cog_edges) - 1
365
+
366
+ @property
367
+ def n_sog_bins(self): return len(self.sog_edges) - 1
368
+
369
+ @property
370
+ def n_rot_bins(self): return len(self.rot_edges) - 1
371
+
372
+ @property
373
+ def n_alt_rate_bins(self): return len(self.alt_rate_edges) - 1
374
+
375
+ def digitize(self, values: np.ndarray, edges: np.ndarray) -> np.ndarray:
376
+ """Bin values. Returns indices in [0, n_bins-1], clipped."""
377
+ indices = np.digitize(values, edges) - 1
378
+ return np.clip(indices, 0, len(edges) - 2)
379
+
380
+ def encode_cog(self, cog: np.ndarray) -> np.ndarray:
381
+ return self.digitize(cog, self.cog_edges)
382
+
383
+ def encode_sog(self, sog: np.ndarray) -> np.ndarray:
384
+ return self.digitize(sog, self.sog_edges)
385
+
386
+ def encode_rot(self, rot: np.ndarray) -> np.ndarray:
387
+ rot_clipped = np.clip(rot, -6, 6)
388
+ return self.digitize(rot_clipped, self.rot_edges)
389
+
390
+ def encode_alt_rate(self, alt_rate: np.ndarray) -> np.ndarray:
391
+ ar_clipped = np.clip(alt_rate, -6000, 6000)
392
+ return self.digitize(ar_clipped, self.alt_rate_edges)
393
+
394
+
395
+ # ============================================================
396
+ # 6. Uncertainty Score
397
+ # ============================================================
398
+
399
+ def compute_uncertainty(
400
+ cog: np.ndarray,
401
+ sog: np.ndarray,
402
+ rot: np.ndarray,
403
+ alt_rate: np.ndarray,
404
+ window: int = 5
405
+ ) -> np.ndarray:
406
+ """
407
+ Compute trajectory uncertainty score from recent state variance.
408
+ High variance = high uncertainty (maneuvering aircraft).
409
+
410
+ Returns:
411
+ scores: shape (N,) — uncertainty scores (higher = more uncertain)
412
+ """
413
+ N = len(cog)
414
+ scores = np.zeros(N)
415
+
416
+ for i in range(N):
417
+ start = max(0, i - window + 1)
418
+ w = slice(start, i + 1)
419
+
420
+ # Circular variance for COG
421
+ cog_rad = np.radians(cog[w])
422
+ R_len = np.sqrt(np.mean(np.cos(cog_rad))**2 + np.mean(np.sin(cog_rad))**2)
423
+ cog_var = 1 - R_len # circular variance [0, 1]
424
+
425
+ # Regular variance for others
426
+ sog_var = np.var(sog[w]) if len(sog[w]) > 1 else 0
427
+ rot_var = np.var(rot[w]) if len(rot[w]) > 1 else 0
428
+ alt_var = np.var(alt_rate[w]) if len(alt_rate[w]) > 1 else 0
429
+
430
+ # Normalize and combine (equal weights)
431
+ scores[i] = cog_var + sog_var / max(np.var(sog) + 1e-10, 1e-10) + \
432
+ rot_var / max(np.var(rot) + 1e-10, 1e-10) + \
433
+ alt_var / max(np.var(alt_rate) + 1e-10, 1e-10)
434
+
435
+ return scores
436
+
437
+
438
+ def discretize_uncertainty(scores: np.ndarray, n_bins: int = 16) -> np.ndarray:
439
+ """Discretize uncertainty scores into quantile bins."""
440
+ if len(np.unique(scores)) < n_bins:
441
+ # Not enough unique values for quantile binning
442
+ edges = np.linspace(scores.min(), scores.max() + 1e-10, n_bins + 1)
443
+ else:
444
+ edges = np.quantile(scores, np.linspace(0, 1, n_bins + 1))
445
+ edges[-1] += 1e-10 # ensure max value is included
446
+
447
+ return np.clip(np.digitize(scores, edges) - 1, 0, n_bins - 1)
448
+
449
+
450
+ # ============================================================
451
+ # 7. Temporal Features
452
+ # ============================================================
453
+
454
+ def extract_temporal_features(timestamps: np.ndarray) -> Dict[str, np.ndarray]:
455
+ """
456
+ Extract temporal features from Unix timestamps.
457
+
458
+ Returns dict with:
459
+ 'second_of_day': float seconds within the day [0, 86400)
460
+ 'hour': int hour of day [0, 23]
461
+ 'dow': int day of week [0, 6]
462
+ 'month': int month [0, 11]
463
+ 'dt': float seconds since previous point (0 for first)
464
+ 'fractional_second': float sub-second component [0, 1)
465
+ """
466
+ import datetime
467
+
468
+ # Convert to datetime objects for calendar features
469
+ dts = [datetime.datetime.utcfromtimestamp(t) for t in timestamps]
470
+
471
+ hours = np.array([d.hour for d in dts], dtype=np.int64)
472
+ dows = np.array([d.weekday() for d in dts], dtype=np.int64)
473
+ months = np.array([d.month - 1 for d in dts], dtype=np.int64) # 0-indexed
474
+
475
+ # Second of day (with fractional seconds)
476
+ second_of_day = np.array([
477
+ d.hour * 3600 + d.minute * 60 + d.second + d.microsecond / 1e6
478
+ for d in dts
479
+ ])
480
+
481
+ # Delta-t between consecutive points
482
+ dt = np.zeros(len(timestamps))
483
+ dt[1:] = np.diff(timestamps)
484
+
485
+ # Fractional second component
486
+ fractional_second = timestamps - np.floor(timestamps)
487
+
488
+ return {
489
+ 'second_of_day': second_of_day,
490
+ 'hour': hours,
491
+ 'dow': dows,
492
+ 'month': months,
493
+ 'dt': dt,
494
+ 'fractional_second': fractional_second,
495
+ }
496
+
497
+
498
+ # ============================================================
499
+ # 8. Full Trajectory Processor
500
+ # ============================================================
501
+
502
+ class TrajectoryProcessor:
503
+ """
504
+ Complete pipeline: raw ADS-B → model-ready features.
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ resample_dt: float = 5.0, # resample interval in seconds
510
+ geohash_precision: int = 40, # bits per axis
511
+ n_uncertainty_bins: int = 16,
512
+ feature_bins: Optional[FeatureBins] = None,
513
+ min_trajectory_len: int = 20, # minimum points after processing
514
+ ):
515
+ self.resample_dt = resample_dt
516
+ self.geohash_precision = geohash_precision
517
+ self.n_uncertainty_bins = n_uncertainty_bins
518
+ self.feature_bins = feature_bins or FeatureBins()
519
+ self.min_trajectory_len = min_trajectory_len
520
+ self.geohash_encoder = GeohashEncoder(precision=geohash_precision)
521
+
522
+ # Fit state
523
+ self._fitted = False
524
+
525
+ def resample_trajectory(
526
+ self, timestamps: np.ndarray, lats: np.ndarray, lons: np.ndarray, alts: np.ndarray
527
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
528
+ """Resample trajectory to fixed time intervals via linear interpolation."""
529
+ t_start = timestamps[0]
530
+ t_end = timestamps[-1]
531
+
532
+ n_points = int((t_end - t_start) / self.resample_dt) + 1
533
+ t_new = np.linspace(t_start, t_start + (n_points - 1) * self.resample_dt, n_points)
534
+
535
+ lats_new = np.interp(t_new, timestamps, lats)
536
+ lons_new = np.interp(t_new, timestamps, lons)
537
+ alts_new = np.interp(t_new, timestamps, alts)
538
+
539
+ return t_new, lats_new, lons_new, alts_new
540
+
541
+ def process_trajectory(
542
+ self,
543
+ timestamps: np.ndarray,
544
+ lats: np.ndarray,
545
+ lons: np.ndarray,
546
+ alts: np.ndarray,
547
+ metadata: Optional[Dict] = None
548
+ ) -> Optional[Dict[str, np.ndarray]]:
549
+ """
550
+ Process a single trajectory from raw ADS-B to model features.
551
+
552
+ Returns None if trajectory is too short or invalid.
553
+ Returns dict with all features needed for the model.
554
+ """
555
+ # Sort by time
556
+ sort_idx = np.argsort(timestamps)
557
+ timestamps = timestamps[sort_idx]
558
+ lats = lats[sort_idx]
559
+ lons = lons[sort_idx]
560
+ alts = alts[sort_idx]
561
+
562
+ # Resample to fixed interval
563
+ timestamps, lats, lons, alts = self.resample_trajectory(timestamps, lats, lons, alts)
564
+
565
+ if len(timestamps) < self.min_trajectory_len:
566
+ return None
567
+
568
+ # Convert to ENU (origin = first point)
569
+ converter = ENUConverter(lats[0], lons[0], alts[0])
570
+ east, north, up = converter.to_enu(lats, lons, alts)
571
+
572
+ # Derive features via 3-point derivative on ENU
573
+ features = derive_features_enu(east, north, up, timestamps)
574
+
575
+ # Binary geohash encoding on ENU positions
576
+ # If encoder not yet fitted, store placeholder (will be re-encoded after fitting)
577
+ if self.geohash_encoder.e_min is not None:
578
+ geohash_bits = self.geohash_encoder.encode(east, north, up) # (N, 120)
579
+ else:
580
+ geohash_bits = np.zeros((len(east), self.geohash_precision * 3), dtype=np.int64)
581
+
582
+ # Discretize kinematic features
583
+ cog_bins = self.feature_bins.encode_cog(features['COG'])
584
+ sog_bins = self.feature_bins.encode_sog(features['SOG'])
585
+ rot_bins = self.feature_bins.encode_rot(features['ROT'])
586
+ alt_rate_bins = self.feature_bins.encode_alt_rate(features['alt_rate'])
587
+
588
+ # Uncertainty — multiple methods
589
+ from uncertainty import (
590
+ compute_all_uncertainties, discretize_scores, UncertaintyConfig
591
+ )
592
+ uncert_config = UncertaintyConfig(
593
+ use_kinematic_variance=True,
594
+ use_prediction_residual=True,
595
+ use_spatial_density=True,
596
+ use_flight_phase_entropy=True,
597
+ use_temporal_irregularity=False,
598
+ n_bins=self.n_uncertainty_bins,
599
+ window=5,
600
+ )
601
+ raw_uncert = compute_all_uncertainties(
602
+ east, north, up, timestamps,
603
+ features['COG'], features['SOG'], features['ROT'], features['alt_rate'],
604
+ config=uncert_config,
605
+ )
606
+ # Discretize each method into bins → stack into (N, n_methods) array
607
+ uncert_methods = sorted(raw_uncert.keys())
608
+ uncert_bins_multi = np.stack([
609
+ discretize_scores(raw_uncert[m], self.n_uncertainty_bins)
610
+ for m in uncert_methods
611
+ ], axis=1) # (N, n_methods)
612
+
613
+ # Also keep legacy single-method for backwards compat
614
+ if 'kinematic_var' in raw_uncert:
615
+ uncert_bins = discretize_scores(raw_uncert['kinematic_var'], self.n_uncertainty_bins)
616
+ else:
617
+ uncert_bins = uncert_bins_multi[:, 0]
618
+
619
+ # Temporal features
620
+ temporal = extract_temporal_features(timestamps)
621
+
622
+ return {
623
+ # Raw (for evaluation/debugging)
624
+ 'timestamps': timestamps,
625
+ 'lats': lats,
626
+ 'lons': lons,
627
+ 'alts': alts,
628
+ 'east': east,
629
+ 'north': north,
630
+ 'up': up,
631
+
632
+ # Continuous features
633
+ 'COG': features['COG'],
634
+ 'SOG': features['SOG'],
635
+ 'ROT': features['ROT'],
636
+ 'alt_rate': features['alt_rate'],
637
+ 'vx': features['vx'],
638
+ 'vy': features['vy'],
639
+ 'vz': features['vz'],
640
+
641
+ # Geohash (binary, 120 bits per timestep)
642
+ 'geohash_bits': geohash_bits,
643
+
644
+ # Discretized features (bin indices)
645
+ 'cog_bins': cog_bins,
646
+ 'sog_bins': sog_bins,
647
+ 'rot_bins': rot_bins,
648
+ 'alt_rate_bins': alt_rate_bins,
649
+
650
+ # Uncertainty (bin indices)
651
+ 'uncert_bins': uncert_bins, # (N,) legacy single method
652
+ 'uncert_bins_multi': uncert_bins_multi, # (N, n_methods) multi-method
653
+ 'uncert_method_names': uncert_methods, # list of method names
654
+
655
+ # Temporal
656
+ 'hour': temporal['hour'],
657
+ 'dow': temporal['dow'],
658
+ 'month': temporal['month'],
659
+ 'second_of_day': temporal['second_of_day'],
660
+ 'dt': temporal['dt'],
661
+
662
+ # ENU converter (for decoding predictions back to lat/lon)
663
+ 'enu_origin': (converter.origin_lat, converter.origin_lon, converter.origin_alt),
664
+
665
+ # Metadata
666
+ 'metadata': metadata or {},
667
+ }
668
+
669
+ def fit_geohash(self, all_east: np.ndarray, all_north: np.ndarray, all_up: np.ndarray):
670
+ """Fit geohash normalization bounds from all training trajectories."""
671
+ self.geohash_encoder.fit(all_east, all_north, all_up)
672
+ self._fitted = True
673
+
674
+
675
+ # ============================================================
676
+ # 9. PyTorch Dataset with Sliding Window
677
+ # ============================================================
678
+
679
+ @dataclass
680
+ class PromptTokens:
681
+ """Prompt token IDs for metadata encoding."""
682
+ # Special tokens
683
+ BOS: int = 0
684
+ EOS: int = 1
685
+ PAD: int = 2
686
+
687
+ # Task tokens
688
+ PREDICT: int = 3
689
+ CLASSIFY: int = 4
690
+ DETECT_ANOMALY: int = 5
691
+
692
+ # Aircraft category
693
+ HEAVY: int = 6
694
+ LARGE: int = 7
695
+ SMALL: int = 8
696
+ ROTORCRAFT: int = 9
697
+ GLIDER: int = 10
698
+ UAV: int = 11
699
+ AIRCRAFT_UNKNOWN: int = 12
700
+
701
+ # Flight phase
702
+ CLIMB: int = 13
703
+ CRUISE: int = 14
704
+ DESCENT: int = 15
705
+ APPROACH: int = 16
706
+ GROUND: int = 17
707
+ PHASE_UNKNOWN: int = 18
708
+
709
+ # Region
710
+ CONUS: int = 19
711
+ EUROPE: int = 20
712
+ ASIA: int = 21
713
+ REGION_OTHER: int = 22
714
+
715
+ VOCAB_SIZE: int = 23
716
+
717
+
718
+ class AirTrackDataset(Dataset):
719
+ """
720
+ Sliding-window dataset for next-state prediction.
721
+
722
+ Each sample is a window of `seq_len` consecutive states.
723
+ The model predicts state[t+1] from state[1:t] for all t.
724
+ """
725
+
726
+ def __init__(
727
+ self,
728
+ trajectories: List[Dict[str, np.ndarray]],
729
+ seq_len: int = 128,
730
+ stride: int = 64,
731
+ task: str = 'predict', # 'predict' or 'classify'
732
+ ):
733
+ self.seq_len = seq_len
734
+ self.stride = stride
735
+ self.task = task
736
+
737
+ # Build index of (trajectory_idx, start_pos) for all valid windows
738
+ self.windows = []
739
+ self.trajectories = trajectories
740
+
741
+ for traj_idx, traj in enumerate(trajectories):
742
+ n_points = len(traj['timestamps'])
743
+ # Need seq_len + 1 points (seq_len inputs + 1 target for last position)
744
+ if n_points < seq_len + 1:
745
+ # Use entire trajectory if it's at least min length
746
+ if n_points >= 20:
747
+ self.windows.append((traj_idx, 0, n_points))
748
+ continue
749
+
750
+ for start in range(0, n_points - seq_len, stride):
751
+ end = start + seq_len + 1 # +1 for next-state target
752
+ if end <= n_points:
753
+ self.windows.append((traj_idx, start, end))
754
+
755
+ # Prompt tokens
756
+ self.prompt_tokens = PromptTokens()
757
+
758
+ def __len__(self):
759
+ return len(self.windows)
760
+
761
+ def __getitem__(self, idx):
762
+ traj_idx, start, end = self.windows[idx]
763
+ traj = self.trajectories[traj_idx]
764
+
765
+ # Slice the window
766
+ sl = slice(start, end)
767
+
768
+ # Geohash bits: (window_len, 120)
769
+ geohash_bits = torch.from_numpy(traj['geohash_bits'][sl]).float()
770
+
771
+ # Discretized features
772
+ cog_bins = torch.from_numpy(traj['cog_bins'][sl]).long()
773
+ sog_bins = torch.from_numpy(traj['sog_bins'][sl]).long()
774
+ rot_bins = torch.from_numpy(traj['rot_bins'][sl]).long()
775
+ alt_rate_bins = torch.from_numpy(traj['alt_rate_bins'][sl]).long()
776
+
777
+ # Uncertainty bins (single + multi)
778
+ uncert_bins = torch.from_numpy(traj['uncert_bins'][sl]).long()
779
+ if 'uncert_bins_multi' in traj:
780
+ uncert_bins_multi = torch.from_numpy(traj['uncert_bins_multi'][sl]).long()
781
+ else:
782
+ uncert_bins_multi = uncert_bins.unsqueeze(-1)
783
+
784
+ # Temporal features
785
+ hour = torch.from_numpy(traj['hour'][sl]).long()
786
+ dow = torch.from_numpy(traj['dow'][sl]).long()
787
+ month = torch.from_numpy(traj['month'][sl]).long()
788
+
789
+ # Second-of-day as continuous feature (for sinusoidal encoding)
790
+ second_of_day = torch.from_numpy(traj['second_of_day'][sl]).float()
791
+
792
+ # Delta-t between points
793
+ dt = torch.from_numpy(traj['dt'][sl]).float()
794
+
795
+ # Prompt tokens (fixed for prediction task)
796
+ task_token = self.prompt_tokens.PREDICT if self.task == 'predict' else self.prompt_tokens.CLASSIFY
797
+ prompt = torch.tensor([
798
+ self.prompt_tokens.BOS,
799
+ task_token,
800
+ self.prompt_tokens.AIRCRAFT_UNKNOWN, # default; override with metadata
801
+ self.prompt_tokens.PHASE_UNKNOWN,
802
+ self.prompt_tokens.REGION_OTHER,
803
+ ], dtype=torch.long)
804
+
805
+ # Continuous ENU positions (for evaluation / regression head)
806
+ east = torch.from_numpy(traj['east'][sl]).float()
807
+ north = torch.from_numpy(traj['north'][sl]).float()
808
+ up = torch.from_numpy(traj['up'][sl]).float()
809
+
810
+ return {
811
+ 'geohash_bits': geohash_bits,
812
+ 'cog_bins': cog_bins,
813
+ 'sog_bins': sog_bins,
814
+ 'rot_bins': rot_bins,
815
+ 'alt_rate_bins': alt_rate_bins,
816
+ 'uncert_bins': uncert_bins,
817
+ 'uncert_bins_multi': uncert_bins_multi,
818
+ 'hour': hour,
819
+ 'dow': dow,
820
+ 'month': month,
821
+ 'second_of_day': second_of_day,
822
+ 'dt': dt,
823
+ 'prompt': prompt,
824
+ 'east': east,
825
+ 'north': north,
826
+ 'up': up,
827
+ }
828
+
829
+
830
+ # ============================================================
831
+ # 10. Data Loading Utilities
832
+ # ============================================================
833
+
834
+ def load_traffic_sample(name: str = 'quickstart') -> List[Dict]:
835
+ """
836
+ Load sample data from the `traffic` library.
837
+
838
+ Available collections: 'quickstart' (238 flights), 'switzerland', 'savan'
839
+ Individual flights: 'landing_denver', calibration flights, etc.
840
+ """
841
+ import traffic.data.samples as samples
842
+
843
+ data = getattr(samples, name)
844
+ trajectories = []
845
+
846
+ # Handle both Traffic (collection) and Flight (single) objects
847
+ flights = data if hasattr(data, '__iter__') else [data]
848
+
849
+ for flight in flights:
850
+ df = flight.data
851
+
852
+ if df is None or len(df) < 20:
853
+ continue
854
+
855
+ # Extract required columns — handle tz-aware and PyArrow timestamps
856
+ import pandas as pd
857
+ ts_series = pd.to_datetime(df['timestamp'])
858
+ if ts_series.dt.tz is not None:
859
+ ts_series = ts_series.dt.tz_convert('UTC').dt.tz_localize(None)
860
+ timestamps = ts_series.values.astype('int64').astype(np.float64) / 1e9
861
+ lats = df['latitude'].values.astype(np.float64)
862
+ lons = df['longitude'].values.astype(np.float64)
863
+
864
+ # Altitude: try barometric first, then geometric
865
+ if 'altitude' in df.columns:
866
+ alts = df['altitude'].values.astype(np.float64)
867
+ elif 'baro_altitude' in df.columns:
868
+ alts = df['baro_altitude'].values.astype(np.float64)
869
+ else:
870
+ alts = np.zeros(len(df))
871
+
872
+ # Handle NaNs
873
+ valid = ~(np.isnan(lats) | np.isnan(lons) | np.isnan(alts) | np.isnan(timestamps))
874
+ if valid.sum() < 20:
875
+ continue
876
+
877
+ trajectories.append({
878
+ 'timestamps': timestamps[valid],
879
+ 'lats': lats[valid],
880
+ 'lons': lons[valid],
881
+ 'alts': alts[valid],
882
+ 'callsign': flight.callsign if hasattr(flight, 'callsign') else 'UNKNOWN',
883
+ 'icao24': flight.icao24 if hasattr(flight, 'icao24') else 'UNKNOWN',
884
+ })
885
+
886
+ return trajectories
887
+
888
+
889
+ def build_dataset(
890
+ raw_trajectories: List[Dict],
891
+ processor: TrajectoryProcessor,
892
+ seq_len: int = 128,
893
+ stride: int = 64,
894
+ fit_geohash: bool = True,
895
+ ) -> AirTrackDataset:
896
+ """
897
+ Process raw trajectories and build PyTorch dataset.
898
+
899
+ Args:
900
+ raw_trajectories: list of dicts with 'timestamps', 'lats', 'lons', 'alts'
901
+ processor: TrajectoryProcessor instance
902
+ seq_len: sliding window size
903
+ stride: sliding window stride
904
+ fit_geohash: if True, fit geohash bounds from this data
905
+ """
906
+ # First pass: convert to ENU and collect bounds for geohash fitting
907
+ processed = []
908
+ all_east, all_north, all_up = [], [], []
909
+
910
+ for raw in raw_trajectories:
911
+ result = processor.process_trajectory(
912
+ raw['timestamps'], raw['lats'], raw['lons'], raw['alts'],
913
+ metadata={k: v for k, v in raw.items() if k not in ['timestamps', 'lats', 'lons', 'alts']}
914
+ )
915
+ if result is not None:
916
+ processed.append(result)
917
+ all_east.append(result['east'])
918
+ all_north.append(result['north'])
919
+ all_up.append(result['up'])
920
+
921
+ if fit_geohash and processed:
922
+ # Fit geohash bounds from all trajectories
923
+ all_e = np.concatenate(all_east)
924
+ all_n = np.concatenate(all_north)
925
+ all_u = np.concatenate(all_up)
926
+ processor.fit_geohash(all_e, all_n, all_u)
927
+
928
+ # Re-encode geohash with fitted bounds
929
+ for traj in processed:
930
+ traj['geohash_bits'] = processor.geohash_encoder.encode(
931
+ traj['east'], traj['north'], traj['up']
932
+ )
933
+
934
+ print(f"Processed {len(processed)}/{len(raw_trajectories)} trajectories")
935
+
936
+ dataset = AirTrackDataset(processed, seq_len=seq_len, stride=stride)
937
+ print(f"Created dataset with {len(dataset)} windows")
938
+
939
+ return dataset
940
+
941
+
942
+ if __name__ == '__main__':
943
+ # Quick test with traffic sample data
944
+ print("Loading traffic sample data...")
945
+ raw_trajs = load_traffic_sample()
946
+ print(f"Loaded {len(raw_trajs)} raw trajectories")
947
+
948
+ print("\nProcessing trajectories...")
949
+ processor = TrajectoryProcessor(resample_dt=5.0)
950
+ dataset = build_dataset(raw_trajs, processor, seq_len=64, stride=32)
951
+
952
+ print(f"\nDataset size: {len(dataset)}")
953
+ if len(dataset) > 0:
954
+ sample = dataset[0]
955
+ print("\nSample keys and shapes:")
956
+ for k, v in sample.items():
957
+ if isinstance(v, torch.Tensor):
958
+ print(f" {k}: {v.shape} ({v.dtype})")
959
+ else:
960
+ print(f" {k}: {type(v)}")