| import os |
| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
| from my_dataset import MyDataset |
| from cldm.logger import ImageLogger |
| from pytorch_lightning.callbacks import ModelCheckpoint |
| from cldm.model import create_model, load_state_dict |
| from cldm.hack import disable_verbosity, enable_sliced_attention |
| from utils.config import * |
|
|
| |
|
|
| if __name__ == '__main__': |
| os.environ['CUDA_VISIBLE_DEVICES'] = '1' |
| save_memory = False |
|
|
| disable_verbosity() |
|
|
| if save_memory: |
| enable_sliced_attention() |
|
|
| |
| resume_path = model_root + "control_dresscode_ini.ckpt" |
| batch_size = 4 |
| logger_freq = 4600 |
| learning_rate = 1.0e-05 |
| sd_locked = True |
| only_mid_control = False |
|
|
| |
| model = create_model('configs/cldm_v2.yaml').cpu() |
| model.load_state_dict(load_state_dict(resume_path, location='cpu')) |
| model.learning_rate = learning_rate |
| model.sd_locked = sd_locked |
| model.only_mid_control = only_mid_control |
|
|
| |
| dataset = MyDataset() |
| print("******************************************************") |
| print(len(dataset)) |
| print("******************************************************") |
| dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) |
| logger = ImageLogger(batch_frequency=logger_freq) |
| |
| checkpoint_callback = ModelCheckpoint( |
| monitor=None, |
| dirpath='./hiera_logs', |
| filename='model_{epoch:02d}-{step:06d}', |
| save_top_k=-1, |
| save_last=True, |
| save_weights_only=False, |
| mode='min', |
| every_n_train_steps=50000 |
| ) |
|
|
| |
| callbacks = [logger, checkpoint_callback] |
| trainer = pl.Trainer(gpus=[1], precision=32, callbacks=callbacks, max_epochs=100) |
|
|
| |
| trainer.fit(model, dataloader) |
|
|