| import os |
| import gc |
| import sys |
| import torch |
| import wandb |
| import torch.nn as nn |
| import lightning.pytorch as pl |
|
|
| from omegaconf import OmegaConf |
| from lightning.pytorch.strategies import DDPStrategy |
| from lightning.pytorch.loggers import WandbLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
| from src.lm.memdlm.diffusion_module import MembraneDiffusion |
| from src.lm.memdlm.dataloader import MembraneDataModule, get_datasets |
| from src.utils.model_utils import apply_rdm_freezing |
|
|
| wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
|
|
|
|
| |
| config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml") |
|
|
| |
| datasets = get_datasets(config) |
| data_module = MembraneDataModule( |
| config=config, |
| train_dataset=datasets['train'], |
| val_dataset=datasets['val'], |
| test_dataset=datasets['test'], |
| ) |
|
|
| |
| wandb.init(project=config.wandb.project, name=config.wandb.name) |
| wandb_logger = WandbLogger(**config.wandb) |
|
|
| |
| lr_monitor = LearningRateMonitor(logging_interval="step") |
| checkpoint_callback = ModelCheckpoint( |
| monitor="val/loss", |
| save_top_k=1, |
| mode="min", |
| dirpath=config.checkpointing.save_dir, |
| filename="best_model", |
| every_n_train_steps=config.checkpointing.save_every_n_steps |
| ) |
|
|
| |
| trainer = pl.Trainer( |
| max_steps=config.training.max_steps, |
| max_epochs=None, |
| accelerator="cuda" if torch.cuda.is_available() else "cpu", |
| devices=config.training.devices if config.training.mode=='train' else [0], |
| strategy=DDPStrategy(find_unused_parameters=True), |
| callbacks=[checkpoint_callback, lr_monitor], |
| logger=wandb_logger, |
| log_every_n_steps=config.training.log_every_n_steps |
| ) |
|
|
|
|
| |
| ckpt_path = config.checkpointing.save_dir |
| try: os.makedirs(ckpt_path, exist_ok=False) |
| except FileExistsError: pass |
|
|
| |
| diffusion = MembraneDiffusion(config) |
| diffusion.validate_config() |
|
|
| |
| model_type = "evoflow" |
| if config.training.mode == "train": |
| apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type) |
| trainer.fit(diffusion, datamodule=data_module) |
|
|
| elif config.training.mode == "test": |
| state_dict = diffusion.get_state_dict(config.checkpointing.best_ckpt_path) |
| diffusion.load_state_dict(state_dict) |
| trainer.test(diffusion, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path) |
|
|
| elif config.training.mode == "resume_from_checkpoint": |
| state_dict = diffusion.get_state_dict(config.training.resume_ckpt_path) |
| diffusion.load_state_dict(state_dict) |
| apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type) |
| trainer.fit(diffusion, datamodule=data_module, ckpt_path=ckpt_path) |
|
|
| wandb.finish() |
|
|
|
|