| |
| """ |
| Unified T1 scene recognition training script. |
| Supports 8 methods: 7 published baselines + SyncFuse. |
| |
| Usage: |
| python3 train_baselines_t1.py --method stgcn --seed 42 |
| python3 train_baselines_t1.py --method ctrgcn --seed 42 |
| python3 train_baselines_t1.py --method limu_bert --seed 42 |
| python3 train_baselines_t1.py --method emg_cnn --seed 42 |
| python3 train_baselines_t1.py --method actionsense --seed 42 |
| python3 train_baselines_t1.py --method mult --seed 42 |
| python3 train_baselines_t1.py --method perceiver --seed 42 |
| python3 train_baselines_t1.py --method syncfuse --seed 42 \ |
| --mod_dropout_p 0.3 --use_xmod_shift --use_learned_late \ |
| --pretrained_dir /path/to/pretrained |
| """ |
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import get_dataloaders, NUM_CLASSES |
| from nets.baselines_published.baselines import ( |
| STGCN, CTRGCN, LIMUBert, EMGCNN, ActionSenseLSTM, MulT, PerceiverIO, |
| ) |
| from nets.baselines_published.syncfuse import SyncFuse |
|
|
|
|
| |
| |
| |
|
|
| METHOD_MODALITIES = { |
| |
| 'stgcn': ['mocap'], |
| 'ctrgcn': ['mocap'], |
| 'limu_bert': ['imu'], |
| 'emg_cnn': ['emg'], |
| |
| 'actionsense': ['mocap', 'emg', 'eyetrack', 'imu'], |
| 'mult': ['mocap', 'emg', 'imu'], |
| 'perceiver': ['mocap', 'emg', 'eyetrack', 'imu'], |
| |
| 'syncfuse': ['mocap', 'emg', 'eyetrack', 'imu'], |
| |
| 'syncfuse_ime': ['mocap', 'emg', 'imu'], |
| |
| |
| 'transformer_late': ['mocap', 'emg', 'eyetrack', 'imu'], |
| 'transformer_late_ime': ['mocap', 'emg', 'imu'], |
| |
| 'transformer_imu': ['imu'], |
| } |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def build_model(method, modality_dims, num_classes, args): |
| """Construct the requested baseline or SyncFuse.""" |
| if method == 'stgcn': |
| return STGCN(modality_dims['mocap'], num_classes, |
| hidden=args.hidden_dim, n_joints=args.n_joints) |
| if method == 'ctrgcn': |
| return CTRGCN(modality_dims['mocap'], num_classes, |
| hidden=args.hidden_dim, n_joints=args.n_joints) |
| if method == 'limu_bert': |
| return LIMUBert(modality_dims['imu'], num_classes, |
| hidden=args.hidden_dim, n_layers=4, n_heads=4) |
| if method == 'emg_cnn': |
| return EMGCNN(modality_dims['emg'], num_classes, hidden=64) |
| if method == 'actionsense': |
| return ActionSenseLSTM(modality_dims, num_classes, hidden=args.hidden_dim) |
| if method == 'mult': |
| return MulT(modality_dims, num_classes, d_model=args.hidden_dim, |
| n_layers=2, n_heads=4) |
| if method == 'perceiver': |
| return PerceiverIO(modality_dims, num_classes, |
| latent_dim=args.hidden_dim, n_latents=32, |
| n_layers=3, n_heads=4) |
| if method in ('syncfuse', 'syncfuse_ime'): |
| m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, |
| n_heads=4, n_layers=2, |
| use_xmod_shift=args.use_xmod_shift, |
| use_learned_late=args.use_learned_late) |
| if args.pretrained_dir: |
| pt_paths = {} |
| for m_name in modality_dims: |
| p = os.path.join(args.pretrained_dir, |
| f'transformer_{m_name}_early/model_best.pt') |
| if os.path.exists(p): |
| pt_paths[m_name] = p |
| if pt_paths: |
| m.load_pretrained(pt_paths, freeze=args.freeze_pretrained) |
| return m |
| if method == 'transformer_imu': |
| |
| |
| m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, |
| n_heads=4, n_layers=2, |
| use_xmod_shift=False, |
| use_learned_late=False) |
| return m |
| if method in ('transformer_late', 'transformer_late_ime'): |
| |
| |
| m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim, |
| n_heads=4, n_layers=2, |
| use_xmod_shift=False, |
| use_learned_late=False) |
| if args.pretrained_dir: |
| pt_paths = {} |
| for m_name in modality_dims: |
| p = os.path.join(args.pretrained_dir, |
| f'transformer_{m_name}_early/model_best.pt') |
| if os.path.exists(p): |
| pt_paths[m_name] = p |
| if pt_paths: |
| m.load_pretrained(pt_paths, freeze=args.freeze_pretrained) |
| return m |
| raise ValueError(f"Unknown method: {method}") |
|
|
|
|
| |
| |
| |
|
|
| def train_one_epoch(model, loader, criterion, optimizer, device, args): |
| model.train() |
| total_loss, n, all_preds, all_labels = 0., 0, [], [] |
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| optimizer.zero_grad() |
| if args.method in ('syncfuse', 'syncfuse_ime'): |
| logits = model(x, mask, mod_dropout_p=args.mod_dropout_p, |
| training_time=True) |
| elif args.method in ('transformer_late', 'transformer_late_ime', |
| 'transformer_imu'): |
| logits = model(x, mask, mod_dropout_p=0.0, training_time=False) |
| elif args.method in ('stgcn', 'ctrgcn'): |
| logits = model(x, mask) |
| elif args.method == 'limu_bert': |
| logits = model(x, mask) |
| elif args.method == 'emg_cnn': |
| logits = model(x, mask) |
| else: |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| loss.backward() |
| trainable = [p for p in model.parameters() if p.requires_grad] |
| if trainable: |
| torch.nn.utils.clip_grad_norm_(trainable, 1.0) |
| optimizer.step() |
| total_loss += loss.item() * y.size(0); n += y.size(0) |
| all_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| all_labels.extend(y.cpu().numpy()) |
| return total_loss / max(n, 1), accuracy_score(all_labels, all_preds) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device, args): |
| model.eval() |
| total_loss, n, all_preds, all_labels = 0., 0, [], [] |
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| if args.method in ('syncfuse', 'syncfuse_ime', |
| 'transformer_late', 'transformer_late_ime', |
| 'transformer_imu'): |
| logits = model(x, mask, training_time=False) |
| else: |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| total_loss += loss.item() * y.size(0); n += y.size(0) |
| all_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| all_labels.extend(y.cpu().numpy()) |
| if n == 0: |
| return 0., 0., 0., np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int) |
| acc = accuracy_score(all_labels, all_preds) |
| f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
| cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES))) |
| return total_loss / n, acc, f1, cm |
|
|
|
|
| |
| |
| |
|
|
| def run(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| modalities = METHOD_MODALITIES[args.method] |
| print(f"Method: {args.method} | Modalities: {modalities} | Seed: {args.seed}") |
|
|
| train_loader, val_loader, test_loader, info = get_dataloaders( |
| modalities, batch_size=args.batch_size, downsample=args.downsample, |
| ) |
| if info['val_size'] == 0: |
| val_loader = test_loader |
| print(f"Train={info['train_size']} Test={info['test_size']} " |
| f"feat_dim={info['feat_dim']} mod_dims={info['modality_dims']}") |
|
|
| model = build_model(args.method, info['modality_dims'], info['num_classes'], |
| args).to(device) |
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Params: {trainable:,}/{total:,}") |
|
|
| class_weights = info['class_weights'].to(device) |
| criterion = nn.CrossEntropyLoss(weight=class_weights, |
| label_smoothing=args.label_smoothing) |
| optimizer = torch.optim.Adam( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=args.lr, weight_decay=args.weight_decay, |
| ) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, |
| ) |
|
|
| exp_name = f"{args.method}_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| |
| best_val_f1, best_val_loss, best_epoch, patience_counter = -1.0, float('inf'), 0, 0 |
| best_cm = None |
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, |
| optimizer, device, args) |
| va_loss, va_acc, va_f1, va_cm = evaluate(model, val_loader, criterion, |
| device, args) |
| scheduler.step(va_loss) |
| print(f" E{epoch:3d} | tr {tr_loss:.4f}/{tr_acc:.3f} | " |
| f"va {va_loss:.4f}/{va_acc:.3f} f1 {va_f1:.3f} | " |
| f"{time.time()-t0:.1f}s") |
| if va_f1 > best_val_f1: |
| best_val_f1 = va_f1; best_val_loss = va_loss |
| best_epoch = epoch; patience_counter = 0 |
| best_cm = va_cm |
| torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" Early stop at epoch {epoch} (best {best_epoch})") |
| break |
| best_f1 = best_val_f1 |
|
|
| |
| model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), |
| weights_only=True)) |
| te_loss, te_acc, te_f1, te_cm = evaluate(model, test_loader, criterion, |
| device, args) |
| print(f"\n== Test == loss {te_loss:.4f} acc {te_acc:.3f} f1 {te_f1:.3f}") |
|
|
| results = { |
| 'method': args.method, |
| 'modalities': modalities, |
| 'seed': args.seed, |
| 'best_epoch': best_epoch, |
| 'best_val_f1': float(best_f1), |
| 'test_acc': float(te_acc), |
| 'test_f1': float(te_f1), |
| 'n_params': trainable, |
| 'n_params_total': total, |
| 'confusion_matrix': te_cm.tolist(), |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| print(f"Saved: {out_dir}/results.json") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--method', type=str, required=True, |
| choices=list(METHOD_MODALITIES.keys())) |
| p.add_argument('--epochs', type=int, default=80) |
| p.add_argument('--batch_size', type=int, default=16) |
| p.add_argument('--lr', type=float, default=1e-3) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--patience', type=int, default=15) |
| p.add_argument('--label_smoothing', type=float, default=0.1) |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| |
| p.add_argument('--n_joints', type=int, default=52) |
| |
| p.add_argument('--mod_dropout_p', type=float, default=0.3) |
| p.add_argument('--use_xmod_shift', action='store_true') |
| p.add_argument('--use_learned_late', action='store_true') |
| p.add_argument('--pretrained_dir', type=str, default='') |
| p.add_argument('--freeze_pretrained', action='store_true', |
| help='Freeze loaded pretrained backbones (default: fine-tune them)') |
| args = p.parse_args() |
| run(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|