from pretrain_custom_lm import TrainLMConfig, train_lm, save_lm cfg = TrainLMConfig() cfg.num_epochs = 1 cfg.batch_size = 8 cfg.max_seq_len = 128 cfg.save_dir = 'checkpoints_lm' model = train_lm(cfg) save_lm(model, 'epistemic_lm.pt') print('Done pretraining')