Jdice27 commited on
Commit
0c01cdc
·
verified ·
1 Parent(s): bdd8cea

Add data module

Browse files
Files changed (1) hide show
  1. llm4airtrack/data.py +131 -0
llm4airtrack/data.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ATFMTraj Data Loading and Preprocessing for LLM4AirTrack.
3
+
4
+ Loads ENU-transformed ADS-B trajectories from petchthwr/ATFMTraj.
5
+ Creates sliding-window samples: [context_window] -> [prediction_horizon].
6
+ Computes kinematic features: directional vectors, polar components, speed proxies.
7
+ """
8
+
9
+ import os
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from huggingface_hub import hf_hub_download
15
+ from typing import Tuple, Optional
16
+
17
+
18
+ def download_atfm_dataset(airport="RKSIa", cache_dir="./data/ATFMTraj"):
19
+ """Download ATFMTraj TSV files from HuggingFace Hub."""
20
+ os.makedirs(cache_dir, exist_ok=True)
21
+ airport_dir = os.path.join(cache_dir, airport)
22
+ os.makedirs(airport_dir, exist_ok=True)
23
+ for mode in ["TRAIN", "TEST"]:
24
+ for var in ["X", "Y", "Z"]:
25
+ fname = f"{airport}_{mode}_{var}.tsv"
26
+ fpath = os.path.join(airport_dir, fname)
27
+ if not os.path.exists(fpath):
28
+ print(f"Downloading {airport}/{fname}...")
29
+ hf_hub_download(
30
+ repo_id="petchthwr/ATFMTraj",
31
+ filename=f"{airport}/{fname}",
32
+ repo_type="dataset",
33
+ local_dir=cache_dir,
34
+ )
35
+ return airport_dir
36
+
37
+
38
+ def load_atfm_raw(airport="RKSIa", mode="TRAIN", cache_dir="./data/ATFMTraj"):
39
+ """Load raw ATFMTraj data. Returns (N, T_max, 3) ENU + (N,) labels."""
40
+ airport_dir = os.path.join(cache_dir, airport)
41
+ data, labels = [], None
42
+ for var in ['X', 'Y', 'Z']:
43
+ df = pd.read_csv(
44
+ os.path.join(airport_dir, f"{airport}_{mode}_{var}.tsv"),
45
+ sep='\t', header=None, na_values='NaN'
46
+ )
47
+ if labels is None:
48
+ labels = df.values[:, 0]
49
+ data.append(df.values[:, 1:])
50
+ return np.stack(data, axis=-1), labels.astype(int)
51
+
52
+
53
+ def compute_kinematic_features(trajectory, dt=1.0):
54
+ """
55
+ Compute 9-dim kinematic features from ENU (x,y,z):
56
+ Position (x,y,z) + Direction (ux,uy,uz) + Polar (r, sinθ, cosθ)
57
+ """
58
+ x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
59
+ dx, dy, dz = np.gradient(x)/dt, np.gradient(y)/dt, np.gradient(z)/dt
60
+ speed = np.sqrt(dx**2 + dy**2 + dz**2) + 1e-8
61
+ ux, uy, uz = dx/speed, dy/speed, dz/speed
62
+ r = np.sqrt(x**2 + y**2) + 1e-8
63
+ theta = np.arctan2(y, x)
64
+ return np.stack([x, y, z, ux, uy, uz, r, np.sin(theta), np.cos(theta)], axis=-1)
65
+
66
+
67
+ def create_trajectory_windows(data, labels, context_len=60, pred_len=30, stride=15):
68
+ """Create sliding-window samples from variable-length trajectories."""
69
+ total_len = context_len + pred_len
70
+ contexts, targets, sample_labels = [], [], []
71
+ for i in range(len(data)):
72
+ traj = data[i]
73
+ valid_mask = ~np.isnan(traj[:, 0])
74
+ valid_len = np.sum(valid_mask)
75
+ if valid_len < total_len:
76
+ continue
77
+ traj_valid = traj[valid_mask]
78
+ for start in range(0, valid_len - total_len + 1, stride):
79
+ ctx_raw = traj_valid[start:start + context_len]
80
+ tgt = traj_valid[start + context_len:start + total_len]
81
+ ctx = compute_kinematic_features(ctx_raw)
82
+ contexts.append(ctx)
83
+ targets.append(tgt)
84
+ sample_labels.append(labels[i])
85
+ return (
86
+ np.array(contexts, dtype=np.float32),
87
+ np.array(targets, dtype=np.float32),
88
+ np.array(sample_labels, dtype=np.int64),
89
+ )
90
+
91
+
92
+ class AirTrackDataset(Dataset):
93
+ """PyTorch Dataset for aircraft trajectory prediction."""
94
+ def __init__(self, contexts, targets, labels):
95
+ self.contexts = torch.from_numpy(contexts)
96
+ self.targets = torch.from_numpy(targets)
97
+ self.labels = torch.from_numpy(labels)
98
+
99
+ def __len__(self):
100
+ return len(self.contexts)
101
+
102
+ def __getitem__(self, idx):
103
+ return {"context": self.contexts[idx], "target": self.targets[idx], "label": self.labels[idx]}
104
+
105
+
106
+ def prepare_dataloaders(airport="RKSIa", context_len=60, pred_len=30, stride=15,
107
+ batch_size=32, cache_dir="./data/ATFMTraj", max_trajectories=None):
108
+ """Full pipeline: download -> load -> window -> dataloader."""
109
+ download_atfm_dataset(airport, cache_dir)
110
+ train_data, train_labels = load_atfm_raw(airport, "TRAIN", cache_dir)
111
+ test_data, test_labels = load_atfm_raw(airport, "TEST", cache_dir)
112
+ if max_trajectories:
113
+ train_data, train_labels = train_data[:max_trajectories], train_labels[:max_trajectories]
114
+ test_data, test_labels = test_data[:max_trajectories], test_labels[:max_trajectories]
115
+
116
+ train_ctx, train_tgt, train_lbl = create_trajectory_windows(train_data, train_labels, context_len, pred_len, stride)
117
+ test_ctx, test_tgt, test_lbl = create_trajectory_windows(test_data, test_labels, context_len, pred_len, stride)
118
+
119
+ all_labels = np.concatenate([train_lbl, test_lbl])
120
+ n_classes = int(all_labels.max()) + 1
121
+
122
+ train_ds = AirTrackDataset(train_ctx, train_tgt, train_lbl)
123
+ test_ds = AirTrackDataset(test_ctx, test_tgt, test_lbl)
124
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
125
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
126
+
127
+ return train_loader, test_loader, {
128
+ "airport": airport, "context_len": context_len, "pred_len": pred_len,
129
+ "n_features": train_ctx.shape[-1], "n_classes": n_classes,
130
+ "n_train_windows": len(train_ds), "n_test_windows": len(test_ds),
131
+ }