PULSE-code / experiments /tasks /train_baselines_t1.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Unified T1 scene recognition training script.
Supports 8 methods: 7 published baselines + SyncFuse.
Usage:
python3 train_baselines_t1.py --method stgcn --seed 42
python3 train_baselines_t1.py --method ctrgcn --seed 42
python3 train_baselines_t1.py --method limu_bert --seed 42
python3 train_baselines_t1.py --method emg_cnn --seed 42
python3 train_baselines_t1.py --method actionsense --seed 42
python3 train_baselines_t1.py --method mult --seed 42
python3 train_baselines_t1.py --method perceiver --seed 42
python3 train_baselines_t1.py --method syncfuse --seed 42 \
--mod_dropout_p 0.3 --use_xmod_shift --use_learned_late \
--pretrained_dir /path/to/pretrained
"""
import os
import sys
import json
import time
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import get_dataloaders, NUM_CLASSES
from nets.baselines_published.baselines import (
STGCN, CTRGCN, LIMUBert, EMGCNN, ActionSenseLSTM, MulT, PerceiverIO,
)
from nets.baselines_published.syncfuse import SyncFuse
# ---------------------------------------------------------------------------
# Modality configurations per method
# ---------------------------------------------------------------------------
METHOD_MODALITIES = {
# Single-modality baselines
'stgcn': ['mocap'],
'ctrgcn': ['mocap'],
'limu_bert': ['imu'],
'emg_cnn': ['emg'],
# Multi-modality baselines
'actionsense': ['mocap', 'emg', 'eyetrack', 'imu'], # drop pressure due to sparse coverage
'mult': ['mocap', 'emg', 'imu'], # MulT is 3-modal
'perceiver': ['mocap', 'emg', 'eyetrack', 'imu'],
# Our method (4-mod)
'syncfuse': ['mocap', 'emg', 'eyetrack', 'imu'],
# Our method, 3-mod IME variant for direct comparison with tab:scene-published
'syncfuse_ime': ['mocap', 'emg', 'imu'],
# Plain Transformer+Late head (matches tab:scene-published setup) under
# both 3-mod (IME) and 4-mod protocols, for fair re-evaluation
'transformer_late': ['mocap', 'emg', 'eyetrack', 'imu'], # 4-mod
'transformer_late_ime': ['mocap', 'emg', 'imu'], # 3-mod IME
# Single-modality IMU-only Transformer (diagnostic)
'transformer_imu': ['imu'],
}
def set_seed(seed):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
def build_model(method, modality_dims, num_classes, args):
"""Construct the requested baseline or SyncFuse."""
if method == 'stgcn':
return STGCN(modality_dims['mocap'], num_classes,
hidden=args.hidden_dim, n_joints=args.n_joints)
if method == 'ctrgcn':
return CTRGCN(modality_dims['mocap'], num_classes,
hidden=args.hidden_dim, n_joints=args.n_joints)
if method == 'limu_bert':
return LIMUBert(modality_dims['imu'], num_classes,
hidden=args.hidden_dim, n_layers=4, n_heads=4)
if method == 'emg_cnn':
return EMGCNN(modality_dims['emg'], num_classes, hidden=64)
if method == 'actionsense':
return ActionSenseLSTM(modality_dims, num_classes, hidden=args.hidden_dim)
if method == 'mult':
return MulT(modality_dims, num_classes, d_model=args.hidden_dim,
n_layers=2, n_heads=4)
if method == 'perceiver':
return PerceiverIO(modality_dims, num_classes,
latent_dim=args.hidden_dim, n_latents=32,
n_layers=3, n_heads=4)
if method in ('syncfuse', 'syncfuse_ime'):
m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim,
n_heads=4, n_layers=2,
use_xmod_shift=args.use_xmod_shift,
use_learned_late=args.use_learned_late)
if args.pretrained_dir:
pt_paths = {}
for m_name in modality_dims:
p = os.path.join(args.pretrained_dir,
f'transformer_{m_name}_early/model_best.pt')
if os.path.exists(p):
pt_paths[m_name] = p
if pt_paths:
m.load_pretrained(pt_paths, freeze=args.freeze_pretrained)
return m
if method == 'transformer_imu':
# SyncFuse with single IMU branch + no extras + no pretrain = matches
# the "Transformer (ours) IMU early" row in tab:scene-published.
m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim,
n_heads=4, n_layers=2,
use_xmod_shift=False,
use_learned_late=False)
return m
if method in ('transformer_late', 'transformer_late_ime'):
# Reuse SyncFuse class with all extras OFF == per-modality Transformer
# branches + simple late mean fusion + optional pretrained init.
m = SyncFuse(modality_dims, num_classes, hidden=args.hidden_dim,
n_heads=4, n_layers=2,
use_xmod_shift=False,
use_learned_late=False)
if args.pretrained_dir:
pt_paths = {}
for m_name in modality_dims:
p = os.path.join(args.pretrained_dir,
f'transformer_{m_name}_early/model_best.pt')
if os.path.exists(p):
pt_paths[m_name] = p
if pt_paths:
m.load_pretrained(pt_paths, freeze=args.freeze_pretrained)
return m
raise ValueError(f"Unknown method: {method}")
# ---------------------------------------------------------------------------
# Train / Eval loop
# ---------------------------------------------------------------------------
def train_one_epoch(model, loader, criterion, optimizer, device, args):
model.train()
total_loss, n, all_preds, all_labels = 0., 0, [], []
for x, y, mask, _ in loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
optimizer.zero_grad()
if args.method in ('syncfuse', 'syncfuse_ime'):
logits = model(x, mask, mod_dropout_p=args.mod_dropout_p,
training_time=True)
elif args.method in ('transformer_late', 'transformer_late_ime',
'transformer_imu'):
logits = model(x, mask, mod_dropout_p=0.0, training_time=False)
elif args.method in ('stgcn', 'ctrgcn'):
logits = model(x, mask) # these take only MoCap slice == all of x
elif args.method == 'limu_bert':
logits = model(x, mask) # IMU only
elif args.method == 'emg_cnn':
logits = model(x, mask)
else:
logits = model(x, mask)
loss = criterion(logits, y)
loss.backward()
trainable = [p for p in model.parameters() if p.requires_grad]
if trainable:
torch.nn.utils.clip_grad_norm_(trainable, 1.0)
optimizer.step()
total_loss += loss.item() * y.size(0); n += y.size(0)
all_preds.extend(logits.argmax(dim=1).cpu().numpy())
all_labels.extend(y.cpu().numpy())
return total_loss / max(n, 1), accuracy_score(all_labels, all_preds)
@torch.no_grad()
def evaluate(model, loader, criterion, device, args):
model.eval()
total_loss, n, all_preds, all_labels = 0., 0, [], []
for x, y, mask, _ in loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
if args.method in ('syncfuse', 'syncfuse_ime',
'transformer_late', 'transformer_late_ime',
'transformer_imu'):
logits = model(x, mask, training_time=False)
else:
logits = model(x, mask)
loss = criterion(logits, y)
total_loss += loss.item() * y.size(0); n += y.size(0)
all_preds.extend(logits.argmax(dim=1).cpu().numpy())
all_labels.extend(y.cpu().numpy())
if n == 0:
return 0., 0., 0., np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES)))
return total_loss / n, acc, f1, cm
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def run(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
modalities = METHOD_MODALITIES[args.method]
print(f"Method: {args.method} | Modalities: {modalities} | Seed: {args.seed}")
train_loader, val_loader, test_loader, info = get_dataloaders(
modalities, batch_size=args.batch_size, downsample=args.downsample,
)
if info['val_size'] == 0:
val_loader = test_loader
print(f"Train={info['train_size']} Test={info['test_size']} "
f"feat_dim={info['feat_dim']} mod_dims={info['modality_dims']}")
model = build_model(args.method, info['modality_dims'], info['num_classes'],
args).to(device)
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Params: {trainable:,}/{total:,}")
class_weights = info['class_weights'].to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights,
label_smoothing=args.label_smoothing)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr, weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6,
)
exp_name = f"{args.method}_seed{args.seed}"
if args.tag:
exp_name += f"_{args.tag}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
# Select model by MAX val F1 (more robust than min val_loss when val == 25-sample test).
best_val_f1, best_val_loss, best_epoch, patience_counter = -1.0, float('inf'), 0, 0
best_cm = None
for epoch in range(1, args.epochs + 1):
t0 = time.time()
tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, args)
va_loss, va_acc, va_f1, va_cm = evaluate(model, val_loader, criterion,
device, args)
scheduler.step(va_loss)
print(f" E{epoch:3d} | tr {tr_loss:.4f}/{tr_acc:.3f} | "
f"va {va_loss:.4f}/{va_acc:.3f} f1 {va_f1:.3f} | "
f"{time.time()-t0:.1f}s")
if va_f1 > best_val_f1:
best_val_f1 = va_f1; best_val_loss = va_loss
best_epoch = epoch; patience_counter = 0
best_cm = va_cm
torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt'))
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f" Early stop at epoch {epoch} (best {best_epoch})")
break
best_f1 = best_val_f1
# Final test eval on best
model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'),
weights_only=True))
te_loss, te_acc, te_f1, te_cm = evaluate(model, test_loader, criterion,
device, args)
print(f"\n== Test == loss {te_loss:.4f} acc {te_acc:.3f} f1 {te_f1:.3f}")
results = {
'method': args.method,
'modalities': modalities,
'seed': args.seed,
'best_epoch': best_epoch,
'best_val_f1': float(best_f1),
'test_acc': float(te_acc),
'test_f1': float(te_f1),
'n_params': trainable,
'n_params_total': total,
'confusion_matrix': te_cm.tolist(),
'args': vars(args),
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Saved: {out_dir}/results.json")
return results
def main():
p = argparse.ArgumentParser()
p.add_argument('--method', type=str, required=True,
choices=list(METHOD_MODALITIES.keys()))
p.add_argument('--epochs', type=int, default=80)
p.add_argument('--batch_size', type=int, default=16)
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--weight_decay', type=float, default=1e-4)
p.add_argument('--hidden_dim', type=int, default=128)
p.add_argument('--downsample', type=int, default=5)
p.add_argument('--patience', type=int, default=15)
p.add_argument('--label_smoothing', type=float, default=0.1)
p.add_argument('--seed', type=int, default=42)
p.add_argument('--output_dir', type=str, required=True)
p.add_argument('--tag', type=str, default='')
# Method-specific
p.add_argument('--n_joints', type=int, default=52)
# SyncFuse specific
p.add_argument('--mod_dropout_p', type=float, default=0.3)
p.add_argument('--use_xmod_shift', action='store_true')
p.add_argument('--use_learned_late', action='store_true')
p.add_argument('--pretrained_dir', type=str, default='')
p.add_argument('--freeze_pretrained', action='store_true',
help='Freeze loaded pretrained backbones (default: fine-tune them)')
args = p.parse_args()
run(args)
if __name__ == '__main__':
main()