File size: 273 Bytes
cf52a55
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from pretrain_custom_lm import TrainLMConfig, train_lm, save_lm

cfg = TrainLMConfig()
cfg.num_epochs = 1
cfg.batch_size = 8
cfg.max_seq_len = 128
cfg.save_dir = 'checkpoints_lm'

model = train_lm(cfg)
save_lm(model, 'epistemic_lm.pt')
print('Done pretraining')