"""Staged freeze/thaw training helpers.""" import torch.nn as nn def freeze(m): for p in m.parameters(): p.requires_grad = False def unfreeze(m): for p in m.parameters(): p.requires_grad = True def count_trainable(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) def setup_stage1(model): unfreeze(model) return model def setup_stage2(model): freeze(model) for name, p in model.named_parameters(): if "xattn" in name or "tproj" in name: p.requires_grad = True return model def setup_stage3(model): unfreeze(model) return model