#!/usr/bin/env python3 """ Unified T1 scene recognition training script. Supports 8 methods: 7 published baselines + SyncFuse. Usage: python3 train_baselines_t1.py --method stgcn --seed 42 python3 train_baselines_t1.py --method ctrgcn --seed 42 python3 train_baselines_t1.py --method limu_bert --seed 42 python3 train_baselines_t1.py --method emg_cnn --seed 42 python3 train_baselines_t1.py --method actionsense --seed 42 python3 train_baselines_t1.py --method mult --seed 42 python3 train_baselines_t1.py --method perceiver --seed 42 python3 train_baselines_t1.py --method syncfuse --seed 42 \ --mod_dropout_p 0.3 --use_xmod_shift --use_learned_late \ --pretrained_dir /path/to/pretrained """ import os import sys import json import time import random import argparse import numpy as np import torch import torch.nn as nn from sklearn.metrics import accuracy_score, f1_score, confusion_matrix sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import get_dataloaders, NUM_CLASSES from nets.baselines_published.baselines import ( STGCN, CTRGCN, LIMUBert, EMGCNN, ActionSenseLSTM, MulT, PerceiverIO, ) from nets.baselines_published.syncfuse import SyncFuse # --------------------------------------------------------------------------- # Modality configurations per method # --------------------------------------------------------------------------- METHOD_MODALITIES = { # Single-modality baselines 'stgcn': ['mocap'], 'ctrgcn': ['mocap'], 'limu_bert': ['imu'], 'emg_cnn': ['emg'], # Multi-modality baselines 'actionsense': ['mocap', 'emg', 'eyetrack', 'imu'], # drop pressure due to sparse coverage 'mult': ['mocap', 'emg', 'imu'], # MulT is 3-modal 'perceiver': ['mocap', 'emg', 'eyetrack', 'imu'], # Our method (4-mod) 'syncfuse': ['mocap', 'emg', 'eyetrack', 'imu'], # Our method, 3-mod IME variant for direct comparison with tab:scene-published 'syncfuse_ime': ['mocap', 'emg', 'imu'], # Plain Transformer+Late head (matches tab:scene-published setup) under # both 3-mod (IME) and 4-mod protocols, for fair re-evaluation 'transformer_late': ['mocap', 'emg', 'eyetrack', 'imu'], # 4-mod 'transformer_late_ime': ['mocap', 'emg', 'imu'], # 3-mod IME # Single-modality IMU-only Transformer (diagnostic) 'transformer_imu': ['imu'], } def set_seed(seed): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) def build_model(method, modality_dims, num_classes, args): """Construct the requested baseline or SyncFuse.""" if method == 'stgcn': return STGCN(modality_dims['mocap'], num_classes, hidden=args.hidden_dim, n_joints=args.n_joints) if method == 'ctrgcn': return CTRGCN(modality_dims['mocap'], num_classes, hidden=args.hidden_dim, n_joints=args.n_joints) if method == 'limu_bert': return LIMUBert(modality_dims['imu'], num_classes, hidden=args.hidden_dim, n_layers=4, n_heads=4) if method == 'emg_cnn': return EMGCNN(modality_dims['emg'], num_classes, hidden=64) if method == 'actionsense': return ActionSenseLSTM(modality_dims, num_classes, hidden=args.hidden_dim) if method == 'mult': return MulT(modality_dims, num_classes, d_model=args.hidden_dim, n_layers=2, n_heads=4) if method == 'perceiver': return PerceiverIO(modality_dims, num_classes, latent_dim=args.hidden_dim, n_latents=32, n_layers=3, n_heads=4) if method in ('syncfuse', 'syncfuse_ime'): m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, n_heads=4, n_layers=2, use_xmod_shift=args.use_xmod_shift, use_learned_late=args.use_learned_late) if args.pretrained_dir: pt_paths = {} for m_name in modality_dims: p = os.path.join(args.pretrained_dir, f'transformer_{m_name}_early/model_best.pt') if os.path.exists(p): pt_paths[m_name] = p if pt_paths: m.load_pretrained(pt_paths, freeze=args.freeze_pretrained) return m if method == 'transformer_imu': # SyncFuse with single IMU branch + no extras + no pretrain = matches # the "Transformer (ours) IMU early" row in tab:scene-published. m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, n_heads=4, n_layers=2, use_xmod_shift=False, use_learned_late=False) return m if method in ('transformer_late', 'transformer_late_ime'): # Reuse SyncFuse class with all extras OFF == per-modality Transformer # branches + simple late mean fusion + optional pretrained init. m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, n_heads=4, n_layers=2, use_xmod_shift=False, use_learned_late=False) if args.pretrained_dir: pt_paths = {} for m_name in modality_dims: p = os.path.join(args.pretrained_dir, f'transformer_{m_name}_early/model_best.pt') if os.path.exists(p): pt_paths[m_name] = p if pt_paths: m.load_pretrained(pt_paths, freeze=args.freeze_pretrained) return m raise ValueError(f"Unknown method: {method}") # --------------------------------------------------------------------------- # Train / Eval loop # --------------------------------------------------------------------------- def train_one_epoch(model, loader, criterion, optimizer, device, args): model.train() total_loss, n, all_preds, all_labels = 0., 0, [], [] for x, y, mask, _ in loader: x, y, mask = x.to(device), y.to(device), mask.to(device) optimizer.zero_grad() if args.method in ('syncfuse', 'syncfuse_ime'): logits = model(x, mask, mod_dropout_p=args.mod_dropout_p, training_time=True) elif args.method in ('transformer_late', 'transformer_late_ime', 'transformer_imu'): logits = model(x, mask, mod_dropout_p=0.0, training_time=False) elif args.method in ('stgcn', 'ctrgcn'): logits = model(x, mask) # these take only MoCap slice == all of x elif args.method == 'limu_bert': logits = model(x, mask) # IMU only elif args.method == 'emg_cnn': logits = model(x, mask) else: logits = model(x, mask) loss = criterion(logits, y) loss.backward() trainable = [p for p in model.parameters() if p.requires_grad] if trainable: torch.nn.utils.clip_grad_norm_(trainable, 1.0) optimizer.step() total_loss += loss.item() * y.size(0); n += y.size(0) all_preds.extend(logits.argmax(dim=1).cpu().numpy()) all_labels.extend(y.cpu().numpy()) return total_loss / max(n, 1), accuracy_score(all_labels, all_preds) @torch.no_grad() def evaluate(model, loader, criterion, device, args): model.eval() total_loss, n, all_preds, all_labels = 0., 0, [], [] for x, y, mask, _ in loader: x, y, mask = x.to(device), y.to(device), mask.to(device) if args.method in ('syncfuse', 'syncfuse_ime', 'transformer_late', 'transformer_late_ime', 'transformer_imu'): logits = model(x, mask, training_time=False) else: logits = model(x, mask) loss = criterion(logits, y) total_loss += loss.item() * y.size(0); n += y.size(0) all_preds.extend(logits.argmax(dim=1).cpu().numpy()) all_labels.extend(y.cpu().numpy()) if n == 0: return 0., 0., 0., np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int) acc = accuracy_score(all_labels, all_preds) f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES))) return total_loss / n, acc, f1, cm # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def run(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") modalities = METHOD_MODALITIES[args.method] print(f"Method: {args.method} | Modalities: {modalities} | Seed: {args.seed}") train_loader, val_loader, test_loader, info = get_dataloaders( modalities, batch_size=args.batch_size, downsample=args.downsample, ) if info['val_size'] == 0: val_loader = test_loader print(f"Train={info['train_size']} Test={info['test_size']} " f"feat_dim={info['feat_dim']} mod_dims={info['modality_dims']}") model = build_model(args.method, info['modality_dims'], info['num_classes'], args).to(device) total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Params: {trainable:,}/{total:,}") class_weights = info['class_weights'].to(device) criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=args.label_smoothing) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, ) exp_name = f"{args.method}_seed{args.seed}" if args.tag: exp_name += f"_{args.tag}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) # Select model by MAX val F1 (more robust than min val_loss when val == 25-sample test). best_val_f1, best_val_loss, best_epoch, patience_counter = -1.0, float('inf'), 0, 0 best_cm = None for epoch in range(1, args.epochs + 1): t0 = time.time() tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, args) va_loss, va_acc, va_f1, va_cm = evaluate(model, val_loader, criterion, device, args) scheduler.step(va_loss) print(f" E{epoch:3d} | tr {tr_loss:.4f}/{tr_acc:.3f} | " f"va {va_loss:.4f}/{va_acc:.3f} f1 {va_f1:.3f} | " f"{time.time()-t0:.1f}s") if va_f1 > best_val_f1: best_val_f1 = va_f1; best_val_loss = va_loss best_epoch = epoch; patience_counter = 0 best_cm = va_cm torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) else: patience_counter += 1 if patience_counter >= args.patience: print(f" Early stop at epoch {epoch} (best {best_epoch})") break best_f1 = best_val_f1 # Final test eval on best model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), weights_only=True)) te_loss, te_acc, te_f1, te_cm = evaluate(model, test_loader, criterion, device, args) print(f"\n== Test == loss {te_loss:.4f} acc {te_acc:.3f} f1 {te_f1:.3f}") results = { 'method': args.method, 'modalities': modalities, 'seed': args.seed, 'best_epoch': best_epoch, 'best_val_f1': float(best_f1), 'test_acc': float(te_acc), 'test_f1': float(te_f1), 'n_params': trainable, 'n_params_total': total, 'confusion_matrix': te_cm.tolist(), 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Saved: {out_dir}/results.json") return results def main(): p = argparse.ArgumentParser() p.add_argument('--method', type=str, required=True, choices=list(METHOD_MODALITIES.keys())) p.add_argument('--epochs', type=int, default=80) p.add_argument('--batch_size', type=int, default=16) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--weight_decay', type=float, default=1e-4) p.add_argument('--hidden_dim', type=int, default=128) p.add_argument('--downsample', type=int, default=5) p.add_argument('--patience', type=int, default=15) p.add_argument('--label_smoothing', type=float, default=0.1) p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, required=True) p.add_argument('--tag', type=str, default='') # Method-specific p.add_argument('--n_joints', type=int, default=52) # SyncFuse specific p.add_argument('--mod_dropout_p', type=float, default=0.3) p.add_argument('--use_xmod_shift', action='store_true') p.add_argument('--use_learned_late', action='store_true') p.add_argument('--pretrained_dir', type=str, default='') p.add_argument('--freeze_pretrained', action='store_true', help='Freeze loaded pretrained backbones (default: fine-tune them)') args = p.parse_args() run(args) if __name__ == '__main__': main()