File size: 5,549 Bytes
0c01cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
ATFMTraj Data Loading and Preprocessing for LLM4AirTrack.

Loads ENU-transformed ADS-B trajectories from petchthwr/ATFMTraj.
Creates sliding-window samples: [context_window] -> [prediction_horizon].
Computes kinematic features: directional vectors, polar components, speed proxies.
"""

import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import hf_hub_download
from typing import Tuple, Optional


def download_atfm_dataset(airport="RKSIa", cache_dir="./data/ATFMTraj"):
    """Download ATFMTraj TSV files from HuggingFace Hub."""
    os.makedirs(cache_dir, exist_ok=True)
    airport_dir = os.path.join(cache_dir, airport)
    os.makedirs(airport_dir, exist_ok=True)
    for mode in ["TRAIN", "TEST"]:
        for var in ["X", "Y", "Z"]:
            fname = f"{airport}_{mode}_{var}.tsv"
            fpath = os.path.join(airport_dir, fname)
            if not os.path.exists(fpath):
                print(f"Downloading {airport}/{fname}...")
                hf_hub_download(
                    repo_id="petchthwr/ATFMTraj",
                    filename=f"{airport}/{fname}",
                    repo_type="dataset",
                    local_dir=cache_dir,
                )
    return airport_dir


def load_atfm_raw(airport="RKSIa", mode="TRAIN", cache_dir="./data/ATFMTraj"):
    """Load raw ATFMTraj data. Returns (N, T_max, 3) ENU + (N,) labels."""
    airport_dir = os.path.join(cache_dir, airport)
    data, labels = [], None
    for var in ['X', 'Y', 'Z']:
        df = pd.read_csv(
            os.path.join(airport_dir, f"{airport}_{mode}_{var}.tsv"),
            sep='\t', header=None, na_values='NaN'
        )
        if labels is None:
            labels = df.values[:, 0]
        data.append(df.values[:, 1:])
    return np.stack(data, axis=-1), labels.astype(int)


def compute_kinematic_features(trajectory, dt=1.0):
    """
    Compute 9-dim kinematic features from ENU (x,y,z):
    Position (x,y,z) + Direction (ux,uy,uz) + Polar (r, sinθ, cosθ)
    """
    x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
    dx, dy, dz = np.gradient(x)/dt, np.gradient(y)/dt, np.gradient(z)/dt
    speed = np.sqrt(dx**2 + dy**2 + dz**2) + 1e-8
    ux, uy, uz = dx/speed, dy/speed, dz/speed
    r = np.sqrt(x**2 + y**2) + 1e-8
    theta = np.arctan2(y, x)
    return np.stack([x, y, z, ux, uy, uz, r, np.sin(theta), np.cos(theta)], axis=-1)


def create_trajectory_windows(data, labels, context_len=60, pred_len=30, stride=15):
    """Create sliding-window samples from variable-length trajectories."""
    total_len = context_len + pred_len
    contexts, targets, sample_labels = [], [], []
    for i in range(len(data)):
        traj = data[i]
        valid_mask = ~np.isnan(traj[:, 0])
        valid_len = np.sum(valid_mask)
        if valid_len < total_len:
            continue
        traj_valid = traj[valid_mask]
        for start in range(0, valid_len - total_len + 1, stride):
            ctx_raw = traj_valid[start:start + context_len]
            tgt = traj_valid[start + context_len:start + total_len]
            ctx = compute_kinematic_features(ctx_raw)
            contexts.append(ctx)
            targets.append(tgt)
            sample_labels.append(labels[i])
    return (
        np.array(contexts, dtype=np.float32),
        np.array(targets, dtype=np.float32),
        np.array(sample_labels, dtype=np.int64),
    )


class AirTrackDataset(Dataset):
    """PyTorch Dataset for aircraft trajectory prediction."""
    def __init__(self, contexts, targets, labels):
        self.contexts = torch.from_numpy(contexts)
        self.targets = torch.from_numpy(targets)
        self.labels = torch.from_numpy(labels)
    
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, idx):
        return {"context": self.contexts[idx], "target": self.targets[idx], "label": self.labels[idx]}


def prepare_dataloaders(airport="RKSIa", context_len=60, pred_len=30, stride=15,
                        batch_size=32, cache_dir="./data/ATFMTraj", max_trajectories=None):
    """Full pipeline: download -> load -> window -> dataloader."""
    download_atfm_dataset(airport, cache_dir)
    train_data, train_labels = load_atfm_raw(airport, "TRAIN", cache_dir)
    test_data, test_labels = load_atfm_raw(airport, "TEST", cache_dir)
    if max_trajectories:
        train_data, train_labels = train_data[:max_trajectories], train_labels[:max_trajectories]
        test_data, test_labels = test_data[:max_trajectories], test_labels[:max_trajectories]
    
    train_ctx, train_tgt, train_lbl = create_trajectory_windows(train_data, train_labels, context_len, pred_len, stride)
    test_ctx, test_tgt, test_lbl = create_trajectory_windows(test_data, test_labels, context_len, pred_len, stride)
    
    all_labels = np.concatenate([train_lbl, test_lbl])
    n_classes = int(all_labels.max()) + 1
    
    train_ds = AirTrackDataset(train_ctx, train_tgt, train_lbl)
    test_ds = AirTrackDataset(test_ctx, test_tgt, test_lbl)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    return train_loader, test_loader, {
        "airport": airport, "context_len": context_len, "pred_len": pred_len,
        "n_features": train_ctx.shape[-1], "n_classes": n_classes,
        "n_train_windows": len(train_ds), "n_test_windows": len(test_ds),
    }