| import os |
| import tqdm |
| import utils |
| import hydra |
| import torch |
| import einops |
| import datasets |
| import numpy as np |
| import torch.distributed |
| from pathlib import Path |
| from datetime import timedelta |
| from omegaconf import OmegaConf |
| from accelerate import Accelerator |
| from collections import OrderedDict |
| from workspaces.base import Workspace |
| from torch.utils.data import DataLoader |
| from accelerate.logging import get_logger |
| from accelerate import InitProcessGroupKwargs, DistributedDataParallelKwargs |
|
|
| os.environ["WANDB_START_METHOD"] = "thread" |
| logger = get_logger(__name__) |
|
|
|
|
| class Trainer: |
| def __init__(self, cfg): |
| process_group_kwargs = InitProcessGroupKwargs( |
| timeout=timedelta(seconds=cfg.timeout_seconds) |
| ) |
| dist_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| self.cfg = cfg |
| self.effective_batch_size = self.cfg.batch_size |
| self.accelerator = Accelerator( |
| log_with="wandb", kwargs_handlers=[process_group_kwargs, dist_kwargs] |
| ) |
| logger.info(f"Mixed precision: {self.accelerator.mixed_precision}") |
| utils.set_seed_everywhere(cfg.seed) |
|
|
| self.job_num, self.work_dir = utils.get_hydra_jobnum_workdir() |
|
|
| |
| if torch.distributed.is_initialized(): |
| objs = [str(self.work_dir)] |
| torch.distributed.broadcast_object_list(objs, 0) |
| self.work_dir = Path(objs[0]) |
| self.accelerator.wait_for_everyone() |
| logger.info("Saving to {}".format(self.work_dir)) |
| os.chdir(self.work_dir) |
| self.work_dir = Path(os.getcwd()) |
|
|
| self.dataset = hydra.utils.instantiate(cfg.env.dataset) |
| self.train_set, self.test_set = self._split_and_slice_dataset(self.dataset) |
| self._setup_loaders(batch_size=self.cfg.batch_size) |
| self._init_tracker(cfg) |
|
|
| |
| self.encoder = None |
| self.projector = None |
| self.ssl = None |
| self._init_encoder() |
| self._init_projector() |
| self._init_ssl() |
|
|
| self.workspace: Workspace = hydra.utils.instantiate( |
| self.cfg.env.workspace, |
| cfg=self.cfg, |
| work_dir=self.work_dir, |
| _recursive_=False, |
| ) |
| self.workspace.set_dataset(self.dataset) |
|
|
| self.log_components = OrderedDict() |
| self.epoch = 0 |
|
|
| def _init_tracker(self, cfg): |
| wandb_cfg = OmegaConf.to_container(cfg, resolve=True) |
| wandb_cfg["effective_batch_size"] = self.effective_batch_size |
| wandb_cfg["save_path"] = str(self.work_dir) |
| self.accelerator.init_trackers( |
| project_name=cfg.project, |
| config=wandb_cfg, |
| init_kwargs={ |
| "wandb": { |
| "reinit": False, |
| "settings": {"start_method": "thread"}, |
| }, |
| }, |
| ) |
| if self.accelerator.is_main_process: |
| self.wandb_run = self.accelerator.get_tracker("wandb", unwrap=True) |
| logger.info("wandb run url: %s", self.wandb_run.get_url()) |
|
|
| def _init_encoder(self): |
| if self.encoder is None: |
| self.encoder = hydra.utils.instantiate(self.cfg.encoder) |
| if self.cfg.sync_bn: |
| self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm( |
| self.encoder |
| ) |
| self.encoder_optim = torch.optim.AdamW( |
| params=self.encoder.parameters(), |
| lr=self.cfg.ssl_lr, |
| weight_decay=self.cfg.ssl_weight_decay, |
| betas=tuple(self.cfg.betas), |
| ) |
| ( |
| self.encoder, |
| self.encoder_optim, |
| ) = self.accelerator.prepare(self.encoder, self.encoder_optim) |
| if self.accelerator.is_main_process: |
| self.wandb_run.watch(self.encoder) |
|
|
| def _init_projector(self): |
| if self.projector is None: |
| self.projector = hydra.utils.instantiate( |
| self.cfg.projector, _recursive_=False |
| ) |
| self.projector_optim: torch.optim.Optimizer = ( |
| self.projector.configure_optimizers( |
| lr=self.cfg.ssl_lr, |
| weight_decay=self.cfg.ssl_weight_decay, |
| betas=tuple(self.cfg.betas), |
| ) |
| ) |
| ( |
| self.projector, |
| self.projector_optim, |
| ) = self.accelerator.prepare(self.projector, self.projector_optim) |
|
|
| def _init_ssl(self): |
| if self.ssl is None: |
| self.ssl = hydra.utils.instantiate( |
| self.cfg.ssl, |
| encoder=self.encoder, |
| projector=self.projector, |
| ) |
|
|
| def _split_and_slice_dataset(self, dataset): |
| kwargs = { |
| "train_fraction": self.cfg.train_fraction, |
| "random_seed": self.cfg.seed, |
| "window_size": self.cfg.window_size, |
| "future_conditional": (self.cfg.goal_conditional == "future"), |
| "min_future_sep": self.cfg.min_future_sep, |
| "future_seq_len": self.cfg.goal_seq_len, |
| "num_extra_predicted_actions": self.cfg.num_extra_predicted_actions, |
| } |
| return datasets.core.get_train_val_sliced(dataset, **kwargs) |
|
|
| def _setup_loaders(self, batch_size=None, pin_memory=True, num_workers=None): |
| if num_workers is None: |
| num_workers = self.cfg.num_workers |
| kwargs = { |
| "batch_size": batch_size or self.cfg.batch_size, |
| "num_workers": num_workers, |
| "pin_memory": pin_memory, |
| } |
| |
| assert kwargs["batch_size"] % self.accelerator.num_processes == 0, ( |
| "Batch size must be divisible by the number of processes. " |
| f"Got {kwargs['batch_size']} and {self.accelerator.num_processes}." |
| ) |
| kwargs["batch_size"] = kwargs["batch_size"] // self.accelerator.num_processes |
| self.train_loader = DataLoader(self.train_set, shuffle=True, **kwargs) |
| self.test_loader = DataLoader(self.test_set, shuffle=False, **kwargs) |
|
|
| self.train_loader = self.accelerator.prepare(self.train_loader) |
| self.test_loader = self.accelerator.prepare(self.test_loader) |
|
|
| def train(self): |
| if self.cfg.use_lr_scheduling: |
| lr = self.adjust_lr() |
| self.log_append("metrics", 1, {"lr": lr}) |
| self.ssl.adjust_beta(self.epoch, self.cfg.num_epochs) |
| pbar = tqdm.tqdm( |
| self.train_loader, |
| desc=f"Training epoch {self.epoch}", |
| disable=not self.accelerator.is_main_process, |
| ncols=80, |
| ) |
| for data in pbar: |
| obs, _, _ = data |
|
|
| with self.accelerator.autocast(): |
| ( |
| obs_enc, |
| obs_proj, |
| ssl_loss, |
| ssl_loss_components, |
| ) = self.ssl.forward(obs) |
| self.log_append("ssl_train", len(obs), ssl_loss_components) |
| self.accelerator.backward(ssl_loss, retain_graph=True) |
|
|
| if self.cfg.clip_grad_norm: |
| self.accelerator.clip_grad_norm_( |
| self.encoder.parameters(), self.cfg.clip_grad_norm |
| ) |
| self.accelerator.clip_grad_norm_( |
| self.projector.parameters(), self.cfg.clip_grad_norm |
| ) |
| self.accelerator.clip_grad_norm_( |
| self.ssl.parameters(), self.cfg.clip_grad_norm |
| ) |
|
|
| self.encoder_optim.step() |
| self.projector_optim.step() |
| self.ssl.step() |
|
|
| self.encoder_optim.zero_grad(set_to_none=True) |
| self.projector_optim.zero_grad(set_to_none=True) |
|
|
| def eval(self): |
| if self.cfg.eval_offline: |
| |
| self.workspace.set_models( |
| encoder=self.encoder, |
| projector=self.projector, |
| ) |
| offline_eval_results = self.workspace.run_offline_eval() |
| if self.accelerator.is_main_process: |
| self.log_append("env_offline_eval", 1, offline_eval_results) |
|
|
| with utils.inference.eval_mode( |
| self.encoder, |
| self.projector, |
| no_grad=True, |
| ): |
| |
| self.eval_loss = 0 |
| for data in self.test_loader: |
| obs, _, _ = data |
|
|
| ( |
| obs_enc, |
| obs_proj, |
| ssl_loss, |
| ssl_loss_components, |
| ) = self.ssl.forward(obs) |
| ssl_loss = self.accelerator.gather_for_metrics(ssl_loss).mean() |
| ssl_loss_components = utils.reduce_dict( |
| torch.mean, |
| self.accelerator.gather_for_metrics(ssl_loss_components), |
| ) |
| self.log_append( |
| "ssl_eval", |
| len(obs), |
| ssl_loss_components, |
| ) |
|
|
| flat_obs_enc = self.accelerator.gather_for_metrics(obs_enc) |
| flat_obs_enc = einops.rearrange(flat_obs_enc, "N T V E -> (N T V) E") |
| obs_enc_mean_std = flat_obs_enc.std(dim=0).mean() |
| obs_enc_mean_norm = flat_obs_enc.norm(dim=-1).mean() |
| self.log_append( |
| "metrics", |
| len(flat_obs_enc), |
| { |
| "obs_enc_mean_std": obs_enc_mean_std, |
| "obs_enc_mean_norm": obs_enc_mean_norm, |
| }, |
| ) |
|
|
| flat_obs_proj = self.accelerator.gather_for_metrics(obs_proj) |
| flat_obs_proj = einops.rearrange(flat_obs_proj, "N T V Z -> (N T V) Z") |
| obs_proj_mean_std = flat_obs_proj.std(dim=0).mean() |
| obs_proj_mean_norm = flat_obs_proj.norm(dim=-1).mean() |
| self.log_append( |
| "metrics", |
| len(flat_obs_proj), |
| { |
| "obs_proj_mean_std": obs_proj_mean_std, |
| "obs_proj_mean_norm": obs_proj_mean_norm, |
| }, |
| ) |
|
|
| def run(self): |
| snapshot = Path(self.work_dir) / "snapshot.pt" |
| if snapshot.exists(): |
| print(f"Resuming: {snapshot}") |
| self.load_snapshot() |
|
|
| self.train_iterator = tqdm.trange( |
| self.epoch, |
| self.cfg.num_epochs, |
| disable=not self.accelerator.is_main_process, |
| ncols=80, |
| ) |
| self.train_iterator.set_description("Training") |
| |
| self.log_components = OrderedDict() |
| for epoch in self.train_iterator: |
| self.epoch = epoch |
| self.train() |
| self.eval() |
| self.flush_log(step=self.epoch, iterator=self.train_iterator) |
| if (self.epoch + 1) % self.cfg.save_every_epochs == 0: |
| self.save_snapshot() |
|
|
| self.accelerator.wait_for_everyone() |
| self.accelerator.end_training() |
|
|
| return float(self.eval_loss) |
|
|
| def save_snapshot(self): |
| self.accelerator.wait_for_everyone() |
| if self.accelerator.is_main_process: |
| self._keys_to_save = [ |
| "encoder", |
| "projector", |
| "encoder_optim", |
| "projector_optim", |
| "ssl", |
| "epoch", |
| ] |
| payload = {} |
| |
| for k in self._keys_to_save: |
| if hasattr(self.__dict__[k], "module"): |
| payload[k] = self.accelerator.unwrap_model(self.__dict__[k]) |
| else: |
| payload[k] = self.__dict__[k] |
| with (self.work_dir / "snapshot.pt").open("wb") as f: |
| torch.save(payload, f) |
| with (self.work_dir / "encoder.pt").open("wb") as f: |
| torch.save(payload["encoder"], f) |
| with (self.work_dir / f"snapshot_{self.epoch}.pt").open("wb") as f: |
| torch.save(payload, f) |
| with (self.work_dir / f"encoder_{self.epoch}.pt").open("wb") as f: |
| torch.save(payload["encoder"], f) |
|
|
| def load_snapshot(self): |
| with (self.work_dir / "snapshot.pt").open("rb") as f: |
| payload = torch.load(f) |
| for k, v in payload.items(): |
| self.__dict__[k] = v |
| not_in_payload = set(self._keys_to_save) - set(payload.keys()) |
| if len(not_in_payload): |
| logger.warning("Keys not found in snapshot: %s", not_in_payload) |
|
|
| def log_append(self, log_key, length, loss_components): |
| for key, value in loss_components.items(): |
| if isinstance(value, torch.Tensor): |
| value = value.detach().cpu().item() |
| key_name = f"{log_key}/{key}" |
| count, sum = self.log_components.get(key_name, (0, 0.0)) |
| self.log_components[key_name] = ( |
| count + length, |
| sum + (length * value), |
| ) |
|
|
| def flush_log(self, step, iterator=None): |
| log_components = OrderedDict() |
| iterator_log_component = OrderedDict() |
| for key, value in self.log_components.items(): |
| count, sum = value |
| to_log = sum / count |
| log_components[key] = to_log |
| |
| log_key, name_key = key.split("/") |
| iterator_log_name = f"{log_key[0]}{name_key[0]}".upper() |
| iterator_log_component[iterator_log_name] = to_log |
| postfix = ",".join( |
| "{}:{:.2e}".format(key, iterator_log_component[key]) |
| for key in iterator_log_component.keys() |
| ) |
| if iterator is not None: |
| iterator.set_postfix_str(postfix) |
| self.accelerator.log(log_components, step=step) |
| logger.info(f"[{self.job_num}] Epoch {self.epoch}: {log_components}") |
| self.log_components = OrderedDict() |
|
|
| def adjust_lr(self): |
| |
| """Decays the learning rate with half-cycle cosine after warmup""" |
| |
| if self.epoch < self.cfg.warmup_epochs: |
| lr = self.cfg.ssl_lr * self.epoch / self.cfg.warmup_epochs |
| else: |
| lr = self.cfg.ssl_lr * 0.5 * (1.0 + np.cos(np.pi * (self.epoch - self.cfg.warmup_epochs) / (self.cfg.num_epochs - self.cfg.warmup_epochs))) |
| |
| optimizers = [self.encoder_optim, self.projector_optim] |
| for optim in optimizers: |
| for param_group in optim.param_groups: |
| param_group["lr"] = lr |
| return lr |
|
|
|
|
| @hydra.main(version_base="1.2", config_path="configs", config_name="train") |
| def main(cfg): |
| trainer = Trainer(cfg) |
| eval_loss = trainer.run() |
| return eval_loss |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|