| """
|
| PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
| This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
| entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
| pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
|
|
| Usage
|
| Single GPU:
|
| python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| Example:
|
| python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
| python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
| Multi-GPU (single node):
|
| torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
| Example:
|
| torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
| torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
| Multi-Node Training:
|
| torchrun \
|
| --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
| --master_addr=<master_ip> --master_port=<port> \
|
| scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
|
|
| """
|
|
|
| import dataclasses
|
| import gc
|
| import logging
|
| import os
|
| import platform
|
| import shutil
|
| import time
|
|
|
| import jax
|
| import numpy as np
|
| import safetensors.torch
|
| import torch
|
| import torch.distributed as dist
|
| import torch.nn.parallel
|
| import tqdm
|
| import wandb
|
|
|
| import openpi.models.pi0_config
|
| from openpi.models_pytorch import pi0_pytorch, pi0_align_pytorch, projectors
|
| import openpi.shared.normalize as _normalize
|
| import openpi.training.config as _config
|
| import openpi.training.data_loader as _data
|
|
|
| from vggt.models.vggt import VGGT
|
|
|
|
|
| def init_logging():
|
| level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
|
|
| class CustomFormatter(logging.Formatter):
|
| def format(self, record):
|
| record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| return super().format(record)
|
|
|
| formatter = CustomFormatter(
|
| fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| datefmt="%H:%M:%S",
|
| )
|
| logger = logging.getLogger()
|
| logger.setLevel(logging.INFO)
|
| if not logger.handlers:
|
| ch = logging.StreamHandler()
|
| ch.setFormatter(formatter)
|
| logger.addHandler(ch)
|
| else:
|
| logger.handlers[0].setFormatter(formatter)
|
|
|
|
|
| def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| """Initialize wandb logging."""
|
| if not enabled:
|
| wandb.init(mode="disabled")
|
| return
|
|
|
| ckpt_dir = config.checkpoint_dir
|
| if not ckpt_dir.exists():
|
| raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
|
|
| if resuming:
|
| run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| wandb.init(id=run_id, resume="must", project=config.project_name)
|
| else:
|
| wandb.init(
|
| name=config.exp_name,
|
| config=dataclasses.asdict(config),
|
| project=config.project_name,
|
| )
|
| (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
|
|
|
|
| def setup_ddp():
|
| world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| use_ddp = world_size > 1
|
| if use_ddp and not torch.distributed.is_initialized():
|
| backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| torch.distributed.init_process_group(backend=backend, init_method="env://")
|
|
|
|
|
| if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
| os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
|
|
| local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| if torch.cuda.is_available():
|
| torch.cuda.set_device(device)
|
| return use_ddp, local_rank, device
|
|
|
|
|
| def cleanup_ddp():
|
| if torch.distributed.is_initialized():
|
| torch.distributed.barrier()
|
| torch.distributed.destroy_process_group()
|
|
|
|
|
| def set_seed(seed: int, local_rank: int):
|
| torch.manual_seed(seed + local_rank)
|
| np.random.seed(seed + local_rank)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed + local_rank)
|
|
|
|
|
| def build_datasets(config: _config.TrainConfig):
|
|
|
| data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
| return data_loader, data_loader.data_config()
|
|
|
|
|
| def get_model_state_dict(model):
|
| """Get state dict from model, handling DDP wrapper."""
|
| return (
|
| model.module.state_dict()
|
| if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| else model.state_dict()
|
| )
|
|
|
|
|
| def get_model_parameters(model):
|
| """Get parameters from model, handling DDP wrapper."""
|
| return (
|
| model.module.parameters()
|
| if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| else model.parameters()
|
| )
|
|
|
|
|
| def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
| """Save a checkpoint with model state, optimizer state, and metadata."""
|
| if not is_main:
|
| return
|
|
|
|
|
| if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
|
|
| final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
| tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
|
|
|
|
| if tmp_ckpt_dir.exists():
|
| shutil.rmtree(tmp_ckpt_dir)
|
| tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
|
|
|
|
| torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
|
|
|
|
| metadata = {
|
| "global_step": global_step,
|
| "config": dataclasses.asdict(config),
|
| "timestamp": time.time(),
|
| }
|
| torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
|
|
|
|
| norm_stats = data_config.norm_stats
|
| if norm_stats is not None and data_config.asset_id is not None:
|
| _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
|
|
|
|
| if final_ckpt_dir.exists():
|
| shutil.rmtree(final_ckpt_dir)
|
| tmp_ckpt_dir.rename(final_ckpt_dir)
|
|
|
| logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
|
|
|
|
| if config.wandb_enabled:
|
| wandb.log({"checkpoint_step": global_step}, step=global_step)
|
|
|
|
|
| def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| """Load the latest checkpoint and return the global step."""
|
| checkpoint_steps = [
|
| int(d.name)
|
| for d in checkpoint_dir.iterdir()
|
| if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| ]
|
|
|
| if not checkpoint_steps:
|
| raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
|
|
| latest_step = max(checkpoint_steps)
|
| ckpt_dir = checkpoint_dir / f"{latest_step}"
|
|
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
|
|
| try:
|
|
|
| logging.info("Loading model state...")
|
| safetensors_path = ckpt_dir / "model.safetensors"
|
|
|
| if safetensors_path.exists():
|
| model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| logging.info("Loaded model state from safetensors format")
|
| else:
|
| raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| log_memory_usage(device, latest_step, "after_loading_model")
|
|
|
|
|
| logging.info("Loading optimizer state...")
|
| optimizer_path = ckpt_dir / "optimizer.pt"
|
|
|
| if optimizer_path.exists():
|
| optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| logging.info("Loaded optimizer state from pt format")
|
| else:
|
| raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
|
|
| optimizer.load_state_dict(optimizer_state_dict)
|
| del optimizer_state_dict
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| log_memory_usage(device, latest_step, "after_loading_optimizer")
|
|
|
|
|
| logging.info("Loading metadata...")
|
| metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
| global_step = metadata.get("global_step", latest_step)
|
| del metadata
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| log_memory_usage(device, latest_step, "after_loading_metadata")
|
|
|
| logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| return global_step
|
|
|
| except RuntimeError as e:
|
| if "out of memory" in str(e):
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
| log_memory_usage(device, latest_step, "after_oom_error")
|
| raise RuntimeError(
|
| "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| ) from e
|
| raise
|
|
|
|
|
| def get_latest_checkpoint_step(checkpoint_dir):
|
| """Get the latest checkpoint step number from a checkpoint directory."""
|
| checkpoint_steps = [
|
| int(d.name)
|
| for d in checkpoint_dir.iterdir()
|
| if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| ]
|
| return max(checkpoint_steps) if checkpoint_steps else None
|
|
|
|
|
| def log_memory_usage(device, step, phase="unknown"):
|
| """Log detailed memory usage information."""
|
| if not torch.cuda.is_available():
|
| return
|
|
|
| memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
| memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
| memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
| memory_free = memory_free / 1e9
|
|
|
|
|
| memory_stats = torch.cuda.memory_stats(device)
|
| max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
| max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
|
|
|
|
| ddp_info = ""
|
| if dist.is_initialized():
|
| ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
|
|
| logging.info(
|
| f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
| )
|
|
|
|
|
| def train_loop(config: _config.TrainConfig):
|
| use_ddp, local_rank, device = setup_ddp()
|
| is_main = (not use_ddp) or (dist.get_rank() == 0)
|
| set_seed(config.seed, local_rank)
|
|
|
|
|
| resuming = False
|
| if config.resume:
|
|
|
| exp_checkpoint_dir = config.checkpoint_dir
|
| if exp_checkpoint_dir.exists():
|
|
|
| latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
| if latest_step is not None:
|
| resuming = True
|
| logging.info(
|
| f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
| )
|
| else:
|
| raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
| else:
|
| raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
| elif config.overwrite and config.checkpoint_dir.exists():
|
| shutil.rmtree(config.checkpoint_dir)
|
| logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
|
|
|
|
| if not resuming:
|
|
|
| exp_checkpoint_dir = config.checkpoint_dir
|
| exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
| else:
|
|
|
| logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
|
|
|
|
| if is_main:
|
| init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
|
|
|
|
|
|
|
|
| world_size = torch.distributed.get_world_size() if use_ddp else 1
|
| effective_batch_size = config.batch_size // world_size
|
| logging.info(
|
| f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
| )
|
|
|
|
|
| loader, data_config = build_datasets(config)
|
|
|
|
|
| if is_main and config.wandb_enabled and not resuming:
|
|
|
| sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
| sample_batch = next(iter(sample_data_loader))
|
|
|
| observation, actions = sample_batch
|
| sample_batch = observation.to_dict()
|
| sample_batch["actions"] = actions
|
|
|
|
|
| images_to_log = []
|
|
|
| batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
| for i in range(min(5, batch_size)):
|
|
|
|
|
| img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
| img_concatenated = img_concatenated.cpu().numpy()
|
| images_to_log.append(wandb.Image(img_concatenated))
|
|
|
| wandb.log({"camera_views": images_to_log}, step=0)
|
|
|
|
|
| del sample_batch, observation, actions, images_to_log, img_concatenated
|
| del sample_data_loader
|
| gc.collect()
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| logging.info("Cleared sample batch and data loader from memory")
|
|
|
|
|
| if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
|
|
| model_cfg = openpi.models.pi0_config.Pi0Config(
|
| dtype=config.pytorch_training_precision,
|
| action_dim=config.model.action_dim,
|
| action_horizon=config.model.action_horizon,
|
| max_token_len=config.model.max_token_len,
|
| paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
| action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
| pi05=getattr(config.model, "pi05", False),
|
| )
|
| else:
|
| model_cfg = config.model
|
|
|
| object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
|
|
| model = openpi.models_pytorch.pi0_align_pytorch.PI0Pytorch(model_cfg, config).to(device)
|
| vggt_model = VGGT(
|
| enable_camera=False,
|
| enable_point=False,
|
| enable_depth=False,
|
| enable_track=False,
|
| feature_only=True,
|
| ).to(device)
|
| align_projector = projectors.AlignProjector(
|
| model.LLM_width,
|
| config.vggt_dim,
|
| config.use_vlm_norm).to(device)
|
|
|
| if hasattr(model, "gradient_checkpointing_enable"):
|
| enable_gradient_checkpointing = True
|
| model.gradient_checkpointing_enable()
|
| logging.info("Enabled gradient checkpointing for memory optimization")
|
| else:
|
| enable_gradient_checkpointing = False
|
| logging.info("Gradient checkpointing is not supported for this model")
|
|
|
|
|
| if is_main and torch.cuda.is_available():
|
| log_memory_usage(device, 0, "after_model_creation")
|
|
|
|
|
| if world_size >= 8:
|
| torch.backends.cudnn.benchmark = True
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
| logging.info("Enabled memory optimizations for 8+ GPU training")
|
|
|
| if use_ddp:
|
| model = torch.nn.parallel.DistributedDataParallel(
|
| model,
|
| device_ids=[device.index] if device.type == "cuda" else None,
|
| find_unused_parameters=True,
|
| gradient_as_bucket_view=True,
|
| static_graph=world_size >= 8,
|
| )
|
| align_projector = torch.nn.parallel.DistributedDataParallel(
|
| align_projector,
|
| device_ids=[device.index] if device.type == "cuda" else None,
|
| find_unused_parameters=True,
|
| gradient_as_bucket_view=True,
|
| static_graph=world_size >= 8,
|
| )
|
|
|
|
|
| if config.pytorch_weight_path is not None:
|
| logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
| model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
| safetensors.torch.load_model(
|
| (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model),
|
| model_path,
|
| strict=False,
|
| )
|
| logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
| if config.vggt_weight_path is not None:
|
| vggt_path = os.path.join(config.vggt_weight_path, "model.pt")
|
| if not os.path.exists(vggt_path):
|
| raise FileNotFoundError(f"VGGT weight file not found at {vggt_path}")
|
| vggt_model.load_state_dict(torch.load(vggt_path), strict=False)
|
| logging.info(f"Loaded VGGT weights from {config.vggt_weight_path}")
|
|
|
|
|
| warmup_steps = config.lr_schedule.warmup_steps
|
| peak_lr = config.lr_schedule.peak_lr
|
| decay_steps = config.lr_schedule.decay_steps
|
| end_lr = config.lr_schedule.decay_lr
|
|
|
|
|
| optim = torch.optim.AdamW(
|
| list(model.parameters()) + list(align_projector.parameters()),
|
| lr=peak_lr,
|
| betas=(config.optimizer.b1, config.optimizer.b2),
|
| eps=config.optimizer.eps,
|
| weight_decay=config.optimizer.weight_decay,
|
| )
|
|
|
|
|
| global_step = 0
|
| if resuming:
|
| global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
| logging.info(f"Resumed training from step {global_step}")
|
|
|
| def lr_schedule(step: int):
|
| if step < warmup_steps:
|
|
|
| init_lr = peak_lr / (warmup_steps + 1)
|
| return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
|
|
| progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
| cos = 0.5 * (1 + np.cos(np.pi * progress))
|
| return end_lr + (peak_lr - end_lr) * cos
|
|
|
| model.train()
|
| align_projector.train()
|
| vggt_model.eval()
|
| start_time = time.time()
|
| infos = []
|
| if is_main:
|
| logging.info(
|
| f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
| )
|
| logging.info(
|
| f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
| )
|
| logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
| logging.info(
|
| f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
| )
|
| logging.info(
|
| f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
| )
|
| logging.info("EMA is not supported for PyTorch training")
|
| logging.info(f"Training precision: {model_cfg.dtype}")
|
|
|
|
|
| pbar = (
|
| tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
| if is_main
|
| else None
|
| )
|
|
|
| while global_step < config.num_train_steps:
|
|
|
| if use_ddp and hasattr(loader, "set_epoch"):
|
| loader.set_epoch(global_step // len(loader))
|
|
|
| for observation, actions in loader:
|
|
|
| if global_step >= config.num_train_steps:
|
| break
|
|
|
|
|
| observation = jax.tree.map(lambda x: x.to(device), observation)
|
| actions = actions.to(torch.float32)
|
| actions = actions.to(device)
|
|
|
|
|
| for pg in optim.param_groups:
|
| pg["lr"] = lr_schedule(global_step)
|
|
|
|
|
| action_losses, align_loss = model(observation, actions, vggt=vggt_model, align_proj=align_projector)
|
| loss = action_losses + config.align_loss_coeff * align_loss
|
|
|
|
|
| loss.backward()
|
|
|
|
|
| if global_step < 5 and is_main and torch.cuda.is_available():
|
| log_memory_usage(device, global_step, "after_backward")
|
|
|
|
|
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
|
|
|
|
| optim.step()
|
| optim.zero_grad(set_to_none=True)
|
|
|
|
|
| for param in model.parameters():
|
| if param.grad is not None:
|
| param.grad.detach_()
|
| param.grad = None
|
|
|
|
|
| if is_main:
|
| infos.append(
|
| {
|
| "action_loss": action_losses.item(),
|
| "align_loss": align_loss.item(),
|
| "learning_rate": optim.param_groups[0]["lr"],
|
| "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
| }
|
| )
|
|
|
| if is_main and (global_step % config.log_interval == 0):
|
| elapsed = time.time() - start_time
|
|
|
|
|
| avg_loss = sum(info["action_loss"] for info in infos) / len(infos)
|
| avg_align_loss = sum(info["align_loss"] for info in infos) / len(infos)
|
| avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
|
|
| avg_grad_norm = None
|
| if any("grad_norm" in info for info in infos):
|
| vals = [
|
| info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
| ]
|
| if len(vals) > 0:
|
| avg_grad_norm = sum(vals) / len(vals)
|
| logging.info(
|
| f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
| if avg_grad_norm is not None
|
| else f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
| )
|
|
|
|
|
| if config.wandb_enabled and len(infos) > 0:
|
| log_payload = {
|
| "action_loss": avg_loss,
|
| "align_loss": avg_align_loss,
|
| "learning_rate": avg_lr,
|
| "step": global_step,
|
| "time_per_step": elapsed / config.log_interval,
|
| }
|
| if avg_grad_norm is not None:
|
| log_payload["grad_norm"] = avg_grad_norm
|
| wandb.log(log_payload, step=global_step)
|
|
|
| start_time = time.time()
|
| infos = []
|
|
|
| global_step += 1
|
|
|
| save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
|
|
|
|
| if pbar is not None:
|
| pbar.update(1)
|
| pbar.set_postfix(
|
| {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
| )
|
|
|
|
|
| if pbar is not None:
|
| pbar.close()
|
|
|
|
|
| if is_main and config.wandb_enabled:
|
| wandb.finish()
|
|
|
| cleanup_ddp()
|
|
|
|
|
| def main():
|
| init_logging()
|
| config = _config.cli()
|
| train_loop(config)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|