| |
| """ |
| Experiment F: Zero-shot scene generalization. |
| |
| Leave-one-scene-out evaluation on T1 (scene recognition). For each of the 8 |
| scenes S_k, train on the remaining 7 scenes across all train+test |
| volunteers, then evaluate on scene S_k only (all volunteers). Since the |
| held-out scene was never seen during training, the held-out scene's samples |
| should be distributed over the remaining 7 classes -- so we report the |
| fraction of held-out samples that get classified into the single nearest |
| remaining class (dominant neighbor) and macro-F1 on the 7 seen scenes |
| during training+eval on mixed scenes. |
| |
| Simpler protocol: train 8-class classifier but WITHOUT scene S_k in the |
| training set. Evaluate on full test set (all 8 scenes). Measure what the |
| holdout scene gets misclassified to -- reveals scene similarity and |
| generalization behavior. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| MultimodalSceneDataset, collate_fn, TRAIN_VOLS, TEST_VOLS, SCENE_LABELS, |
| NUM_CLASSES, |
| ) |
| from nets.models import build_model |
| from tasks.train_exp1 import set_seed, apply_augmentation |
|
|
|
|
| def filter_dataset_by_scene(ds, excluded_scene): |
| """Return indices of samples NOT from the excluded scene.""" |
| idxs = [] |
| for i, info in enumerate(ds.sample_info): |
| if f"/{excluded_scene}" not in info: |
| idxs.append(i) |
| return idxs |
|
|
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| modalities = args.modalities.split(',') |
| held_out = args.held_out_scene |
| assert held_out in SCENE_LABELS, f"Unknown scene: {held_out}" |
| print(f"Held-out scene: {held_out} (= class {SCENE_LABELS[held_out]})") |
|
|
| |
| print("Loading train data...") |
| full_train = MultimodalSceneDataset(TRAIN_VOLS, modalities, args.downsample) |
| stats = full_train.get_stats() |
| print("Loading test data...") |
| full_test = MultimodalSceneDataset(TEST_VOLS, modalities, args.downsample, |
| stats=stats) |
|
|
| |
| train_idx = filter_dataset_by_scene(full_train, held_out) |
| print(f"Train size (7 seen scenes): {len(train_idx)}/{len(full_train)}") |
|
|
| |
| test_seen_idx = filter_dataset_by_scene(full_test, held_out) |
| test_unseen_idx = [i for i in range(len(full_test)) |
| if i not in test_seen_idx] |
| print(f"Test seen: {len(test_seen_idx)} unseen: {len(test_unseen_idx)}") |
|
|
| train_sub = torch.utils.data.Subset(full_train, train_idx) |
| test_seen_sub = torch.utils.data.Subset(full_test, test_seen_idx) |
| test_unseen_sub = torch.utils.data.Subset(full_test, test_unseen_idx) |
|
|
| train_loader = DataLoader(train_sub, batch_size=args.batch_size, shuffle=True, |
| collate_fn=collate_fn) |
| test_seen_loader = DataLoader(test_seen_sub, batch_size=args.batch_size, |
| shuffle=False, collate_fn=collate_fn) |
| test_unseen_loader = DataLoader(test_unseen_sub, batch_size=args.batch_size, |
| shuffle=False, collate_fn=collate_fn) |
|
|
| |
| |
| model = build_model( |
| args.model, args.fusion, full_train.feat_dim, |
| full_train.modality_dims, NUM_CLASSES, |
| hidden_dim=args.hidden_dim, proj_dim=0, late_agg='mean', |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Params: {n_params:,}") |
|
|
| |
| class_weights = full_train.get_class_weights().clone().to(device) |
| class_weights[SCENE_LABELS[held_out]] = 0.0 |
| criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1, |
| ignore_index=SCENE_LABELS[held_out]) |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, |
| ) |
|
|
| exp_name = f"zs_{args.model}_{'-'.join(modalities)}_hold_{held_out}_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_seen_f1 = 0.0 |
| best_state = None |
| best_epoch = 0 |
| patience_counter = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| model.train() |
| tr_loss, n = 0.0, 0 |
| for x, y, mask, _ in train_loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| if args.augment: |
| x = apply_augmentation(x, mask, 0.1, 0.1) |
| optimizer.zero_grad() |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| tr_loss += loss.item() * y.size(0) |
| n += y.size(0) |
| tr_loss /= max(n, 1) |
|
|
| |
| model.eval() |
| def run_eval(loader): |
| preds, ys, losses = [], [], 0.0 |
| nn_ = 0 |
| with torch.no_grad(): |
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| logits = model(x, mask) |
| losses += criterion(logits, y).item() * y.size(0) |
| nn_ += y.size(0) |
| preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| ys.extend(y.cpu().numpy()) |
| return preds, ys, losses / max(nn_, 1) |
|
|
| seen_preds, seen_ys, seen_loss = run_eval(test_seen_loader) |
| uns_preds, uns_ys, _ = run_eval(test_unseen_loader) |
|
|
| seen_acc = accuracy_score(seen_ys, seen_preds) |
| seen_f1 = f1_score(seen_ys, seen_preds, average='macro', |
| labels=[c for c in range(NUM_CLASSES) |
| if c != SCENE_LABELS[held_out]], |
| zero_division=0) |
| uns_pred_counts = np.bincount(uns_preds, minlength=NUM_CLASSES) |
| |
| dominant = int(np.argmax(uns_pred_counts)) |
| dominant_frac = float(uns_pred_counts[dominant] / max(len(uns_preds), 1)) |
| held_out_pred_frac = float(uns_pred_counts[SCENE_LABELS[held_out]] / |
| max(len(uns_preds), 1)) |
|
|
| scheduler.step(seen_loss) |
|
|
| print(f" E{epoch:3d} | tr {tr_loss:.4f} te {seen_loss:.4f} | " |
| f"seen_acc {seen_acc:.3f} f1 {seen_f1:.3f} | " |
| f"unseen -> {dominant} ({dominant_frac:.2f}) " |
| f"held_out_predicted_frac {held_out_pred_frac:.3f} | " |
| f"{time.time()-t0:.1f}s") |
|
|
| if seen_f1 > best_seen_f1: |
| best_seen_f1 = seen_f1 |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| best_epoch = epoch |
| patience_counter = 0 |
| best_metrics = { |
| 'seen_acc': float(seen_acc), |
| 'seen_f1': float(seen_f1), |
| 'unseen_dominant_class': int(dominant), |
| 'unseen_dominant_frac': float(dominant_frac), |
| 'unseen_pred_hist': uns_pred_counts.tolist(), |
| 'n_unseen': len(uns_preds), |
| 'held_out_pred_frac': float(held_out_pred_frac), |
| } |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" Early stop (best epoch {best_epoch})") |
| break |
|
|
| if best_state is not None: |
| torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) |
|
|
| results = { |
| 'experiment': exp_name, |
| 'model': args.model, |
| 'modalities': modalities, |
| 'held_out_scene': held_out, |
| 'held_out_label': SCENE_LABELS[held_out], |
| 'seed': args.seed, |
| 'best_epoch': best_epoch, |
| 'best_metrics': best_metrics, |
| 'train_size': len(train_sub), |
| 'test_seen_size': len(test_seen_sub), |
| 'test_unseen_size': len(test_unseen_sub), |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved: {out_dir}/results.json") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--model', type=str, default='transformer') |
| p.add_argument('--fusion', type=str, default='early') |
| p.add_argument('--modalities', type=str, default='mocap,emg,imu') |
| p.add_argument('--held_out_scene', type=str, required=True, |
| help='One of s1..s8') |
| p.add_argument('--epochs', type=int, default=60) |
| p.add_argument('--batch_size', type=int, default=16) |
| p.add_argument('--lr', type=float, default=1e-3) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--patience', type=int, default=12) |
| p.add_argument('--augment', action='store_true') |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| args = p.parse_args() |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|