File size: 612 Bytes
785614a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Staged freeze/thaw training helpers."""
import torch.nn as nn

def freeze(m):
    for p in m.parameters(): p.requires_grad = False

def unfreeze(m):
    for p in m.parameters(): p.requires_grad = True

def count_trainable(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def setup_stage1(model):
    unfreeze(model)
    return model

def setup_stage2(model):
    freeze(model)
    for name, p in model.named_parameters():
        if "xattn" in name or "tproj" in name:
            p.requires_grad = True
    return model

def setup_stage3(model):
    unfreeze(model)
    return model