Jdice27 commited on
Commit
744a6a7
·
verified ·
1 Parent(s): 1773f5e

Upload data_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data_pipeline.py +171 -582
data_pipeline.py CHANGED
@@ -6,13 +6,21 @@ Converts raw ADS-B (lat, lon, alt, timestamp) to model-ready tensors.
6
  Pipeline:
7
  1. Load trajectories from `traffic` library or raw CSV
8
  2. Resample to fixed time interval (default 5s)
9
- 3. Convert lat/lon/alt to ENU (East-North-Up) coordinates using first point as origin
10
  4. Compute velocity via 3-point central derivative on ENU positions
11
  5. Derive COG, SOG from x-y ground velocity; ROT from COG; altitude rate from z velocity
12
- 6. Binary geohash encoding (40-bit per axis, following LLM4STP)
13
  7. Discretize features into bins
14
  8. Compute uncertainty scores
15
  9. Build sliding-window PyTorch Dataset
 
 
 
 
 
 
 
 
16
  """
17
 
18
  import numpy as np
@@ -29,10 +37,8 @@ from dataclasses import dataclass, field
29
 
30
  class ENUConverter:
31
  """
32
- Convert WGS84 (lat, lon, alt) to local East-North-Up (ENU) coordinates.
33
- Origin is set to the first point of each trajectory.
34
-
35
- Uses pyproj for geodetically correct transformations.
36
  """
37
 
38
  def __init__(self, origin_lat: float, origin_lon: float, origin_alt: float = 0.0):
@@ -40,18 +46,15 @@ class ENUConverter:
40
  self.origin_lon = origin_lon
41
  self.origin_alt = origin_alt
42
 
43
- # ECEF transformer
44
  self.ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
45
  self.lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
46
  self.transformer_to_ecef = pyproj.Transformer.from_proj(self.lla, self.ecef, always_xy=True)
47
  self.transformer_to_lla = pyproj.Transformer.from_proj(self.ecef, self.lla, always_xy=True)
48
 
49
- # Origin in ECEF
50
  self.x0, self.y0, self.z0 = self.transformer_to_ecef.transform(
51
  origin_lon, origin_lat, origin_alt
52
  )
53
 
54
- # Rotation matrix (ECEF -> ENU)
55
  lat_r = np.radians(origin_lat)
56
  lon_r = np.radians(origin_lon)
57
  self.R = np.array([
@@ -60,108 +63,50 @@ class ENUConverter:
60
  [ np.cos(lat_r)*np.cos(lon_r), np.cos(lat_r)*np.sin(lon_r), np.sin(lat_r)]
61
  ])
62
 
63
- def to_enu(self, lats: np.ndarray, lons: np.ndarray, alts: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
64
- """Convert arrays of lat/lon/alt to ENU (meters)."""
65
- # To ECEF
66
  x, y, z = self.transformer_to_ecef.transform(lons, lats, alts)
67
-
68
- # Offset from origin
69
  dx = x - self.x0
70
  dy = y - self.y0
71
  dz = z - self.z0
72
-
73
- # Rotate to ENU
74
- ecef_delta = np.stack([dx, dy, dz], axis=0) # (3, N)
75
- enu = self.R @ ecef_delta # (3, N)
76
-
77
- east = enu[0] # meters
78
- north = enu[1] # meters
79
- up = enu[2] # meters
80
-
81
- return east, north, up
82
 
83
- def from_enu(self, east: np.ndarray, north: np.ndarray, up: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
84
- """Convert ENU back to lat/lon/alt."""
85
  enu = np.stack([east, north, up], axis=0)
86
  ecef_delta = self.R.T @ enu
87
-
88
  x = ecef_delta[0] + self.x0
89
  y = ecef_delta[1] + self.y0
90
  z = ecef_delta[2] + self.z0
91
-
92
  lons, lats, alts = self.transformer_to_lla.transform(x, y, z)
93
  return lats, lons, alts
94
 
95
 
96
  # ============================================================
97
- # 2. Three-Point Central Derivative
98
  # ============================================================
99
 
100
- def three_point_derivative(values: np.ndarray, dt: np.ndarray) -> np.ndarray:
101
  """
102
- Compute derivative using 3-point central difference.
103
-
104
- For interior points (i=1..N-2):
105
- f'(i) = (f(i+1) - f(i-1)) / (t(i+1) - t(i-1))
106
-
107
- For endpoints:
108
- f'(0) = (f(1) - f(0)) / (t(1) - t(0)) # forward difference
109
- f'(N-1) = (f(N-1) - f(N-2)) / (t(N-1) - t(N-2)) # backward difference
110
-
111
- Args:
112
- values: shape (N,) — the signal to differentiate
113
- dt: shape (N,) — cumulative time from start (seconds)
114
-
115
- Returns:
116
- derivative: shape (N,) — rate of change per second
117
  """
118
  N = len(values)
119
  deriv = np.zeros(N)
120
-
121
  if N < 2:
122
  return deriv
123
 
124
- # Forward difference for first point
125
- dt_fwd = dt[1] - dt[0]
126
  if dt_fwd > 0:
127
  deriv[0] = (values[1] - values[0]) / dt_fwd
128
 
129
- # Central difference for interior points
130
- for i in range(1, N - 1):
131
- dt_span = dt[i + 1] - dt[i - 1]
132
- if dt_span > 0:
133
- deriv[i] = (values[i + 1] - values[i - 1]) / dt_span
134
-
135
- # Backward difference for last point
136
- dt_bwd = dt[-1] - dt[-2]
137
- if dt_bwd > 0:
138
- deriv[-1] = (values[-1] - values[-2]) / dt_bwd
139
-
140
- return deriv
141
-
142
-
143
- def three_point_derivative_vectorized(values: np.ndarray, dt: np.ndarray) -> np.ndarray:
144
- """Vectorized version of 3-point central derivative."""
145
- N = len(values)
146
- deriv = np.zeros(N)
147
-
148
- if N < 2:
149
- return deriv
150
-
151
- # Forward difference for first point
152
- dt_fwd = dt[1] - dt[0]
153
- if dt_fwd > 0:
154
- deriv[0] = (values[1] - values[0]) / dt_fwd
155
-
156
- # Central difference for interior points (vectorized)
157
  if N > 2:
158
- dt_span = dt[2:] - dt[:-2] # (N-2,)
159
  mask = dt_span > 0
160
- val_diff = values[2:] - values[:-2] # (N-2,)
161
  deriv[1:-1] = np.where(mask, val_diff / np.maximum(dt_span, 1e-10), 0.0)
162
 
163
- # Backward difference for last point
164
- dt_bwd = dt[-1] - dt[-2]
165
  if dt_bwd > 0:
166
  deriv[-1] = (values[-1] - values[-2]) / dt_bwd
167
 
@@ -172,172 +117,115 @@ def three_point_derivative_vectorized(values: np.ndarray, dt: np.ndarray) -> np.
172
  # 3. Feature Derivation from ENU positions
173
  # ============================================================
174
 
175
- def derive_features_enu(
176
- east: np.ndarray,
177
- north: np.ndarray,
178
- up: np.ndarray,
179
- timestamps: np.ndarray
180
- ) -> Dict[str, np.ndarray]:
181
  """
182
- Derive COG, SOG, ROT, and altitude rate from ENU positions
183
- using 3-point central derivatives.
184
-
185
- Args:
186
- east, north, up: ENU coordinates in meters, shape (N,)
187
- timestamps: Unix timestamps in seconds, shape (N,)
188
 
189
- Returns:
190
- dict with keys: 'vx', 'vy', 'vz', 'COG', 'SOG', 'ROT', 'alt_rate'
191
- Each is shape (N,)
 
192
  """
193
- # Cumulative time from start
194
  t = timestamps - timestamps[0]
195
 
196
- # 3-point derivative on ENU positions → velocities (m/s)
197
- vx = three_point_derivative_vectorized(east, t) # East velocity (m/s)
198
- vy = three_point_derivative_vectorized(north, t) # North velocity (m/s)
199
- vz = three_point_derivative_vectorized(up, t) # Up velocity (m/s)
200
 
201
- # SOG from ground plane velocity: sqrt(vx² + vy²), convert m/s → knots
202
  sog_ms = np.sqrt(vx**2 + vy**2)
203
- sog_knots = sog_ms * 1.94384 # m/s to knots
204
 
205
- # COG from ground velocity components: atan2(vx, vy) degrees from North
206
- # atan2(East, North) gives bearing from North, clockwise
207
  cog_deg = np.degrees(np.arctan2(vx, vy)) % 360
208
 
209
- # ROT: derivative of COG (degrees/second)
210
- # Need to handle circular wraparound — unwrap COG first
211
  cog_unwrapped = np.unwrap(np.radians(cog_deg))
212
- rot_rad_s = three_point_derivative_vectorized(cog_unwrapped, t)
213
  rot_deg_s = np.degrees(rot_rad_s)
214
 
215
- # Altitude rate: vz converted to ft/min
216
- alt_rate_ftmin = vz * 196.85 # m/s → ft/min (1 m/s = 196.85 ft/min)
217
 
218
  return {
219
- 'vx': vx,
220
- 'vy': vy,
221
- 'vz': vz,
222
- 'COG': cog_deg,
223
- 'SOG': sog_knots,
224
- 'ROT': rot_deg_s,
225
- 'alt_rate': alt_rate_ftmin,
226
  }
227
 
228
 
229
  # ============================================================
230
- # 4. Binary Geohash Encoding (following LLM4STP, 40-bit precision)
231
  # ============================================================
232
 
233
- def binary_geohash_encode(
234
- values: np.ndarray,
235
- precision: int = 40,
236
- v_min: float = 0.0,
237
- v_max: float = 1.0
238
- ) -> np.ndarray:
239
- """
240
- Encode normalized values as binary geohash via successive bisection.
241
- Matches LLM4STP's num2bits() implementation.
242
-
243
- Args:
244
- values: shape (N,) — normalized to [v_min, v_max]
245
- precision: number of bits
246
- v_min, v_max: range bounds
247
-
248
- Returns:
249
- bits: shape (N, precision) — binary encoding (0/1)
250
- """
251
  N = len(values)
252
  bits = np.zeros((N, precision), dtype=np.int64)
253
-
254
  _min = np.full(N, v_min)
255
  _max = np.full(N, v_max)
256
-
257
  for p in range(precision):
258
  mid = (_min + _max) / 2
259
  mask = values > mid
260
  bits[:, p] = mask.astype(np.int64)
261
  _min = np.where(mask, mid, _min)
262
  _max = np.where(mask, _max, mid)
263
-
264
  return bits
265
 
266
 
267
- def binary_geohash_decode(
268
- bits: np.ndarray,
269
- precision: int = 40,
270
- v_min: float = 0.0,
271
- v_max: float = 1.0
272
- ) -> np.ndarray:
273
- """Decode binary geohash back to values."""
274
  N = bits.shape[0]
275
  _min = np.full(N, v_min)
276
  _max = np.full(N, v_max)
277
-
278
  for p in range(precision):
279
  mid = (_min + _max) / 2
280
  mask = bits[:, p].astype(bool)
281
  _min = np.where(mask, mid, _min)
282
  _max = np.where(mask, _max, mid)
283
-
284
  return (_min + _max) / 2
285
 
286
 
287
  class GeohashEncoder:
288
  """
289
- 3D geohash encoder for aviation.
290
- Encodes (east, north, up) ENU coordinates as binary geohash.
291
 
292
- Following LLM4STP: 40 bits per axis, coordinates normalized to [0,1].
293
- Extended to 3 axes (E, N, U) for aviation 3D trajectory.
294
-
295
- Total encoding: 40*3 = 120 bits per timestep.
296
  """
297
 
298
- def __init__(self, precision: int = 40):
299
  self.precision = precision
300
- # Normalization bounds set from training data
301
- self.e_min = None
302
- self.e_max = None
303
- self.n_min = None
304
- self.n_max = None
305
- self.u_min = None
306
- self.u_max = None
307
-
308
- def fit(self, east: np.ndarray, north: np.ndarray, up: np.ndarray, margin: float = 0.05):
309
- """Set normalization bounds from training data with margin."""
310
- e_range = east.max() - east.min()
311
- n_range = north.max() - north.min()
312
- u_range = up.max() - up.min()
313
-
314
- self.e_min = east.min() - margin * max(e_range, 1.0)
315
- self.e_max = east.max() + margin * max(e_range, 1.0)
316
- self.n_min = north.min() - margin * max(n_range, 1.0)
317
- self.n_max = north.max() + margin * max(n_range, 1.0)
318
- self.u_min = up.min() - margin * max(u_range, 1.0)
319
- self.u_max = up.max() + margin * max(u_range, 1.0)
320
-
321
- def normalize(self, values: np.ndarray, v_min: float, v_max: float) -> np.ndarray:
322
- """Normalize to [0, 1]."""
323
  return np.clip((values - v_min) / max(v_max - v_min, 1e-10), 0.0, 1.0)
324
 
325
- def encode(self, east: np.ndarray, north: np.ndarray, up: np.ndarray) -> np.ndarray:
326
- """
327
- Encode ENU positions as 3D binary geohash.
328
-
329
- Returns:
330
- bits: shape (N, precision*3) — concatenated [E_bits | N_bits | U_bits]
331
- """
332
- e_norm = self.normalize(east, self.e_min, self.e_max)
333
- n_norm = self.normalize(north, self.n_min, self.n_max)
334
- u_norm = self.normalize(up, self.u_min, self.u_max)
335
-
336
  e_bits = binary_geohash_encode(e_norm, self.precision)
337
  n_bits = binary_geohash_encode(n_norm, self.precision)
338
  u_bits = binary_geohash_encode(u_norm, self.precision)
339
-
340
  return np.concatenate([e_bits, n_bits, u_bits], axis=1) # (N, 120)
 
 
 
 
 
 
 
341
 
342
 
343
  # ============================================================
@@ -346,413 +234,187 @@ class GeohashEncoder:
346
 
347
  @dataclass
348
  class FeatureBins:
349
- """Configuration for discretizing continuous features into bins."""
350
-
351
- # COG: [0, 360) degrees, 2° bins for high resolution
352
- cog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 360, 181)) # 180 bins
353
-
354
- # SOG: [0, 600] knots, 2-knot bins
355
- sog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 600, 301)) # 300 bins
356
-
357
- # ROT: [-6, 6] deg/s, 0.1 deg/s bins
358
- rot_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6, 6, 121)) # 120 bins
359
-
360
- # Altitude rate: [-6000, 6000] ft/min, 100 ft/min bins
361
  alt_rate_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6000, 6000, 121)) # 120 bins
362
 
363
  @property
364
  def n_cog_bins(self): return len(self.cog_edges) - 1
365
-
366
  @property
367
  def n_sog_bins(self): return len(self.sog_edges) - 1
368
-
369
  @property
370
  def n_rot_bins(self): return len(self.rot_edges) - 1
371
-
372
  @property
373
  def n_alt_rate_bins(self): return len(self.alt_rate_edges) - 1
374
 
375
- def digitize(self, values: np.ndarray, edges: np.ndarray) -> np.ndarray:
376
- """Bin values. Returns indices in [0, n_bins-1], clipped."""
377
- indices = np.digitize(values, edges) - 1
378
- return np.clip(indices, 0, len(edges) - 2)
379
 
380
- def encode_cog(self, cog: np.ndarray) -> np.ndarray:
381
- return self.digitize(cog, self.cog_edges)
382
-
383
- def encode_sog(self, sog: np.ndarray) -> np.ndarray:
384
- return self.digitize(sog, self.sog_edges)
385
-
386
- def encode_rot(self, rot: np.ndarray) -> np.ndarray:
387
- rot_clipped = np.clip(rot, -6, 6)
388
- return self.digitize(rot_clipped, self.rot_edges)
389
-
390
- def encode_alt_rate(self, alt_rate: np.ndarray) -> np.ndarray:
391
- ar_clipped = np.clip(alt_rate, -6000, 6000)
392
- return self.digitize(ar_clipped, self.alt_rate_edges)
393
 
394
 
395
  # ============================================================
396
- # 6. Uncertainty Score
397
  # ============================================================
398
 
399
- def compute_uncertainty(
400
- cog: np.ndarray,
401
- sog: np.ndarray,
402
- rot: np.ndarray,
403
- alt_rate: np.ndarray,
404
- window: int = 5
405
- ) -> np.ndarray:
406
- """
407
- Compute trajectory uncertainty score from recent state variance.
408
- High variance = high uncertainty (maneuvering aircraft).
409
-
410
- Returns:
411
- scores: shape (N,) — uncertainty scores (higher = more uncertain)
412
- """
413
- N = len(cog)
414
- scores = np.zeros(N)
415
-
416
- for i in range(N):
417
- start = max(0, i - window + 1)
418
- w = slice(start, i + 1)
419
-
420
- # Circular variance for COG
421
- cog_rad = np.radians(cog[w])
422
- R_len = np.sqrt(np.mean(np.cos(cog_rad))**2 + np.mean(np.sin(cog_rad))**2)
423
- cog_var = 1 - R_len # circular variance [0, 1]
424
-
425
- # Regular variance for others
426
- sog_var = np.var(sog[w]) if len(sog[w]) > 1 else 0
427
- rot_var = np.var(rot[w]) if len(rot[w]) > 1 else 0
428
- alt_var = np.var(alt_rate[w]) if len(alt_rate[w]) > 1 else 0
429
-
430
- # Normalize and combine (equal weights)
431
- scores[i] = cog_var + sog_var / max(np.var(sog) + 1e-10, 1e-10) + \
432
- rot_var / max(np.var(rot) + 1e-10, 1e-10) + \
433
- alt_var / max(np.var(alt_rate) + 1e-10, 1e-10)
434
-
435
- return scores
436
-
437
-
438
- def discretize_uncertainty(scores: np.ndarray, n_bins: int = 16) -> np.ndarray:
439
- """Discretize uncertainty scores into quantile bins."""
440
- if len(np.unique(scores)) < n_bins:
441
- # Not enough unique values for quantile binning
442
- edges = np.linspace(scores.min(), scores.max() + 1e-10, n_bins + 1)
443
- else:
444
- edges = np.quantile(scores, np.linspace(0, 1, n_bins + 1))
445
- edges[-1] += 1e-10 # ensure max value is included
446
-
447
- return np.clip(np.digitize(scores, edges) - 1, 0, n_bins - 1)
448
-
449
-
450
- # ============================================================
451
- # 7. Temporal Features
452
- # ============================================================
453
-
454
- def extract_temporal_features(timestamps: np.ndarray) -> Dict[str, np.ndarray]:
455
- """
456
- Extract temporal features from Unix timestamps.
457
-
458
- Returns dict with:
459
- 'second_of_day': float seconds within the day [0, 86400)
460
- 'hour': int hour of day [0, 23]
461
- 'dow': int day of week [0, 6]
462
- 'month': int month [0, 11]
463
- 'dt': float seconds since previous point (0 for first)
464
- 'fractional_second': float sub-second component [0, 1)
465
- """
466
  import datetime
467
 
468
- # Convert to datetime objects for calendar features
469
  dts = [datetime.datetime.utcfromtimestamp(t) for t in timestamps]
470
-
471
  hours = np.array([d.hour for d in dts], dtype=np.int64)
472
  dows = np.array([d.weekday() for d in dts], dtype=np.int64)
473
- months = np.array([d.month - 1 for d in dts], dtype=np.int64) # 0-indexed
474
 
475
- # Second of day (with fractional seconds)
476
  second_of_day = np.array([
477
  d.hour * 3600 + d.minute * 60 + d.second + d.microsecond / 1e6
478
  for d in dts
479
- ])
480
 
481
- # Delta-t between consecutive points
482
- dt = np.zeros(len(timestamps))
483
  dt[1:] = np.diff(timestamps)
484
 
485
- # Fractional second component
486
- fractional_second = timestamps - np.floor(timestamps)
487
-
488
  return {
489
  'second_of_day': second_of_day,
490
- 'hour': hours,
491
- 'dow': dows,
492
- 'month': months,
493
  'dt': dt,
494
- 'fractional_second': fractional_second,
495
  }
496
 
497
 
498
  # ============================================================
499
- # 8. Full Trajectory Processor
500
  # ============================================================
501
 
502
  class TrajectoryProcessor:
503
- """
504
- Complete pipeline: raw ADS-B → model-ready features.
505
- """
506
-
507
- def __init__(
508
- self,
509
- resample_dt: float = 5.0, # resample interval in seconds
510
- geohash_precision: int = 40, # bits per axis
511
- n_uncertainty_bins: int = 16,
512
- feature_bins: Optional[FeatureBins] = None,
513
- min_trajectory_len: int = 20, # minimum points after processing
514
- ):
515
  self.resample_dt = resample_dt
516
  self.geohash_precision = geohash_precision
517
  self.n_uncertainty_bins = n_uncertainty_bins
518
  self.feature_bins = feature_bins or FeatureBins()
519
  self.min_trajectory_len = min_trajectory_len
520
  self.geohash_encoder = GeohashEncoder(precision=geohash_precision)
521
-
522
- # Fit state
523
  self._fitted = False
524
 
525
- def resample_trajectory(
526
- self, timestamps: np.ndarray, lats: np.ndarray, lons: np.ndarray, alts: np.ndarray
527
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
528
- """Resample trajectory to fixed time intervals via linear interpolation."""
529
- t_start = timestamps[0]
530
- t_end = timestamps[-1]
531
-
532
  n_points = int((t_end - t_start) / self.resample_dt) + 1
 
 
533
  t_new = np.linspace(t_start, t_start + (n_points - 1) * self.resample_dt, n_points)
534
-
535
- lats_new = np.interp(t_new, timestamps, lats)
536
- lons_new = np.interp(t_new, timestamps, lons)
537
- alts_new = np.interp(t_new, timestamps, alts)
538
-
539
- return t_new, lats_new, lons_new, alts_new
540
-
541
- def process_trajectory(
542
- self,
543
- timestamps: np.ndarray,
544
- lats: np.ndarray,
545
- lons: np.ndarray,
546
- alts: np.ndarray,
547
- metadata: Optional[Dict] = None
548
- ) -> Optional[Dict[str, np.ndarray]]:
549
- """
550
- Process a single trajectory from raw ADS-B to model features.
551
-
552
- Returns None if trajectory is too short or invalid.
553
- Returns dict with all features needed for the model.
554
- """
555
- # Sort by time
556
  sort_idx = np.argsort(timestamps)
557
- timestamps = timestamps[sort_idx]
558
- lats = lats[sort_idx]
559
- lons = lons[sort_idx]
560
- alts = alts[sort_idx]
561
 
562
- # Resample to fixed interval
563
  timestamps, lats, lons, alts = self.resample_trajectory(timestamps, lats, lons, alts)
564
-
565
  if len(timestamps) < self.min_trajectory_len:
566
  return None
567
 
568
- # Convert to ENU (origin = first point)
569
  converter = ENUConverter(lats[0], lons[0], alts[0])
570
  east, north, up = converter.to_enu(lats, lons, alts)
571
 
572
- # Derive features via 3-point derivative on ENU
573
  features = derive_features_enu(east, north, up, timestamps)
574
 
575
- # Binary geohash encoding on ENU positions
576
- # If encoder not yet fitted, store placeholder (will be re-encoded after fitting)
577
  if self.geohash_encoder.e_min is not None:
578
- geohash_bits = self.geohash_encoder.encode(east, north, up) # (N, 120)
579
  else:
580
  geohash_bits = np.zeros((len(east), self.geohash_precision * 3), dtype=np.int64)
581
 
582
- # Discretize kinematic features
583
  cog_bins = self.feature_bins.encode_cog(features['COG'])
584
  sog_bins = self.feature_bins.encode_sog(features['SOG'])
585
  rot_bins = self.feature_bins.encode_rot(features['ROT'])
586
  alt_rate_bins = self.feature_bins.encode_alt_rate(features['alt_rate'])
587
 
588
- # Uncertainty — multiple methods
589
- from uncertainty import (
590
- compute_all_uncertainties, discretize_scores, UncertaintyConfig
591
- )
592
- uncert_config = UncertaintyConfig(
593
- use_kinematic_variance=True,
594
- use_prediction_residual=True,
595
- use_spatial_density=True,
596
- use_flight_phase_entropy=True,
597
- use_temporal_irregularity=False,
598
- n_bins=self.n_uncertainty_bins,
599
- window=5,
600
- )
601
  raw_uncert = compute_all_uncertainties(
602
  east, north, up, timestamps,
603
  features['COG'], features['SOG'], features['ROT'], features['alt_rate'],
604
  config=uncert_config,
605
  )
606
- # Discretize each method into bins → stack into (N, n_methods) array
607
  uncert_methods = sorted(raw_uncert.keys())
608
  uncert_bins_multi = np.stack([
609
  discretize_scores(raw_uncert[m], self.n_uncertainty_bins)
610
  for m in uncert_methods
611
- ], axis=1) # (N, n_methods)
612
-
613
- # Also keep legacy single-method for backwards compat
614
- if 'kinematic_var' in raw_uncert:
615
- uncert_bins = discretize_scores(raw_uncert['kinematic_var'], self.n_uncertainty_bins)
616
- else:
617
- uncert_bins = uncert_bins_multi[:, 0]
618
 
619
- # Temporal features
620
  temporal = extract_temporal_features(timestamps)
621
 
622
  return {
623
- # Raw (for evaluation/debugging)
624
- 'timestamps': timestamps,
625
- 'lats': lats,
626
- 'lons': lons,
627
- 'alts': alts,
628
- 'east': east,
629
- 'north': north,
630
- 'up': up,
631
-
632
- # Continuous features
633
- 'COG': features['COG'],
634
- 'SOG': features['SOG'],
635
- 'ROT': features['ROT'],
636
- 'alt_rate': features['alt_rate'],
637
- 'vx': features['vx'],
638
- 'vy': features['vy'],
639
- 'vz': features['vz'],
640
-
641
- # Geohash (binary, 120 bits per timestep)
642
  'geohash_bits': geohash_bits,
643
-
644
- # Discretized features (bin indices)
645
- 'cog_bins': cog_bins,
646
- 'sog_bins': sog_bins,
647
- 'rot_bins': rot_bins,
648
- 'alt_rate_bins': alt_rate_bins,
649
-
650
- # Uncertainty (bin indices)
651
- 'uncert_bins': uncert_bins, # (N,) legacy single method
652
- 'uncert_bins_multi': uncert_bins_multi, # (N, n_methods) multi-method
653
- 'uncert_method_names': uncert_methods, # list of method names
654
-
655
- # Temporal
656
- 'hour': temporal['hour'],
657
- 'dow': temporal['dow'],
658
- 'month': temporal['month'],
659
- 'second_of_day': temporal['second_of_day'],
660
- 'dt': temporal['dt'],
661
-
662
- # ENU converter (for decoding predictions back to lat/lon)
663
  'enu_origin': (converter.origin_lat, converter.origin_lon, converter.origin_alt),
664
-
665
- # Metadata
666
  'metadata': metadata or {},
667
  }
668
 
669
- def fit_geohash(self, all_east: np.ndarray, all_north: np.ndarray, all_up: np.ndarray):
670
- """Fit geohash normalization bounds from all training trajectories."""
671
  self.geohash_encoder.fit(all_east, all_north, all_up)
672
  self._fitted = True
673
 
674
 
675
  # ============================================================
676
- # 9. PyTorch Dataset with Sliding Window
677
  # ============================================================
678
 
679
  @dataclass
680
  class PromptTokens:
681
- """Prompt token IDs for metadata encoding."""
682
- # Special tokens
683
- BOS: int = 0
684
- EOS: int = 1
685
- PAD: int = 2
686
-
687
- # Task tokens
688
- PREDICT: int = 3
689
- CLASSIFY: int = 4
690
- DETECT_ANOMALY: int = 5
691
-
692
- # Aircraft category
693
- HEAVY: int = 6
694
- LARGE: int = 7
695
- SMALL: int = 8
696
- ROTORCRAFT: int = 9
697
- GLIDER: int = 10
698
- UAV: int = 11
699
- AIRCRAFT_UNKNOWN: int = 12
700
-
701
- # Flight phase
702
- CLIMB: int = 13
703
- CRUISE: int = 14
704
- DESCENT: int = 15
705
- APPROACH: int = 16
706
- GROUND: int = 17
707
- PHASE_UNKNOWN: int = 18
708
-
709
- # Region
710
- CONUS: int = 19
711
- EUROPE: int = 20
712
- ASIA: int = 21
713
- REGION_OTHER: int = 22
714
-
715
  VOCAB_SIZE: int = 23
716
 
717
 
 
 
 
 
718
  class AirTrackDataset(Dataset):
719
- """
720
- Sliding-window dataset for next-state prediction.
721
-
722
- Each sample is a window of `seq_len` consecutive states.
723
- The model predicts state[t+1] from state[1:t] for all t.
724
- """
725
-
726
- def __init__(
727
- self,
728
- trajectories: List[Dict[str, np.ndarray]],
729
- seq_len: int = 128,
730
- stride: int = 64,
731
- task: str = 'predict', # 'predict' or 'classify'
732
- ):
733
  self.seq_len = seq_len
734
  self.stride = stride
735
  self.task = task
736
-
737
- # Build index of (trajectory_idx, start_pos) for all valid windows
738
  self.windows = []
739
  self.trajectories = trajectories
740
 
741
  for traj_idx, traj in enumerate(trajectories):
742
  n_points = len(traj['timestamps'])
743
- # Need seq_len + 1 points (seq_len inputs + 1 target for last position)
744
  if n_points < seq_len + 1:
745
- # Use entire trajectory if it's at least min length
746
  if n_points >= 20:
747
  self.windows.append((traj_idx, 0, n_points))
748
  continue
749
-
750
  for start in range(0, n_points - seq_len, stride):
751
- end = start + seq_len + 1 # +1 for next-state target
752
  if end <= n_points:
753
  self.windows.append((traj_idx, start, end))
754
 
755
- # Prompt tokens
756
  self.prompt_tokens = PromptTokens()
757
 
758
  def __len__(self):
@@ -761,99 +423,56 @@ class AirTrackDataset(Dataset):
761
  def __getitem__(self, idx):
762
  traj_idx, start, end = self.windows[idx]
763
  traj = self.trajectories[traj_idx]
764
-
765
- # Slice the window
766
  sl = slice(start, end)
767
 
768
- # Geohash bits: (window_len, 120)
769
- geohash_bits = torch.from_numpy(traj['geohash_bits'][sl]).float()
770
-
771
- # Discretized features
772
- cog_bins = torch.from_numpy(traj['cog_bins'][sl]).long()
773
- sog_bins = torch.from_numpy(traj['sog_bins'][sl]).long()
774
- rot_bins = torch.from_numpy(traj['rot_bins'][sl]).long()
775
- alt_rate_bins = torch.from_numpy(traj['alt_rate_bins'][sl]).long()
776
-
777
- # Uncertainty bins (single + multi)
778
- uncert_bins = torch.from_numpy(traj['uncert_bins'][sl]).long()
779
- if 'uncert_bins_multi' in traj:
780
- uncert_bins_multi = torch.from_numpy(traj['uncert_bins_multi'][sl]).long()
781
- else:
782
- uncert_bins_multi = uncert_bins.unsqueeze(-1)
783
-
784
- # Temporal features
785
- hour = torch.from_numpy(traj['hour'][sl]).long()
786
- dow = torch.from_numpy(traj['dow'][sl]).long()
787
- month = torch.from_numpy(traj['month'][sl]).long()
788
-
789
- # Second-of-day as continuous feature (for sinusoidal encoding)
790
- second_of_day = torch.from_numpy(traj['second_of_day'][sl]).float()
791
-
792
- # Delta-t between points
793
- dt = torch.from_numpy(traj['dt'][sl]).float()
794
-
795
- # Prompt tokens (fixed for prediction task)
796
  task_token = self.prompt_tokens.PREDICT if self.task == 'predict' else self.prompt_tokens.CLASSIFY
797
  prompt = torch.tensor([
798
- self.prompt_tokens.BOS,
799
- task_token,
800
- self.prompt_tokens.AIRCRAFT_UNKNOWN, # default; override with metadata
801
  self.prompt_tokens.PHASE_UNKNOWN,
802
  self.prompt_tokens.REGION_OTHER,
803
  ], dtype=torch.long)
804
 
805
- # Continuous ENU positions (for evaluation / regression head)
806
- east = torch.from_numpy(traj['east'][sl]).float()
807
- north = torch.from_numpy(traj['north'][sl]).float()
808
- up = torch.from_numpy(traj['up'][sl]).float()
809
-
810
  return {
811
- 'geohash_bits': geohash_bits,
812
- 'cog_bins': cog_bins,
813
- 'sog_bins': sog_bins,
814
- 'rot_bins': rot_bins,
815
- 'alt_rate_bins': alt_rate_bins,
816
- 'uncert_bins': uncert_bins,
817
- 'uncert_bins_multi': uncert_bins_multi,
818
- 'hour': hour,
819
- 'dow': dow,
820
- 'month': month,
821
- 'second_of_day': second_of_day,
822
- 'dt': dt,
823
  'prompt': prompt,
824
- 'east': east,
825
- 'north': north,
826
- 'up': up,
827
  }
828
 
829
 
830
  # ============================================================
831
- # 10. Data Loading Utilities
832
  # ============================================================
833
 
834
- def load_traffic_sample(name: str = 'quickstart') -> List[Dict]:
835
- """
836
- Load sample data from the `traffic` library.
837
-
838
- Available collections: 'quickstart' (238 flights), 'switzerland', 'savan'
839
- Individual flights: 'landing_denver', calibration flights, etc.
840
- """
841
  import traffic.data.samples as samples
 
842
 
843
  data = getattr(samples, name)
844
  trajectories = []
845
-
846
- # Handle both Traffic (collection) and Flight (single) objects
847
  flights = data if hasattr(data, '__iter__') else [data]
848
 
849
  for flight in flights:
850
- df = flight.data
851
-
 
 
852
  if df is None or len(df) < 20:
853
  continue
854
 
855
- # Extract required columns — handle tz-aware and PyArrow timestamps
856
- import pandas as pd
857
  ts_series = pd.to_datetime(df['timestamp'])
858
  if ts_series.dt.tz is not None:
859
  ts_series = ts_series.dt.tz_convert('UTC').dt.tz_localize(None)
@@ -861,7 +480,6 @@ def load_traffic_sample(name: str = 'quickstart') -> List[Dict]:
861
  lats = df['latitude'].values.astype(np.float64)
862
  lons = df['longitude'].values.astype(np.float64)
863
 
864
- # Altitude: try barometric first, then geometric
865
  if 'altitude' in df.columns:
866
  alts = df['altitude'].values.astype(np.float64)
867
  elif 'baro_altitude' in df.columns:
@@ -869,41 +487,21 @@ def load_traffic_sample(name: str = 'quickstart') -> List[Dict]:
869
  else:
870
  alts = np.zeros(len(df))
871
 
872
- # Handle NaNs
873
  valid = ~(np.isnan(lats) | np.isnan(lons) | np.isnan(alts) | np.isnan(timestamps))
874
  if valid.sum() < 20:
875
  continue
876
 
877
  trajectories.append({
878
  'timestamps': timestamps[valid],
879
- 'lats': lats[valid],
880
- 'lons': lons[valid],
881
- 'alts': alts[valid],
882
- 'callsign': flight.callsign if hasattr(flight, 'callsign') else 'UNKNOWN',
883
- 'icao24': flight.icao24 if hasattr(flight, 'icao24') else 'UNKNOWN',
884
  })
885
 
886
  return trajectories
887
 
888
 
889
- def build_dataset(
890
- raw_trajectories: List[Dict],
891
- processor: TrajectoryProcessor,
892
- seq_len: int = 128,
893
- stride: int = 64,
894
- fit_geohash: bool = True,
895
- ) -> AirTrackDataset:
896
- """
897
- Process raw trajectories and build PyTorch dataset.
898
-
899
- Args:
900
- raw_trajectories: list of dicts with 'timestamps', 'lats', 'lons', 'alts'
901
- processor: TrajectoryProcessor instance
902
- seq_len: sliding window size
903
- stride: sliding window stride
904
- fit_geohash: if True, fit geohash bounds from this data
905
- """
906
- # First pass: convert to ENU and collect bounds for geohash fitting
907
  processed = []
908
  all_east, all_north, all_up = [], [], []
909
 
@@ -919,28 +517,22 @@ def build_dataset(
919
  all_up.append(result['up'])
920
 
921
  if fit_geohash and processed:
922
- # Fit geohash bounds from all trajectories
923
  all_e = np.concatenate(all_east)
924
  all_n = np.concatenate(all_north)
925
  all_u = np.concatenate(all_up)
926
  processor.fit_geohash(all_e, all_n, all_u)
927
-
928
- # Re-encode geohash with fitted bounds
929
  for traj in processed:
930
  traj['geohash_bits'] = processor.geohash_encoder.encode(
931
  traj['east'], traj['north'], traj['up']
932
  )
933
 
934
  print(f"Processed {len(processed)}/{len(raw_trajectories)} trajectories")
935
-
936
  dataset = AirTrackDataset(processed, seq_len=seq_len, stride=stride)
937
  print(f"Created dataset with {len(dataset)} windows")
938
-
939
  return dataset
940
 
941
 
942
  if __name__ == '__main__':
943
- # Quick test with traffic sample data
944
  print("Loading traffic sample data...")
945
  raw_trajs = load_traffic_sample()
946
  print(f"Loaded {len(raw_trajs)} raw trajectories")
@@ -949,12 +541,9 @@ if __name__ == '__main__':
949
  processor = TrajectoryProcessor(resample_dt=5.0)
950
  dataset = build_dataset(raw_trajs, processor, seq_len=64, stride=32)
951
 
952
- print(f"\nDataset size: {len(dataset)}")
953
  if len(dataset) > 0:
954
  sample = dataset[0]
955
- print("\nSample keys and shapes:")
956
  for k, v in sample.items():
957
  if isinstance(v, torch.Tensor):
958
  print(f" {k}: {v.shape} ({v.dtype})")
959
- else:
960
- print(f" {k}: {type(v)}")
 
6
  Pipeline:
7
  1. Load trajectories from `traffic` library or raw CSV
8
  2. Resample to fixed time interval (default 5s)
9
+ 3. Convert lat/lon/alt to ENU (East-North-Up) using first lat/lon point as origin
10
  4. Compute velocity via 3-point central derivative on ENU positions
11
  5. Derive COG, SOG from x-y ground velocity; ROT from COG; altitude rate from z velocity
12
+ 6. Binary geohash encoding (40-bit per axis, following LLM4STP approach)
13
  7. Discretize features into bins
14
  8. Compute uncertainty scores
15
  9. Build sliding-window PyTorch Dataset
16
+
17
+ KEY DESIGN CHOICES:
18
+ - Time: sub-second (fractional) resolution via float64 Unix timestamps + sinusoidal encoding
19
+ - Geohash: 40-bit binary per axis with PER-TRAJECTORY normalization for max spatial resolution
20
+ 40 bits → 2^40 ≈ 10^12 levels. For a 500km trajectory range, that's ~0.5μm resolution.
21
+ Tighter than any geohash resolution level.
22
+ - COG/SOG: derived from ENU velocity components via 3-point central derivative (not from raw lat/lon)
23
+ - ENU origin: first (lat, lon) of each trajectory
24
  """
25
 
26
  import numpy as np
 
37
 
38
  class ENUConverter:
39
  """
40
+ Convert WGS84 (lat, lon, alt) to local East-North-Up (ENU).
41
+ Origin = first point of each trajectory.
 
 
42
  """
43
 
44
  def __init__(self, origin_lat: float, origin_lon: float, origin_alt: float = 0.0):
 
46
  self.origin_lon = origin_lon
47
  self.origin_alt = origin_alt
48
 
 
49
  self.ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
50
  self.lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
51
  self.transformer_to_ecef = pyproj.Transformer.from_proj(self.lla, self.ecef, always_xy=True)
52
  self.transformer_to_lla = pyproj.Transformer.from_proj(self.ecef, self.lla, always_xy=True)
53
 
 
54
  self.x0, self.y0, self.z0 = self.transformer_to_ecef.transform(
55
  origin_lon, origin_lat, origin_alt
56
  )
57
 
 
58
  lat_r = np.radians(origin_lat)
59
  lon_r = np.radians(origin_lon)
60
  self.R = np.array([
 
63
  [ np.cos(lat_r)*np.cos(lon_r), np.cos(lat_r)*np.sin(lon_r), np.sin(lat_r)]
64
  ])
65
 
66
+ def to_enu(self, lats, lons, alts):
 
 
67
  x, y, z = self.transformer_to_ecef.transform(lons, lats, alts)
 
 
68
  dx = x - self.x0
69
  dy = y - self.y0
70
  dz = z - self.z0
71
+ ecef_delta = np.stack([dx, dy, dz], axis=0)
72
+ enu = self.R @ ecef_delta
73
+ return enu[0], enu[1], enu[2]
 
 
 
 
 
 
 
74
 
75
+ def from_enu(self, east, north, up):
 
76
  enu = np.stack([east, north, up], axis=0)
77
  ecef_delta = self.R.T @ enu
 
78
  x = ecef_delta[0] + self.x0
79
  y = ecef_delta[1] + self.y0
80
  z = ecef_delta[2] + self.z0
 
81
  lons, lats, alts = self.transformer_to_lla.transform(x, y, z)
82
  return lats, lons, alts
83
 
84
 
85
  # ============================================================
86
+ # 2. Three-Point Central Derivative (vectorized)
87
  # ============================================================
88
 
89
+ def three_point_derivative(values: np.ndarray, t: np.ndarray) -> np.ndarray:
90
  """
91
+ 3-point central derivative. Interior: (f(i+1) - f(i-1)) / (t(i+1) - t(i-1))
92
+ Endpoints: forward/backward difference.
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  """
94
  N = len(values)
95
  deriv = np.zeros(N)
 
96
  if N < 2:
97
  return deriv
98
 
99
+ dt_fwd = t[1] - t[0]
 
100
  if dt_fwd > 0:
101
  deriv[0] = (values[1] - values[0]) / dt_fwd
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if N > 2:
104
+ dt_span = t[2:] - t[:-2]
105
  mask = dt_span > 0
106
+ val_diff = values[2:] - values[:-2]
107
  deriv[1:-1] = np.where(mask, val_diff / np.maximum(dt_span, 1e-10), 0.0)
108
 
109
+ dt_bwd = t[-1] - t[-2]
 
110
  if dt_bwd > 0:
111
  deriv[-1] = (values[-1] - values[-2]) / dt_bwd
112
 
 
117
  # 3. Feature Derivation from ENU positions
118
  # ============================================================
119
 
120
+ def derive_features_enu(east, north, up, timestamps):
 
 
 
 
 
121
  """
122
+ Derive COG, SOG, ROT, alt_rate from ENU positions using 3-point central derivatives.
 
 
 
 
 
123
 
124
+ COG = atan2(vx_east, vy_north) → bearing from North, clockwise [0, 360)
125
+ SOG = sqrt(vx² + vy²) converted to knots
126
+ ROT = d(COG)/dt via 3-point derivative on unwrapped COG
127
+ alt_rate = vz converted to ft/min
128
  """
 
129
  t = timestamps - timestamps[0]
130
 
131
+ vx = three_point_derivative(east, t) # East velocity m/s
132
+ vy = three_point_derivative(north, t) # North velocity m/s
133
+ vz = three_point_derivative(up, t) # Up velocity m/s
 
134
 
 
135
  sog_ms = np.sqrt(vx**2 + vy**2)
136
+ sog_knots = sog_ms * 1.94384 # m/s knots
137
 
138
+ # COG: atan2(East, North) gives bearing from North, clockwise
 
139
  cog_deg = np.degrees(np.arctan2(vx, vy)) % 360
140
 
141
+ # ROT: derivative of unwrapped COG
 
142
  cog_unwrapped = np.unwrap(np.radians(cog_deg))
143
+ rot_rad_s = three_point_derivative(cog_unwrapped, t)
144
  rot_deg_s = np.degrees(rot_rad_s)
145
 
146
+ # Altitude rate: m/s ft/min
147
+ alt_rate_ftmin = vz * 196.85
148
 
149
  return {
150
+ 'vx': vx, 'vy': vy, 'vz': vz,
151
+ 'COG': cog_deg, 'SOG': sog_knots,
152
+ 'ROT': rot_deg_s, 'alt_rate': alt_rate_ftmin,
 
 
 
 
153
  }
154
 
155
 
156
  # ============================================================
157
+ # 4. Binary Geohash Encoding (40-bit per axis → 120 bits total)
158
  # ============================================================
159
 
160
+ def binary_geohash_encode(values, precision=40, v_min=0.0, v_max=1.0):
161
+ """Successive bisection encoding to binary. Matches LLM4STP num2bits()."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  N = len(values)
163
  bits = np.zeros((N, precision), dtype=np.int64)
 
164
  _min = np.full(N, v_min)
165
  _max = np.full(N, v_max)
 
166
  for p in range(precision):
167
  mid = (_min + _max) / 2
168
  mask = values > mid
169
  bits[:, p] = mask.astype(np.int64)
170
  _min = np.where(mask, mid, _min)
171
  _max = np.where(mask, _max, mid)
 
172
  return bits
173
 
174
 
175
+ def binary_geohash_decode(bits, precision=40, v_min=0.0, v_max=1.0):
 
 
 
 
 
 
176
  N = bits.shape[0]
177
  _min = np.full(N, v_min)
178
  _max = np.full(N, v_max)
 
179
  for p in range(precision):
180
  mid = (_min + _max) / 2
181
  mask = bits[:, p].astype(bool)
182
  _min = np.where(mask, mid, _min)
183
  _max = np.where(mask, _max, mid)
 
184
  return (_min + _max) / 2
185
 
186
 
187
  class GeohashEncoder:
188
  """
189
+ 3D geohash encoder for ENU coordinates.
190
+ 40 bits per axis × 3 axes = 120 bits per timestep.
191
 
192
+ Uses PER-TRAJECTORY normalization with a small margin so the full
193
+ bit range encodes just the spatial extent of each trajectory.
194
+ For a typical trajectory spanning ~200km, 40 bits gives ~0.2mm resolution.
 
195
  """
196
 
197
+ def __init__(self, precision=40):
198
  self.precision = precision
199
+ self.e_min = self.e_max = None
200
+ self.n_min = self.n_max = None
201
+ self.u_min = self.u_max = None
202
+
203
+ def fit(self, east, north, up, margin=0.05):
204
+ """Fit normalization bounds with margin."""
205
+ for attr, data in [('e', east), ('n', north), ('u', up)]:
206
+ drange = data.max() - data.min()
207
+ m = margin * max(drange, 100.0) # At least 100m range
208
+ setattr(self, f'{attr}_min', data.min() - m)
209
+ setattr(self, f'{attr}_max', data.max() + m)
210
+
211
+ def _normalize(self, values, v_min, v_max):
 
 
 
 
 
 
 
 
 
 
212
  return np.clip((values - v_min) / max(v_max - v_min, 1e-10), 0.0, 1.0)
213
 
214
+ def encode(self, east, north, up):
215
+ e_norm = self._normalize(east, self.e_min, self.e_max)
216
+ n_norm = self._normalize(north, self.n_min, self.n_max)
217
+ u_norm = self._normalize(up, self.u_min, self.u_max)
 
 
 
 
 
 
 
218
  e_bits = binary_geohash_encode(e_norm, self.precision)
219
  n_bits = binary_geohash_encode(n_norm, self.precision)
220
  u_bits = binary_geohash_encode(u_norm, self.precision)
 
221
  return np.concatenate([e_bits, n_bits, u_bits], axis=1) # (N, 120)
222
+
223
+ def get_bounds(self):
224
+ return {
225
+ 'e_min': self.e_min, 'e_max': self.e_max,
226
+ 'n_min': self.n_min, 'n_max': self.n_max,
227
+ 'u_min': self.u_min, 'u_max': self.u_max,
228
+ }
229
 
230
 
231
  # ============================================================
 
234
 
235
  @dataclass
236
  class FeatureBins:
237
+ """Feature discretization configuration."""
238
+ cog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 360, 181)) # 180 bins, 2°
239
+ sog_edges: np.ndarray = field(default_factory=lambda: np.linspace(0, 600, 301)) # 300 bins, 2 kts
240
+ rot_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6, 6, 121)) # 120 bins, 0.1°/s
 
 
 
 
 
 
 
 
241
  alt_rate_edges: np.ndarray = field(default_factory=lambda: np.linspace(-6000, 6000, 121)) # 120 bins
242
 
243
  @property
244
  def n_cog_bins(self): return len(self.cog_edges) - 1
 
245
  @property
246
  def n_sog_bins(self): return len(self.sog_edges) - 1
 
247
  @property
248
  def n_rot_bins(self): return len(self.rot_edges) - 1
 
249
  @property
250
  def n_alt_rate_bins(self): return len(self.alt_rate_edges) - 1
251
 
252
+ def _digitize(self, values, edges):
253
+ return np.clip(np.digitize(values, edges) - 1, 0, len(edges) - 2)
 
 
254
 
255
+ def encode_cog(self, cog): return self._digitize(cog, self.cog_edges)
256
+ def encode_sog(self, sog): return self._digitize(sog, self.sog_edges)
257
+ def encode_rot(self, rot): return self._digitize(np.clip(rot, -6, 6), self.rot_edges)
258
+ def encode_alt_rate(self, ar): return self._digitize(np.clip(ar, -6000, 6000), self.alt_rate_edges)
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  # ============================================================
262
+ # 6. Temporal Features (sub-second precision)
263
  # ============================================================
264
 
265
+ def extract_temporal_features(timestamps):
266
+ """Extract temporal features preserving fractional second precision."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  import datetime
268
 
 
269
  dts = [datetime.datetime.utcfromtimestamp(t) for t in timestamps]
 
270
  hours = np.array([d.hour for d in dts], dtype=np.int64)
271
  dows = np.array([d.weekday() for d in dts], dtype=np.int64)
272
+ months = np.array([d.month - 1 for d in dts], dtype=np.int64)
273
 
274
+ # Second of day with fractional seconds (sub-second precision)
275
  second_of_day = np.array([
276
  d.hour * 3600 + d.minute * 60 + d.second + d.microsecond / 1e6
277
  for d in dts
278
+ ], dtype=np.float64)
279
 
280
+ dt = np.zeros(len(timestamps), dtype=np.float64)
 
281
  dt[1:] = np.diff(timestamps)
282
 
 
 
 
283
  return {
284
  'second_of_day': second_of_day,
285
+ 'hour': hours, 'dow': dows, 'month': months,
 
 
286
  'dt': dt,
 
287
  }
288
 
289
 
290
  # ============================================================
291
+ # 7. Full Trajectory Processor
292
  # ============================================================
293
 
294
  class TrajectoryProcessor:
295
+ def __init__(self, resample_dt=5.0, geohash_precision=40, n_uncertainty_bins=16,
296
+ feature_bins=None, min_trajectory_len=20):
 
 
 
 
 
 
 
 
 
 
297
  self.resample_dt = resample_dt
298
  self.geohash_precision = geohash_precision
299
  self.n_uncertainty_bins = n_uncertainty_bins
300
  self.feature_bins = feature_bins or FeatureBins()
301
  self.min_trajectory_len = min_trajectory_len
302
  self.geohash_encoder = GeohashEncoder(precision=geohash_precision)
 
 
303
  self._fitted = False
304
 
305
+ def resample_trajectory(self, timestamps, lats, lons, alts):
306
+ t_start, t_end = timestamps[0], timestamps[-1]
 
 
 
 
 
307
  n_points = int((t_end - t_start) / self.resample_dt) + 1
308
+ if n_points < 2:
309
+ return timestamps, lats, lons, alts
310
  t_new = np.linspace(t_start, t_start + (n_points - 1) * self.resample_dt, n_points)
311
+ return t_new, np.interp(t_new, timestamps, lats), np.interp(t_new, timestamps, lons), np.interp(t_new, timestamps, alts)
312
+
313
+ def process_trajectory(self, timestamps, lats, lons, alts, metadata=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  sort_idx = np.argsort(timestamps)
315
+ timestamps, lats, lons, alts = timestamps[sort_idx], lats[sort_idx], lons[sort_idx], alts[sort_idx]
 
 
 
316
 
 
317
  timestamps, lats, lons, alts = self.resample_trajectory(timestamps, lats, lons, alts)
 
318
  if len(timestamps) < self.min_trajectory_len:
319
  return None
320
 
321
+ # ENU conversion origin = first point
322
  converter = ENUConverter(lats[0], lons[0], alts[0])
323
  east, north, up = converter.to_enu(lats, lons, alts)
324
 
325
+ # Derive kinematics from ENU via 3-point derivative
326
  features = derive_features_enu(east, north, up, timestamps)
327
 
328
+ # Binary geohash
 
329
  if self.geohash_encoder.e_min is not None:
330
+ geohash_bits = self.geohash_encoder.encode(east, north, up)
331
  else:
332
  geohash_bits = np.zeros((len(east), self.geohash_precision * 3), dtype=np.int64)
333
 
334
+ # Discretize kinematics
335
  cog_bins = self.feature_bins.encode_cog(features['COG'])
336
  sog_bins = self.feature_bins.encode_sog(features['SOG'])
337
  rot_bins = self.feature_bins.encode_rot(features['ROT'])
338
  alt_rate_bins = self.feature_bins.encode_alt_rate(features['alt_rate'])
339
 
340
+ # Uncertainty
341
+ from uncertainty import compute_all_uncertainties, discretize_scores, UncertaintyConfig
342
+ uncert_config = UncertaintyConfig(n_bins=self.n_uncertainty_bins, window=5)
 
 
 
 
 
 
 
 
 
 
343
  raw_uncert = compute_all_uncertainties(
344
  east, north, up, timestamps,
345
  features['COG'], features['SOG'], features['ROT'], features['alt_rate'],
346
  config=uncert_config,
347
  )
 
348
  uncert_methods = sorted(raw_uncert.keys())
349
  uncert_bins_multi = np.stack([
350
  discretize_scores(raw_uncert[m], self.n_uncertainty_bins)
351
  for m in uncert_methods
352
+ ], axis=1)
353
+ uncert_bins = uncert_bins_multi[:, 0] if uncert_bins_multi.shape[1] > 0 else np.zeros(len(east), dtype=np.int64)
 
 
 
 
 
354
 
 
355
  temporal = extract_temporal_features(timestamps)
356
 
357
  return {
358
+ 'timestamps': timestamps, 'lats': lats, 'lons': lons, 'alts': alts,
359
+ 'east': east, 'north': north, 'up': up,
360
+ 'COG': features['COG'], 'SOG': features['SOG'],
361
+ 'ROT': features['ROT'], 'alt_rate': features['alt_rate'],
362
+ 'vx': features['vx'], 'vy': features['vy'], 'vz': features['vz'],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  'geohash_bits': geohash_bits,
364
+ 'cog_bins': cog_bins, 'sog_bins': sog_bins,
365
+ 'rot_bins': rot_bins, 'alt_rate_bins': alt_rate_bins,
366
+ 'uncert_bins': uncert_bins, 'uncert_bins_multi': uncert_bins_multi,
367
+ 'uncert_method_names': uncert_methods,
368
+ 'hour': temporal['hour'], 'dow': temporal['dow'], 'month': temporal['month'],
369
+ 'second_of_day': temporal['second_of_day'], 'dt': temporal['dt'],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  'enu_origin': (converter.origin_lat, converter.origin_lon, converter.origin_alt),
 
 
371
  'metadata': metadata or {},
372
  }
373
 
374
+ def fit_geohash(self, all_east, all_north, all_up):
 
375
  self.geohash_encoder.fit(all_east, all_north, all_up)
376
  self._fitted = True
377
 
378
 
379
  # ============================================================
380
+ # 8. Prompt Tokens
381
  # ============================================================
382
 
383
  @dataclass
384
  class PromptTokens:
385
+ BOS: int = 0; EOS: int = 1; PAD: int = 2
386
+ PREDICT: int = 3; CLASSIFY: int = 4; DETECT_ANOMALY: int = 5
387
+ HEAVY: int = 6; LARGE: int = 7; SMALL: int = 8
388
+ ROTORCRAFT: int = 9; GLIDER: int = 10; UAV: int = 11; AIRCRAFT_UNKNOWN: int = 12
389
+ CLIMB: int = 13; CRUISE: int = 14; DESCENT: int = 15
390
+ APPROACH: int = 16; GROUND: int = 17; PHASE_UNKNOWN: int = 18
391
+ CONUS: int = 19; EUROPE: int = 20; ASIA: int = 21; REGION_OTHER: int = 22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  VOCAB_SIZE: int = 23
393
 
394
 
395
+ # ============================================================
396
+ # 9. PyTorch Dataset
397
+ # ============================================================
398
+
399
  class AirTrackDataset(Dataset):
400
+ def __init__(self, trajectories, seq_len=128, stride=64, task='predict'):
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  self.seq_len = seq_len
402
  self.stride = stride
403
  self.task = task
 
 
404
  self.windows = []
405
  self.trajectories = trajectories
406
 
407
  for traj_idx, traj in enumerate(trajectories):
408
  n_points = len(traj['timestamps'])
 
409
  if n_points < seq_len + 1:
 
410
  if n_points >= 20:
411
  self.windows.append((traj_idx, 0, n_points))
412
  continue
 
413
  for start in range(0, n_points - seq_len, stride):
414
+ end = start + seq_len + 1
415
  if end <= n_points:
416
  self.windows.append((traj_idx, start, end))
417
 
 
418
  self.prompt_tokens = PromptTokens()
419
 
420
  def __len__(self):
 
423
  def __getitem__(self, idx):
424
  traj_idx, start, end = self.windows[idx]
425
  traj = self.trajectories[traj_idx]
 
 
426
  sl = slice(start, end)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  task_token = self.prompt_tokens.PREDICT if self.task == 'predict' else self.prompt_tokens.CLASSIFY
429
  prompt = torch.tensor([
430
+ self.prompt_tokens.BOS, task_token,
431
+ self.prompt_tokens.AIRCRAFT_UNKNOWN,
 
432
  self.prompt_tokens.PHASE_UNKNOWN,
433
  self.prompt_tokens.REGION_OTHER,
434
  ], dtype=torch.long)
435
 
 
 
 
 
 
436
  return {
437
+ 'geohash_bits': torch.from_numpy(traj['geohash_bits'][sl]).float(),
438
+ 'cog_bins': torch.from_numpy(traj['cog_bins'][sl]).long(),
439
+ 'sog_bins': torch.from_numpy(traj['sog_bins'][sl]).long(),
440
+ 'rot_bins': torch.from_numpy(traj['rot_bins'][sl]).long(),
441
+ 'alt_rate_bins': torch.from_numpy(traj['alt_rate_bins'][sl]).long(),
442
+ 'uncert_bins': torch.from_numpy(traj['uncert_bins'][sl]).long(),
443
+ 'uncert_bins_multi': torch.from_numpy(traj['uncert_bins_multi'][sl]).long(),
444
+ 'hour': torch.from_numpy(traj['hour'][sl]).long(),
445
+ 'dow': torch.from_numpy(traj['dow'][sl]).long(),
446
+ 'month': torch.from_numpy(traj['month'][sl]).long(),
447
+ 'second_of_day': torch.from_numpy(traj['second_of_day'][sl]).float(),
448
+ 'dt': torch.from_numpy(traj['dt'][sl]).float(),
449
  'prompt': prompt,
450
+ 'east': torch.from_numpy(traj['east'][sl]).float(),
451
+ 'north': torch.from_numpy(traj['north'][sl]).float(),
452
+ 'up': torch.from_numpy(traj['up'][sl]).float(),
453
  }
454
 
455
 
456
  # ============================================================
457
+ # 10. Data Loading
458
  # ============================================================
459
 
460
+ def load_traffic_sample(name='quickstart'):
 
 
 
 
 
 
461
  import traffic.data.samples as samples
462
+ import pandas as pd
463
 
464
  data = getattr(samples, name)
465
  trajectories = []
 
 
466
  flights = data if hasattr(data, '__iter__') else [data]
467
 
468
  for flight in flights:
469
+ try:
470
+ df = flight.data
471
+ except Exception:
472
+ continue
473
  if df is None or len(df) < 20:
474
  continue
475
 
 
 
476
  ts_series = pd.to_datetime(df['timestamp'])
477
  if ts_series.dt.tz is not None:
478
  ts_series = ts_series.dt.tz_convert('UTC').dt.tz_localize(None)
 
480
  lats = df['latitude'].values.astype(np.float64)
481
  lons = df['longitude'].values.astype(np.float64)
482
 
 
483
  if 'altitude' in df.columns:
484
  alts = df['altitude'].values.astype(np.float64)
485
  elif 'baro_altitude' in df.columns:
 
487
  else:
488
  alts = np.zeros(len(df))
489
 
 
490
  valid = ~(np.isnan(lats) | np.isnan(lons) | np.isnan(alts) | np.isnan(timestamps))
491
  if valid.sum() < 20:
492
  continue
493
 
494
  trajectories.append({
495
  'timestamps': timestamps[valid],
496
+ 'lats': lats[valid], 'lons': lons[valid], 'alts': alts[valid],
497
+ 'callsign': getattr(flight, 'callsign', 'UNKNOWN'),
498
+ 'icao24': getattr(flight, 'icao24', 'UNKNOWN'),
 
 
499
  })
500
 
501
  return trajectories
502
 
503
 
504
+ def build_dataset(raw_trajectories, processor, seq_len=128, stride=64, fit_geohash=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  processed = []
506
  all_east, all_north, all_up = [], [], []
507
 
 
517
  all_up.append(result['up'])
518
 
519
  if fit_geohash and processed:
 
520
  all_e = np.concatenate(all_east)
521
  all_n = np.concatenate(all_north)
522
  all_u = np.concatenate(all_up)
523
  processor.fit_geohash(all_e, all_n, all_u)
 
 
524
  for traj in processed:
525
  traj['geohash_bits'] = processor.geohash_encoder.encode(
526
  traj['east'], traj['north'], traj['up']
527
  )
528
 
529
  print(f"Processed {len(processed)}/{len(raw_trajectories)} trajectories")
 
530
  dataset = AirTrackDataset(processed, seq_len=seq_len, stride=stride)
531
  print(f"Created dataset with {len(dataset)} windows")
 
532
  return dataset
533
 
534
 
535
  if __name__ == '__main__':
 
536
  print("Loading traffic sample data...")
537
  raw_trajs = load_traffic_sample()
538
  print(f"Loaded {len(raw_trajs)} raw trajectories")
 
541
  processor = TrajectoryProcessor(resample_dt=5.0)
542
  dataset = build_dataset(raw_trajs, processor, seq_len=64, stride=32)
543
 
 
544
  if len(dataset) > 0:
545
  sample = dataset[0]
546
+ print("\nSample shapes:")
547
  for k, v in sample.items():
548
  if isinstance(v, torch.Tensor):
549
  print(f" {k}: {v.shape} ({v.dtype})")