File size: 273 Bytes
cf52a55 | 1 2 3 4 5 6 7 8 9 10 11 12 | from pretrain_custom_lm import TrainLMConfig, train_lm, save_lm
cfg = TrainLMConfig()
cfg.num_epochs = 1
cfg.batch_size = 8
cfg.max_seq_len = 128
cfg.save_dir = 'checkpoints_lm'
model = train_lm(cfg)
save_lm(model, 'epistemic_lm.pt')
print('Done pretraining')
|