File size: 612 Bytes
785614a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | """Staged freeze/thaw training helpers."""
import torch.nn as nn
def freeze(m):
for p in m.parameters(): p.requires_grad = False
def unfreeze(m):
for p in m.parameters(): p.requires_grad = True
def count_trainable(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad)
def setup_stage1(model):
unfreeze(model)
return model
def setup_stage2(model):
freeze(model)
for name, p in model.named_parameters():
if "xattn" in name or "tproj" in name:
p.requires_grad = True
return model
def setup_stage3(model):
unfreeze(model)
return model
|