| from config import * |
| from dataset import * |
| import pandas as pd |
| import json |
| import os |
| import copy |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| from pytorch_lightning import loggers as pl_loggers |
| from pytorch_lightning.callbacks import * |
| import torch |
|
|
|
|
| class ZipLoader: |
| def __init__(self, loaders): |
| self.loaders = loaders |
|
|
| def __len__(self): |
| return len(self.loaders[0]) |
|
|
| def __iter__(self): |
| for each in zip(*self.loaders): |
| yield each |
|
|
|
|
| class ClsModel(pl.LightningModule): |
| def __init__(self, conf: TrainConfig): |
| super().__init__() |
| assert conf.train_mode.is_manipulate() |
| if conf.seed is not None: |
| pl.seed_everything(conf.seed) |
|
|
| self.save_hyperparameters(conf.as_dict_jsonable()) |
| self.conf = conf |
|
|
| |
| if conf.train_mode == TrainMode.manipulate: |
| |
| |
| |
| self.model = conf.make_model_conf().make_model() |
| self.ema_model = copy.deepcopy(self.model) |
| self.model.requires_grad_(False) |
| self.ema_model.requires_grad_(False) |
| self.ema_model.eval() |
|
|
| if conf.pretrain is not None: |
| print(f'loading pretrain ... {conf.pretrain.name}') |
| state = torch.load(conf.pretrain.path, map_location='cpu') |
| print('step:', state['global_step']) |
| self.load_state_dict(state['state_dict'], strict=False) |
|
|
| |
| if conf.manipulate_znormalize: |
| print('loading latent stats ...') |
| state = torch.load(conf.latent_infer_path) |
| self.conds = state['conds'] |
| self.register_buffer('conds_mean', |
| state['conds_mean'][None, :]) |
| self.register_buffer('conds_std', state['conds_std'][None, :]) |
| else: |
| self.conds_mean = None |
| self.conds_std = None |
|
|
| if conf.manipulate_mode in [ManipulateMode.celebahq_all]: |
| num_cls = len(CelebAttrDataset.id_to_cls) |
| elif conf.manipulate_mode.is_single_class(): |
| num_cls = 1 |
| else: |
| raise NotImplementedError() |
|
|
| |
| if conf.train_mode == TrainMode.manipulate: |
| |
| self.classifier = nn.Linear(conf.style_ch, num_cls) |
| else: |
| raise NotImplementedError() |
|
|
| self.ema_classifier = copy.deepcopy(self.classifier) |
|
|
| def state_dict(self, *args, **kwargs): |
| |
| out = {} |
| for k, v in super().state_dict(*args, **kwargs).items(): |
| if k.startswith('model.'): |
| pass |
| elif k.startswith('ema_model.'): |
| pass |
| else: |
| out[k] = v |
| return out |
|
|
| def load_state_dict(self, state_dict, strict: bool = None): |
| if self.conf.train_mode == TrainMode.manipulate: |
| |
| if strict is None: |
| strict = False |
| else: |
| if strict is None: |
| strict = True |
| return super().load_state_dict(state_dict, strict=strict) |
|
|
| def normalize(self, cond): |
| cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( |
| self.device) |
| return cond |
|
|
| def denormalize(self, cond): |
| cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( |
| self.device) |
| return cond |
|
|
| def load_dataset(self): |
| if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot: |
| return CelebD2CAttrFewshotDataset( |
| cls_name=self.conf.manipulate_cls, |
| K=self.conf.manipulate_shots, |
| img_folder=data_paths['celeba'], |
| img_size=self.conf.img_size, |
| seed=self.conf.manipulate_seed, |
| all_neg=False, |
| do_augment=True, |
| ) |
| elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg: |
| |
| |
| img_folder = data_paths['celeba'] |
|
|
| return [ |
| CelebD2CAttrFewshotDataset( |
| cls_name=self.conf.manipulate_cls, |
| K=self.conf.manipulate_shots, |
| img_folder=img_folder, |
| img_size=self.conf.img_size, |
| only_cls_name=self.conf.manipulate_cls, |
| only_cls_value=1, |
| seed=self.conf.manipulate_seed, |
| all_neg=True, |
| do_augment=True), |
| CelebD2CAttrFewshotDataset( |
| cls_name=self.conf.manipulate_cls, |
| K=self.conf.manipulate_shots, |
| img_folder=img_folder, |
| img_size=self.conf.img_size, |
| only_cls_name=self.conf.manipulate_cls, |
| only_cls_value=-1, |
| seed=self.conf.manipulate_seed, |
| all_neg=True, |
| do_augment=True), |
| ] |
| elif self.conf.manipulate_mode == ManipulateMode.celebahq_all: |
| return CelebHQAttrDataset(data_paths['celebahq'], |
| self.conf.img_size, |
| data_paths['celebahq_anno'], |
| do_augment=True) |
| else: |
| raise NotImplementedError() |
|
|
| def setup(self, stage=None) -> None: |
| |
| |
| if self.conf.seed is not None: |
| seed = self.conf.seed * get_world_size() + self.global_rank |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| print('local seed:', seed) |
| |
|
|
| self.train_data = self.load_dataset() |
| if self.conf.manipulate_mode.is_fewshot(): |
| |
| if isinstance(self.train_data, list): |
| |
| |
| a, b = self.train_data |
| self.train_data = [ |
| Repeat(a, max(len(a), len(b))), |
| Repeat(b, max(len(a), len(b))), |
| ] |
| else: |
| self.train_data = Repeat(self.train_data, 100_000) |
|
|
| def train_dataloader(self): |
| |
| |
| conf = self.conf.clone() |
| conf.batch_size = self.batch_size |
| if isinstance(self.train_data, list): |
| dataloader = [] |
| for each in self.train_data: |
| dataloader.append( |
| conf.make_loader(each, shuffle=True, drop_last=True)) |
| dataloader = ZipLoader(dataloader) |
| else: |
| dataloader = conf.make_loader(self.train_data, |
| shuffle=True, |
| drop_last=True) |
| return dataloader |
|
|
| @property |
| def batch_size(self): |
| ws = get_world_size() |
| assert self.conf.batch_size % ws == 0 |
| return self.conf.batch_size // ws |
|
|
| def training_step(self, batch, batch_idx): |
| self.ema_model: BeatGANsAutoencModel |
| if isinstance(batch, tuple): |
| a, b = batch |
| imgs = torch.cat([a['img'], b['img']]) |
| labels = torch.cat([a['labels'], b['labels']]) |
| else: |
| imgs = batch['img'] |
| |
| labels = batch['labels'] |
|
|
| if self.conf.train_mode == TrainMode.manipulate: |
| self.ema_model.eval() |
| with torch.no_grad(): |
| |
| cond = self.ema_model.encoder(imgs) |
|
|
| if self.conf.manipulate_znormalize: |
| cond = self.normalize(cond) |
|
|
| |
| pred = self.classifier.forward(cond) |
| pred_ema = self.ema_classifier.forward(cond) |
| elif self.conf.train_mode == TrainMode.manipulate_img: |
| |
| pred = self.classifier.forward(imgs) |
| pred_ema = None |
| elif self.conf.train_mode == TrainMode.manipulate_imgt: |
| t, weight = self.T_sampler.sample(len(imgs), imgs.device) |
| imgs_t = self.sampler.q_sample(imgs, t) |
| pred = self.classifier.forward(imgs_t, t=t) |
| pred_ema = None |
| print('pred:', pred.shape) |
| else: |
| raise NotImplementedError() |
|
|
| if self.conf.manipulate_mode.is_celeba_attr(): |
| gt = torch.where(labels > 0, |
| torch.ones_like(labels).float(), |
| torch.zeros_like(labels).float()) |
| elif self.conf.manipulate_mode == ManipulateMode.relighting: |
| gt = labels |
| else: |
| raise NotImplementedError() |
|
|
| if self.conf.manipulate_loss == ManipulateLossType.bce: |
| loss = F.binary_cross_entropy_with_logits(pred, gt) |
| if pred_ema is not None: |
| loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt) |
| elif self.conf.manipulate_loss == ManipulateLossType.mse: |
| loss = F.mse_loss(pred, gt) |
| if pred_ema is not None: |
| loss_ema = F.mse_loss(pred_ema, gt) |
| else: |
| raise NotImplementedError() |
|
|
| self.log('loss', loss) |
| self.log('loss_ema', loss_ema) |
| return loss |
|
|
| def on_train_batch_end(self, outputs, batch, batch_idx: int, |
| dataloader_idx: int) -> None: |
| ema(self.classifier, self.ema_classifier, self.conf.ema_decay) |
|
|
| def configure_optimizers(self): |
| optim = torch.optim.Adam(self.classifier.parameters(), |
| lr=self.conf.lr, |
| weight_decay=self.conf.weight_decay) |
| return optim |
|
|
|
|
| def ema(source, target, decay): |
| source_dict = source.state_dict() |
| target_dict = target.state_dict() |
| for key in source_dict.keys(): |
| target_dict[key].data.copy_(target_dict[key].data * decay + |
| source_dict[key].data * (1 - decay)) |
|
|
|
|
| def train_cls(conf: TrainConfig, gpus): |
| print('conf:', conf.name) |
| model = ClsModel(conf) |
|
|
| if not os.path.exists(conf.logdir): |
| os.makedirs(conf.logdir) |
| checkpoint = ModelCheckpoint( |
| dirpath=f'{conf.logdir}', |
| save_last=True, |
| save_top_k=1, |
| |
| |
| ) |
| checkpoint_path = f'{conf.logdir}/last.ckpt' |
| if os.path.exists(checkpoint_path): |
| resume = checkpoint_path |
| else: |
| if conf.continue_from is not None: |
| |
| resume = conf.continue_from.path |
| else: |
| resume = None |
|
|
| tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, |
| name=None, |
| version='') |
|
|
| |
|
|
| plugins = [] |
| if len(gpus) == 1: |
| accelerator = None |
| else: |
| accelerator = 'ddp' |
| from pytorch_lightning.plugins import DDPPlugin |
| |
| plugins.append(DDPPlugin(find_unused_parameters=False)) |
|
|
| trainer = pl.Trainer( |
| max_steps=conf.total_samples // conf.batch_size_effective, |
| resume_from_checkpoint=resume, |
| gpus=gpus, |
| accelerator=accelerator, |
| precision=16 if conf.fp16 else 32, |
| callbacks=[ |
| checkpoint, |
| ], |
| replace_sampler_ddp=True, |
| logger=tb_logger, |
| accumulate_grad_batches=conf.accum_batches, |
| plugins=plugins, |
| ) |
| trainer.fit(model) |
|
|