| """Staged freeze/thaw training helpers.""" |
| import torch.nn as nn |
|
|
| def freeze(m): |
| for p in m.parameters(): p.requires_grad = False |
|
|
| def unfreeze(m): |
| for p in m.parameters(): p.requires_grad = True |
|
|
| def count_trainable(m): |
| return sum(p.numel() for p in m.parameters() if p.requires_grad) |
|
|
| def setup_stage1(model): |
| unfreeze(model) |
| return model |
|
|
| def setup_stage2(model): |
| freeze(model) |
| for name, p in model.named_parameters(): |
| if "xattn" in name or "tproj" in name: |
| p.requires_grad = True |
| return model |
|
|
| def setup_stage3(model): |
| unfreeze(model) |
| return model |
|
|