| import os |
| import sys |
| import argparse |
| import yaml |
| import datetime |
| import numpy as np |
| from pathlib import Path |
| from sklearn.metrics import f1_score |
|
|
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, DistributedSampler |
| from torch.cuda.amp import GradScaler, autocast |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'hiera')) |
|
|
| from hiera.hiera_mae import HieraClassifier |
| from data.downstream_dataset import fMRITaskDataset, fMRITaskDataset1, EmoFMRIDataset, HCPtaskDataset |
| from data.adni_dataset import ADNIDataset |
|
|
| from utils.utils import MetricLogger, load_config, log_to_file, count_parameters, save_checkpoint, load_checkpoint, LabelScaler |
| from utils.optim import create_optimizer, create_lr_scheduler |
| from utils.ddp import setup_distributed, set_seed, cleanup_distributed |
|
|
|
|
| def create_model(config): |
| """Create Hiera Classifier model from config""" |
| task_config = config['task'] |
| exp_config = config['experiment'] |
|
|
| model_config = config['model'] |
| pretrained_checkpoint_path = exp_config.get('pretrained_checkpoint', None) |
|
|
| if pretrained_checkpoint_path: |
| pretrain_config_path = Path(pretrained_checkpoint_path).parent.parent / 'config.yaml' |
| if os.path.exists(pretrain_config_path): |
| print(f"Loading model architecture from pretrained config: {pretrain_config_path}") |
| pretrain_config = load_config(pretrain_config_path) |
| model_config = pretrain_config['model'] |
| else: |
| print(f"Warning: Pretrained config not found at {pretrain_config_path}. Using finetune config for model architecture.") |
|
|
| model = HieraClassifier( |
| num_classes=task_config['num_classes'], |
| task_type=task_config['task_type'], |
| input_size=tuple(model_config['input_size']), |
| in_chans=model_config['in_chans'], |
| patch_kernel=tuple(model_config['patch_kernel']), |
| patch_stride=tuple(model_config['patch_stride']), |
| patch_padding=tuple(model_config['patch_padding']), |
| q_stride=tuple(model_config['q_stride']), |
| mask_unit_size=tuple(model_config['mask_unit_size']), |
| embed_dim=model_config['embed_dim'], |
| num_heads=model_config['num_heads'], |
| stages=tuple(model_config['stages']), |
| q_pool=model_config['q_pool'], |
| mlp_ratio=model_config['mlp_ratio'], |
| ) |
|
|
| |
| if pretrained_checkpoint_path: |
| if os.path.exists(pretrained_checkpoint_path): |
| model.load_pretrained_mae(pretrained_checkpoint_path) |
| else: |
| print(f"Warning: Pretrained checkpoint not found at {pretrained_checkpoint_path}. Model is randomly initialized.") |
| else: |
| print("Warning: No pretrained checkpoint specified. Model is randomly initialized.") |
|
|
| return model |
|
|
|
|
|
|
| def create_dataloaders(config, is_distributed, rank, world_size): |
| """Create train, validation, and test dataloaders""" |
| data_config = config['data'] |
| task_config = config['task'] |
|
|
| train_dataset = fMRITaskDataset( |
| data_root=data_config['data_root'], |
| datasets=data_config['datasets'], |
| split_suffixes=data_config['train_split_suffixes'], |
| crop_length=data_config['input_seq_len'], |
| label_csv_path=task_config['csv'], |
| task_type=task_config['task_type'] |
| ) |
|
|
| val_dataset = fMRITaskDataset( |
| data_root=data_config['data_root'], |
| datasets=data_config['datasets'], |
| split_suffixes=data_config['val_split_suffixes'], |
| crop_length=data_config['input_seq_len'], |
| label_csv_path=task_config['csv'], |
| task_type=task_config['task_type'] |
| ) |
|
|
|
|
| test_dataset = fMRITaskDataset( |
| data_root=data_config['data_root'], |
| datasets=data_config['datasets'], |
| split_suffixes=data_config.get('test_split_suffixes', ['test']), |
| crop_length=data_config['input_seq_len'], |
| label_csv_path=task_config['csv'], |
| task_type=task_config['task_type'] |
| ) |
| |
|
|
|
|
| |
| if is_distributed: |
| train_sampler = DistributedSampler( |
| train_dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=True, |
| seed=config['experiment']['seed'] |
| ) |
| val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
| test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
| else: |
| train_sampler = None |
| val_sampler = None |
| test_sampler = None |
|
|
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=data_config['batch_size'], |
| sampler=train_sampler, |
| shuffle=(train_sampler is None), |
| num_workers=data_config['num_workers'], |
| pin_memory=data_config['pin_memory'], |
| prefetch_factor=data_config.get('prefetch_factor', 2), |
| drop_last=True |
| ) |
|
|
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=data_config['batch_size'], |
| sampler=val_sampler, |
| shuffle=False, |
| num_workers=data_config['num_workers'], |
| pin_memory=data_config['pin_memory'], |
| prefetch_factor=data_config.get('prefetch_factor', 2), |
| drop_last=False |
| ) |
|
|
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=data_config['batch_size'], |
| sampler=test_sampler, |
| shuffle=False, |
| num_workers=data_config['num_workers'], |
| pin_memory=data_config['pin_memory'], |
| prefetch_factor=data_config.get('prefetch_factor', 2), |
| drop_last=False |
| ) |
|
|
| return train_loader, val_loader, test_loader, train_sampler |
|
|
|
|
| def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, config, |
| rank, world_size, label_scaler=None,log_file=None): |
| """Train for one epoch""" |
| model.train() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| header = f'Epoch: [{epoch}]' |
|
|
| train_config = config['training'] |
| log_config = config['logging'] |
| task_config = config['task'] |
|
|
| accum_iter = train_config['accum_iter'] |
| use_amp = train_config['use_amp'] |
| clip_grad = train_config.get('clip_grad', None) |
|
|
| optimizer.zero_grad() |
|
|
| for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(train_loader, log_config['print_freq'], header)): |
| |
| if data_iter_step % accum_iter == 0: |
| scheduler.step() |
|
|
| |
| samples = samples.cuda(rank, non_blocking=True) |
| labels = labels.cuda(rank, non_blocking=True) |
|
|
|
|
| |
| with autocast(enabled=use_amp): |
| outputs = model(samples) |
|
|
| |
| if task_config['task_type'] == 'classification': |
| if labels.dim() > 1: |
| labels = labels.squeeze() |
|
|
| loss = criterion(outputs, labels) |
| |
| _, predicted = outputs.max(1) |
| correct = predicted.eq(labels).sum().item() |
| accuracy = correct / labels.size(0) |
| else: |
| if label_scaler is not None: |
| target_for_loss = label_scaler.transform(labels) |
| else: |
| target_for_loss = labels |
| loss = criterion(outputs.squeeze(), target_for_loss.squeeze()) |
| accuracy = 0.0 |
|
|
| loss = loss / accum_iter |
|
|
| |
| if use_amp: |
| scaler.scale(loss).backward() |
|
|
| if (data_iter_step + 1) % accum_iter == 0: |
| if clip_grad is not None: |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), clip_grad) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| else: |
| loss.backward() |
|
|
| if (data_iter_step + 1) % accum_iter == 0: |
| if clip_grad is not None: |
| nn.utils.clip_grad_norm_(model.parameters(), clip_grad) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| |
| loss_value = loss.item() * accum_iter |
| if not np.isfinite(loss_value): |
| print(f"Loss is {loss_value}, stopping training") |
| sys.exit(1) |
|
|
| metric_logger.update(loss=loss_value) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| if task_config['task_type'] == 'classification': |
| metric_logger.update(acc=accuracy) |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print(f"Averaged stats: {metric_logger}") |
|
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, data_loader, criterion, config, rank, epoch=None, label_scaler=None, mode='val'): |
|
|
| model.eval() |
| metric_logger = MetricLogger(delimiter=" ") |
| header = f'{mode.capitalize()} Epoch: [{epoch}]' if epoch is not None else f'{mode.capitalize()}:' |
| |
| task_type = config['task']['task_type'] |
|
|
| all_preds, all_targets = [], [] |
|
|
| for samples, labels in metric_logger.log_every(data_loader, 50, header): |
| samples = samples.cuda(rank, non_blocking=True) |
| labels = labels.cuda(rank, non_blocking=True) |
|
|
| outputs = model(samples) |
|
|
| if task_type == 'classification': |
| labels = labels.squeeze().long() if labels.dim() > 1 else labels.long() |
| loss = criterion(outputs, labels) |
| |
| preds = outputs.argmax(1) |
| acc = (preds == labels).float().mean().item() |
| metric_logger.update(loss=loss.item(), acc=acc) |
| |
| all_preds.append(preds.cpu()) |
| all_targets.append(labels.cpu()) |
| |
| else: |
| if label_scaler is not None: |
| target_norm = label_scaler.transform(labels) |
| loss = criterion(outputs.view(-1), target_norm.view(-1)) |
| |
| metric_logger.update(loss=loss.item()) |
| all_preds.append(outputs.detach().cpu().view(-1)) |
| all_targets.append(target_norm.detach().cpu().view(-1)) |
|
|
| if len(all_preds) > 0: |
| all_preds = torch.cat(all_preds) |
| all_targets = torch.cat(all_targets) |
|
|
| if task_type == 'classification': |
| f1 = f1_score(all_targets.numpy(), all_preds.numpy(), average='weighted') |
| metric_logger.update(f1=f1) |
| else: |
| mse = torch.mean((all_preds - all_targets) ** 2).item() |
| mae = torch.mean(torch.abs(all_preds - all_targets)).item() |
| |
| ss_res = torch.sum((all_targets - all_preds) ** 2) |
| ss_tot = torch.sum((all_targets - all_targets.mean()) ** 2) |
| r2 = (1 - ss_res / (ss_tot + 1e-8)).item() |
| |
| vx = all_preds - all_preds.mean() |
| vy = all_targets - all_targets.mean() |
| corr = (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + 1e-8)).item() |
| |
| metric_logger.update(mse=mse, mae=mae, r2=r2, corr=corr) |
|
|
| metric_logger.synchronize_between_processes() |
| |
| if rank == 0: |
| print(f"[{mode.upper()}] Global stats: {metric_logger}") |
|
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| def main(): |
| """Main fine-tuning function""" |
| |
| parser = argparse.ArgumentParser(description='Hiera MAE 4D fMRI Downstream Fine-tuning') |
| parser.add_argument('--config', type=str, default='configs/finetune_config.yaml', |
| help='Path to config file') |
| parser.add_argument('--resume', type=str, default=None, |
| help='Path to checkpoint to resume from') |
| parser.add_argument('--output_dir', type=str, default=None, |
| help='Output directory (overrides config)') |
| args = parser.parse_args() |
|
|
| |
| config = load_config(args.config) |
|
|
| |
| if args.resume is not None: |
| config['experiment']['resume'] = args.resume |
| if args.output_dir is not None: |
| config['experiment']['output_dir'] = args.output_dir |
|
|
| |
| is_distributed, rank, world_size, gpu = setup_distributed() |
|
|
| |
| set_seed(config['experiment']['seed'], rank) |
|
|
| |
| if rank == 0: |
| output_dir = Path(config['experiment']['output_dir']) |
| checkpoint_dir = output_dir / 'checkpoints' |
| log_dir = output_dir / 'logs' |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| log_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with open(output_dir / 'config.yaml', 'w') as f: |
| yaml.dump(config, f, default_flow_style=False) |
|
|
| |
| log_file = output_dir / 'training_log.txt' |
| with open(log_file, 'w') as f: |
| f.write(f"Fine-tuning started at {datetime.datetime.now()}\n") |
| f.write("="*80 + "\n") |
| f.write(f"Config: {args.config}\n") |
| f.write(f"Output directory: {config['experiment']['output_dir']}\n") |
| f.write(f"Task type: {config['task']['task_type']}\n") |
| f.write("="*80 + "\n\n") |
| else: |
| checkpoint_dir = None |
| log_file = None |
|
|
| if is_distributed: |
| dist.barrier() |
|
|
| model = create_model(config) |
| model = model.cuda(gpu) |
|
|
| if rank == 0: |
| print("\nAnalyzing model architecture...") |
| count_parameters(model, verbose=True) |
|
|
| if is_distributed: |
| model = DDP(model, device_ids=[gpu], find_unused_parameters=True) |
|
|
| model_without_ddp = model.module if is_distributed else model |
|
|
| if rank == 0: |
| print("Creating dataloaders...") |
| train_loader, val_loader, test_loader, train_sampler = create_dataloaders( |
| config, is_distributed, rank, world_size |
| ) |
|
|
| label_scaler = None |
| if config['task']['task_type'] == 'regression': |
| if rank == 0: |
| mean_val = config['task']['mean'] |
| scale_val = config['task']['std'] |
| print(f"StandardScaler fit complete. Mean: {mean_val:.4f}, Std: {scale_val:.4f}") |
|
|
| norm_mean = torch.tensor(mean_val, device=gpu, dtype=torch.float32) |
| norm_std = torch.tensor(scale_val, device=gpu, dtype=torch.float32) |
|
|
| if is_distributed: |
| dist.broadcast(norm_mean, src=0) |
| dist.broadcast(norm_std, src=0) |
|
|
| label_scaler = LabelScaler(norm_mean, norm_std) |
|
|
| if rank == 0: |
| print(f"Training samples: {len(train_loader.dataset)}") |
| print(f"Validation samples: {len(val_loader.dataset)}") |
| print(f"Test samples: {len(test_loader.dataset)}") |
| print(f"Batches per epoch: {len(train_loader)}") |
|
|
| |
| task_config = config['task'] |
| if task_config['task_type'] == 'classification': |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.0) |
| else: |
| criterion = nn.MSELoss() |
|
|
| |
| if config['training'].get('freeze_encoder', False): |
| if rank == 0: |
| print("Freezing encoder weights. Only the head will be trained.") |
| for name, param in model_without_ddp.named_parameters(): |
| if 'head' not in name: |
| param.requires_grad = False |
|
|
| |
| if rank == 0: |
| print("Trainable parameters:") |
| for name, param in model_without_ddp.named_parameters(): |
| if param.requires_grad: |
| print(name) |
|
|
| |
| optimizer = create_optimizer(model_without_ddp, config) |
| scheduler = create_lr_scheduler(optimizer, config, len(train_loader)) |
|
|
| |
| scaler = GradScaler() if config['training']['use_amp'] else None |
|
|
| |
| start_epoch = 0 |
| best_metric = 0.0 |
| best_loss = float('inf') |
|
|
| if config['experiment'].get('resume', None) is not None: |
| start_epoch, best_metric, best_loss = load_checkpoint( |
| config['experiment']['resume'], |
| model_without_ddp, |
| optimizer, |
| scheduler, |
| scaler |
| ) |
| print(f"Resumed from epoch {start_epoch}. Best metric: {best_metric:.4f}, Best loss: {best_loss:.4f}") |
| else: |
| |
| if config['task']['task_type'] == 'classification': |
| best_metric = 0.0 |
| else: |
| best_metric = float('inf') |
|
|
| |
| if rank == 0: |
| print("Starting fine-tuning...") |
| print(f"Training from epoch {start_epoch} to {config['training']['epochs']}") |
|
|
| for epoch in range(start_epoch, config['training']['epochs']): |
| if is_distributed and train_sampler is not None: |
| train_sampler.set_epoch(epoch) |
|
|
| |
| train_stats = train_one_epoch( |
| model, train_loader, criterion, optimizer, scheduler, scaler, |
| epoch, config, rank, world_size, label_scaler, log_file |
| ) |
|
|
| |
| if rank == 0: |
| log_msg = f"Epoch {epoch} Training - " |
| log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in train_stats.items()]) |
| print(log_msg) |
| log_to_file(log_file, log_msg) |
|
|
| |
| if epoch % config['validation']['val_freq'] == 0 or epoch == config['training']['epochs'] - 1: |
| print(f"DEBUG: label_scaler type is {type(label_scaler)}, value is {label_scaler}") |
| val_stats = evaluate( |
| model, val_loader, criterion, config, rank, epoch, label_scaler, 'val' |
| ) |
| test_stats = evaluate(model, test_loader, criterion, config, rank, epoch, label_scaler, 'test' ) |
|
|
| |
| if rank == 0: |
| log_msg = f"Epoch {epoch} Validation - " |
| log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in val_stats.items()]) |
| print(log_msg) |
| log_to_file(log_file, log_msg) |
|
|
| log_msg = f"Epoch {epoch} Test - " |
| log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in test_stats.items()]) |
| print(log_msg) |
| log_to_file(log_file, log_msg) |
|
|
| |
| if rank == 0: |
| if task_config['task_type'] == 'classification': |
| |
| current_metric = val_stats.get('acc', 0.0) |
| is_best = current_metric > best_metric |
| if is_best: |
| best_metric = current_metric |
| best_loss = val_stats['loss'] |
| else: |
| |
| is_best = val_stats['loss'] < best_loss |
| if is_best: |
| best_loss = val_stats['loss'] |
| best_metric = -best_loss |
|
|
| checkpoint_state = { |
| 'epoch': epoch + 1, |
| 'model_state_dict': model_without_ddp.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'best_metric': best_metric, |
| 'best_loss': best_loss, |
| 'config': config, |
| 'train_stats': train_stats, |
| 'val_stats': val_stats, |
| } |
|
|
| if scaler is not None: |
| checkpoint_state['scaler_state_dict'] = scaler.state_dict() |
|
|
| save_checkpoint( |
| checkpoint_state, |
| is_best, |
| checkpoint_dir, |
| filename=f'checkpoint_epoch_{epoch}.pth' |
| ) |
|
|
| checkpoint_msg = f"Checkpoint saved at epoch {epoch}" |
| print(checkpoint_msg) |
| log_to_file(log_file, checkpoint_msg) |
|
|
| if is_best: |
| if task_config['task_type'] == 'classification': |
| best_msg = f"New best validation accuracy: {best_metric:.4f}" |
| else: |
| best_msg = f"New best validation loss: {best_loss:.4f}" |
| print(best_msg) |
| log_to_file(log_file, best_msg) |
|
|
| |
| if rank == 0 and (epoch + 1) % config['logging']['save_freq'] == 0: |
| checkpoint_state = { |
| 'epoch': epoch + 1, |
| 'model_state_dict': model_without_ddp.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'best_metric': best_metric, |
| 'best_loss': best_loss, |
| 'config': config, |
| } |
|
|
| if scaler is not None: |
| checkpoint_state['scaler_state_dict'] = scaler.state_dict() |
|
|
| save_checkpoint( |
| checkpoint_state, |
| False, |
| checkpoint_dir, |
| filename=f'checkpoint_epoch_{epoch}.pth' |
| ) |
|
|
|
|
| |
| cleanup_distributed() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|