| import argparse |
| import logging |
| import os |
| import warnings |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import torch |
| import torch.distributed as dist |
| import torch.optim as optim |
| import torchmetrics |
| import wandb |
| import yaml |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.utils.data.distributed import DistributedSampler |
| from tqdm import tqdm |
|
|
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.loaders import SyntheticValidationDataset, create_synthetic_dataset |
| from src.gift_eval.aggregate_results import aggregate_results |
| from src.gift_eval.constants import ALL_DATASETS |
| from src.gift_eval.evaluate import evaluate_in_memory |
| from src.models.model import TimeSeriesModel |
| from src.optim.lr_scheduler import WarmupStableDecayScheduler, get_scheduler |
| from src.plotting.plot_multivariate_timeseries import plot_from_container |
| from src.utils.utils import ( |
| generate_descriptive_model_name, |
| seed_everything, |
| ) |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) |
| logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) |
| logging.getLogger("PIL").setLevel(logging.WARNING) |
| logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) |
|
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
| def setup_distributed(): |
| """Initializes the distributed process group.""" |
| dist.init_process_group(backend="nccl") |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| torch.cuda.set_device(local_rank) |
| return local_rank |
|
|
|
|
| def cleanup_distributed(): |
| """Cleans up the distributed process group safely.""" |
| try: |
| if dist.is_available() and dist.is_initialized(): |
| try: |
| dist.barrier() |
| except Exception: |
| pass |
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| try: |
| dist.destroy_process_group() |
| except Exception as e: |
| logger.warning(f"Error during destroy_process_group: {e}") |
| except Exception: |
| pass |
|
|
|
|
| def is_main_process(): |
| return dist.get_rank() == 0 |
|
|
|
|
| class TrainingPipeline: |
| def __init__(self, config: dict): |
| self.config = config |
| self.grad_accum_enabled = bool(self.config.get("gradient_accumulation_enabled", False)) |
| self.accumulation_steps = ( |
| max(1, int(self.config.get("accumulation_steps", 1))) if self.grad_accum_enabled else 1 |
| ) |
|
|
| |
| self.local_rank = setup_distributed() |
| self.rank = dist.get_rank() |
| self.world_size = dist.get_world_size() |
| self.device = torch.device(f"cuda:{self.local_rank}") |
|
|
| self.initial_epoch = 0 |
| self.wandb_step_offset = 0 |
| self._setup() |
|
|
| if is_main_process(): |
| logger.info("Loaded config:") |
| for key, value in self.config.items(): |
| logger.info(f"{key}: {value}") |
|
|
| def _setup(self) -> None: |
| seed_everything(self.config["seed"]) |
| self.config["model_name"] = generate_descriptive_model_name(self.config) |
|
|
| |
| self.run_output_dir = ( |
| self.config.get("run_output_dir") or f"{self.config['model_path']}/{self.config['model_name']}" |
| ) |
| self.config["resolved_run_output_dir"] = self.run_output_dir |
|
|
| if is_main_process() and self.config.get("wandb"): |
| init_kwargs = { |
| "name": self.config["model_name"], |
| "resume": "allow", |
| } |
|
|
| |
| |
| if self.config.get("wandb_entity"): |
| init_kwargs["entity"] = self.config.get("wandb_entity") |
|
|
| |
| if self.config.get("continue_training"): |
| if self.config.get("wandb_run_id"): |
| init_kwargs["id"] = self.config["wandb_run_id"] |
| logger.info(f"Attempting to resume wandb run with ID: {self.config['wandb_run_id']}") |
|
|
| |
| wandb.init( |
| project=self.config.get("wandb_project_name", "TimeSeriesForecasting"), |
| config=self.config, |
| **init_kwargs, |
| ) |
|
|
| self.num_training_iterations = self.config.get("num_training_iterations") |
|
|
| self.model = TimeSeriesModel(**self.config["TimeSeriesModel"]).to(self.device) |
| if is_main_process(): |
| logger.info("=" * 80) |
| logger.info( |
| f"Initializing model with {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M parameters" |
| ) |
| logger.info("=" * 80) |
| logger.info(f"Run output directory: {self.run_output_dir}") |
|
|
| dist.barrier(device_ids=[self.local_rank]) |
| self._setup_optimizer() |
| self._load_checkpoint() |
|
|
| dist.barrier(device_ids=[self.local_rank]) |
| logger.info( |
| f"Distributed training setup: rank {self.rank}, world size {self.world_size}, " |
| f"local rank {self.local_rank}, device {self.device}" |
| ) |
| self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=True) |
| logger.info(f"Distributed Data Parallel model initialized on rank {self.local_rank} with device {self.device}") |
|
|
| augmentations_config = self.config.get("data_augmentation", {}) |
| nan_stats_path = augmentations_config.get("nan_stats_path") |
| nan_patterns_path = augmentations_config.get("nan_patterns_path") |
|
|
| chosen_scaler_name = self.config.get("TimeSeriesModel", {}).get("scaler") |
|
|
| |
| self.train_dataset = create_synthetic_dataset( |
| base_data_dir=self.config.get("train_data_path"), |
| batch_size=self.config.get("batch_size", 128), |
| num_batches_per_epoch=self.num_training_iterations, |
| generator_proportions=self.config.get("generator_proportions"), |
| augmentations=augmentations_config, |
| augmentation_probabilities=self.config.get("augmentation_probabilities"), |
| global_seed=self.config["seed"] + int(os.environ["LOCAL_RANK"]), |
| nan_stats_path=nan_stats_path, |
| nan_patterns_path=nan_patterns_path, |
| chosen_scaler_name=chosen_scaler_name, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
|
|
| |
| train_sampler = DistributedSampler( |
| self.train_dataset, |
| num_replicas=self.world_size, |
| rank=self.rank, |
| shuffle=True, |
| ) |
|
|
| |
| def collate_fn(batch): |
| |
| return batch[0] |
|
|
| |
| self.train_loader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| batch_size=1, |
| sampler=train_sampler, |
| num_workers=self.config.get("num_workers", 1), |
| pin_memory=True, |
| collate_fn=collate_fn, |
| ) |
| print( |
| f"Distributed DataLoader created with {len(self.train_loader)} batches " |
| f"and num workers={self.config.get('num_workers', 0)}" |
| ) |
|
|
| |
| val_dataset = SyntheticValidationDataset( |
| base_data_dir=self.config.get("train_data_path"), |
| batch_size=self.config.get("validation_batch_size", 64), |
| num_batches=self.config.get("num_validation_batches", 1), |
| future_length=512, |
| generator_proportions=self.config.get("generator_proportions"), |
| device=self.device, |
| global_seed=self.config["seed"], |
| augmentations=augmentations_config, |
| augmentation_probabilities=self.config.get("augmentation_probabilities"), |
| chosen_scaler_name=chosen_scaler_name, |
| nan_stats_path=nan_stats_path, |
| nan_patterns_path=nan_patterns_path, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| val_sampler = DistributedSampler(val_dataset, shuffle=False) |
|
|
| self.val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=1, |
| shuffle=False, |
| sampler=val_sampler, |
| collate_fn=collate_fn, |
| num_workers=0, |
| ) |
|
|
| self._setup_metrics() |
|
|
| def _setup_optimizer(self): |
| """Setup optimizer and learning rate scheduler with enhanced WSD support.""" |
| optimizer_config = { |
| "lr": float(self.config["peak_lr"]), |
| "weight_decay": float(self.config.get("weight_decay", 0.01)), |
| "betas": ( |
| float(self.config.get("beta1", 0.9)), |
| float(self.config.get("beta2", 0.98)), |
| ), |
| "eps": float(self.config.get("optimizer_eps", 1e-6)), |
| } |
| self.optimizer = optim.AdamW(self.model.parameters(), **optimizer_config) |
|
|
| |
| effective_accum_steps = self.accumulation_steps |
| total_steps = int(self.num_training_iterations // effective_accum_steps // self.world_size) |
|
|
| scheduler_type = self.config.get("lr_scheduler", "warmup_stable_decay") |
|
|
| if scheduler_type == "warmup_stable_decay": |
| |
| warmup_ratio = float(self.config.get("warmup_ratio", 0.01)) |
| stable_ratio = float(self.config.get("stable_ratio", 0.85)) |
|
|
| num_warmup_steps = int(total_steps * warmup_ratio) |
| num_stable_steps = int(total_steps * stable_ratio) |
|
|
| |
| self.scheduler = WarmupStableDecayScheduler( |
| optimizer=self.optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_stable_steps=num_stable_steps, |
| total_steps=total_steps, |
| min_lr_ratio=self.config.get("min_lr_ratio", 0.01), |
| decay_type=self.config.get("decay_type", "cosine"), |
| verbose=is_main_process(), |
| ) |
|
|
| if is_main_process(): |
| logger.info("WSD Scheduler configured:") |
| logger.info(f" Total steps: {total_steps}") |
| logger.info(f" Warmup steps: {num_warmup_steps} ({warmup_ratio * 100:.1f}%)") |
| logger.info(f" Stable steps: {num_stable_steps} ({stable_ratio * 100:.1f}%)") |
| logger.info(f" Decay steps: {total_steps - num_warmup_steps - num_stable_steps}") |
| logger.info(f" Peak LR: {self.config['peak_lr']}") |
| logger.info(f" Min LR: {self.config['peak_lr'] * float(self.config.get('min_lr_ratio', 0.01))}") |
|
|
| elif scheduler_type == "cosine_with_warmup": |
| num_warmup_steps = int(total_steps * self.config.get("warmup_ratio", 0.01)) |
|
|
| self.scheduler = get_scheduler( |
| scheduler_type="cosine_with_warmup", |
| optimizer=self.optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=total_steps, |
| scheduler_kwargs={ |
| "min_lr_ratio": float(self.config.get("min_lr_ratio", 0.01)), |
| "num_cycles": float(self.config.get("num_cycles", 0.5)), |
| }, |
| ) |
|
|
| elif scheduler_type == "cosine_with_restarts": |
| num_warmup_steps = int(total_steps * self.config.get("warmup_ratio", 0.01)) |
|
|
| self.scheduler = get_scheduler( |
| scheduler_type="cosine_with_restarts", |
| optimizer=self.optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=total_steps, |
| scheduler_kwargs={ |
| "min_lr_ratio": float(self.config.get("min_lr_ratio", 0.01)), |
| "num_cycles": int(self.config.get("num_restart_cycles", 4)), |
| }, |
| ) |
|
|
| elif scheduler_type == "cosine": |
| self.scheduler = CosineAnnealingLR( |
| self.optimizer, |
| T_max=total_steps, |
| eta_min=float(self.config["peak_lr"]) * float(self.config.get("min_lr_ratio", 0.01)), |
| ) |
|
|
| else: |
| raise ValueError(f"Unsupported scheduler type: {scheduler_type}") |
|
|
| if is_main_process(): |
| logger.info(f"Optimizer configured with {scheduler_type} scheduler") |
|
|
| def _setup_metrics(self): |
| self.train_metrics = { |
| "mape": torchmetrics.MeanAbsolutePercentageError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| "mse": torchmetrics.MeanSquaredError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| "smape": torchmetrics.SymmetricMeanAbsolutePercentageError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| } |
| self.val_metrics = { |
| "mape": torchmetrics.MeanAbsolutePercentageError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| "mse": torchmetrics.MeanSquaredError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| "smape": torchmetrics.SymmetricMeanAbsolutePercentageError( |
| dist_sync_on_step=False, compute_on_cpu=False, sync_on_compute=True |
| ).to(self.device), |
| } |
|
|
| def _load_checkpoint(self): |
| |
| if not self.config.get("continue_training"): |
| return |
|
|
| checkpoint_path_value = self.config.get("checkpoint_path") |
| if not checkpoint_path_value: |
| if is_main_process(): |
| logger.info("continue_training=True but no checkpoint_path provided; starting from scratch.") |
| return |
|
|
| checkpoint_path = Path(checkpoint_path_value) |
| if not checkpoint_path.exists(): |
| if is_main_process(): |
| logger.warning(f"Checkpoint path does not exist at {checkpoint_path}. Starting from scratch.") |
| return |
|
|
| if is_main_process(): |
| logger.info(f"Loading checkpoint from: {checkpoint_path}") |
|
|
| ckpt = torch.load(checkpoint_path, map_location=self.device) |
| self.model.load_state_dict(ckpt["model_state_dict"]) |
|
|
| def _save_checkpoint(self, epoch: int): |
| dist.barrier() |
| if is_main_process(): |
| model_dir = self.run_output_dir |
| os.makedirs(model_dir, exist_ok=True) |
|
|
| unwrapped_model = self.model.module |
| checkpoint = { |
| "epoch": epoch, |
| "model_state_dict": unwrapped_model.state_dict(), |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "wandb_run_id": self.config.get("wandb_run_id"), |
| } |
|
|
| if hasattr(self.scheduler, "state_dict"): |
| checkpoint["scheduler_state_dict"] = self.scheduler.state_dict() |
| elif hasattr(self.scheduler, "current_step"): |
| checkpoint["wsd_scheduler_state"] = self.scheduler.state_dict() |
|
|
| checkpoint_path = f"{model_dir}/checkpoint.pth" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Checkpoint saved for step {epoch} to {checkpoint_path}") |
|
|
| config_path = f"{model_dir}/config.yaml" |
| with open(config_path, "w") as config_file: |
| yaml.dump(self.config, config_file) |
|
|
| def _inverse_scale(self, model, output: dict) -> torch.Tensor: |
| |
| return model.module.scaler.inverse_scale(output["result"], output["scale_statistics"]) |
|
|
| def _train_epoch(self, epoch: int) -> float: |
| self.model.train() |
| self.train_loader.sampler.set_epoch(epoch) |
|
|
| train_loss, total_loss_sum, total_samples = 0.0, 0.0, 0.0 |
|
|
| pbar = tqdm( |
| self.train_loader, |
| desc=f"Training (start_step={epoch})", |
| disable=not is_main_process(), |
| ) |
|
|
| self.optimizer.zero_grad() |
|
|
| for i, batch in enumerate(pbar): |
| batch_size = batch.history_values.size(0) |
| batch.to(self.device) |
|
|
| with torch.autocast(dtype=torch.bfloat16, device_type="cuda"): |
| output = self.model(batch) |
| loss = self.model.module.compute_loss(batch.future_values, output) |
|
|
| if self.accumulation_steps > 1: |
| loss = loss / self.accumulation_steps |
|
|
| loss.backward() |
|
|
| total_loss_sum += loss.item() * batch_size |
| total_samples += batch_size |
|
|
| if ((i + 1) % self.accumulation_steps == 0) or ((i + 1) == len(self.train_loader)): |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.get("gradient_clip_val", 1.0)) |
|
|
| self.optimizer.step() |
|
|
| if hasattr(self.scheduler, "step") and callable(self.scheduler.step): |
| if isinstance(self.scheduler, WarmupStableDecayScheduler): |
| self.scheduler.step() |
| else: |
| self.scheduler.step() |
|
|
| self.optimizer.zero_grad() |
|
|
| if (i + 1) % self.config.get("log_interval", 10) == 0: |
| dist.barrier() |
| self._validate_epoch(i) |
|
|
| total_loss_tensor = torch.tensor([total_loss_sum, total_samples], device=self.device) |
| dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM) |
| global_loss_sum, global_samples = total_loss_tensor.tolist() |
|
|
| train_loss = global_loss_sum / global_samples if global_samples > 0 else 0.0 |
| if self.accumulation_steps > 1: |
| train_loss *= self.accumulation_steps |
|
|
| if is_main_process(): |
| current_lr = self.optimizer.param_groups[0]["lr"] |
| step_metrics = { |
| "train/step_loss": train_loss, |
| "train/learning_rate": current_lr, |
| "train/lr_schedule_step": i, |
| } |
|
|
| if hasattr(self.scheduler, "get_phase"): |
| step_metrics["train/lr_phase"] = self.scheduler.get_phase() |
| step_metrics["train/lr_factor"] = self.scheduler.get_lr_factor(self.scheduler.current_step - 1) |
|
|
| if self.config.get("wandb"): |
| wandb.log(step_metrics, step=i) |
|
|
| logger.info(f"Step {i} | Training Loss: {train_loss:.4f} | LR: {current_lr:.2e}") |
|
|
| total_loss_sum, total_samples = 0.0, 0 |
|
|
| if (i + 1) % self.config.get("save_every", 10) == 0: |
| self._save_checkpoint(i) |
|
|
| return train_loss |
|
|
| def _validate_epoch(self, epoch: int) -> float: |
| self.model.eval() |
|
|
| for metric in self.val_metrics.values(): |
| metric.reset() |
|
|
| first_batch_for_plotting = None |
|
|
| total_loss_sum, total_samples = 0.0, 0 |
| with torch.no_grad(): |
| self.val_loader.sampler.set_epoch(epoch) |
| for batch_idx, batch in enumerate(self.val_loader): |
| if is_main_process() and batch_idx == 0: |
| first_batch_for_plotting = batch.to(torch.device("cpu")) |
|
|
| batch = batch.to(self.device) |
| batch_size = batch.history_values.size(0) |
|
|
| with torch.autocast(dtype=torch.bfloat16, device_type="cuda"): |
| output = self.model.module(batch) |
| loss = self.model.module.compute_loss(batch.future_values, output) |
|
|
| inv_scaled_output = self._inverse_scale(self.model, output) |
| total_loss_sum += loss.item() * batch_size |
| total_samples += batch_size |
|
|
| self._update_metrics( |
| self.val_metrics, |
| inv_scaled_output, |
| batch.future_values, |
| distributed=False, |
| ) |
|
|
| total_stats = torch.tensor([total_loss_sum, total_samples], device=self.device) |
| dist.all_reduce(total_stats, op=dist.ReduceOp.SUM) |
| global_loss_sum, global_samples = total_stats.tolist() |
| avg_val_loss = global_loss_sum / global_samples if global_samples > 0 else 0.0 |
|
|
| val_computed_metrics = {name: metric.compute() for name, metric in self.val_metrics.items()} |
|
|
| if is_main_process(): |
| log_metrics = {"val/loss": avg_val_loss} |
| log_metrics.update({f"val/{name}": value.item() for name, value in val_computed_metrics.items()}) |
|
|
| if self.config.get("wandb"): |
| wandb.log(log_metrics, step=epoch + self.wandb_step_offset) |
|
|
| logger.info( |
| f"Epoch {epoch} | Validation Loss: {avg_val_loss:.4f} | " |
| f"Validation MAPE: {val_computed_metrics.get('mape', -1).item():.4f}" |
| ) |
|
|
| if first_batch_for_plotting is not None: |
| self._plot_validation_examples(epoch, first_batch_for_plotting, plot_all=True) |
|
|
| |
| dist.barrier() |
| return avg_val_loss |
|
|
| def _update_metrics( |
| self, |
| metrics: dict, |
| predictions: torch.Tensor, |
| targets: torch.Tensor, |
| distributed: bool = True, |
| ): |
| """ |
| Gathers tensors if in distributed mode and updates the metric objects. |
| """ |
| if distributed and dist.is_initialized(): |
| world_size = dist.get_world_size() |
| predictions_list = [torch.zeros_like(predictions) for _ in range(world_size)] |
| targets_list = [torch.zeros_like(targets) for _ in range(world_size)] |
|
|
| dist.all_gather(predictions_list, predictions) |
| dist.all_gather(targets_list, targets) |
|
|
| predictions_gathered = torch.cat(predictions_list, dim=0) |
| targets_gathered = torch.cat(targets_list, dim=0) |
| else: |
| predictions_gathered = predictions |
| targets_gathered = targets |
|
|
| unwrapped_model = self.model.module |
|
|
| if unwrapped_model.loss_type == "quantile": |
| try: |
| median_idx = unwrapped_model.quantiles.index(0.5) |
| predictions_gathered = predictions_gathered[..., median_idx] |
| except (ValueError, AttributeError): |
| if is_main_process(): |
| logger.warning("Median (0.5) quantile not found for metric calculation. Skipping.") |
| return |
|
|
| if predictions_gathered.dim() == 3: |
| b, p, c = predictions_gathered.shape |
| predictions_flat = predictions_gathered.permute(0, 2, 1).reshape(b * c, p) |
| targets_flat = targets_gathered.permute(0, 2, 1).reshape(b * c, p) |
|
|
| for metric in metrics.values(): |
| metric.update(predictions_flat, targets_flat) |
|
|
| def _plot_validation_examples( |
| self, |
| epoch: int, |
| plot_batch: BatchTimeSeriesContainer, |
| plot_indices: list[int] | None = None, |
| plot_all: bool = False, |
| ) -> None: |
| """ |
| Plots validation examples from a given batch and logs them to WandB. |
| This method should only be called from the main process. |
| """ |
| if (not self.config.get("wandb")) or (not self.config.get("wandb_plots", False)): |
| return |
|
|
| if plot_indices is None: |
| plot_indices = [0, 1, 2, 3, 4] |
|
|
| model = self.model.module |
|
|
| with torch.inference_mode(): |
| plot_batch.to(self.device) |
|
|
| with torch.autocast(dtype=torch.bfloat16, device_type="cuda"): |
| output = model(plot_batch) |
|
|
| inv_scaled_output = self._inverse_scale(self.model, output) |
| pred_future = inv_scaled_output.cpu().numpy() |
|
|
| batch_size = plot_batch.history_values.size(0) |
| if plot_all: |
| indices_to_plot = list(range(batch_size)) |
| else: |
| indices_to_plot = [i for i in plot_indices if i < batch_size] |
|
|
| for i in indices_to_plot: |
| fig = plot_from_container( |
| batch=plot_batch, |
| sample_idx=i, |
| predicted_values=pred_future, |
| model_quantiles=model.quantiles if model.loss_type == "quantile" else None, |
| title=f"Epoch {epoch} - Val Sample {i}", |
| output_file=None, |
| show=False, |
| ) |
|
|
| wandb.log( |
| {f"val_plots/sample_{i}": wandb.Image(fig)}, |
| step=epoch + self.wandb_step_offset, |
| ) |
| plt.close(fig) |
|
|
| def train(self) -> None: |
| if is_main_process(): |
| per_rank_iterations = len(self.train_loader) |
| optimizer_steps_per_rank = (per_rank_iterations + self.accumulation_steps - 1) // self.accumulation_steps |
| logger.info( |
| f"Starting training: configured_iterations={self.num_training_iterations}, " |
| f"world_size={self.world_size}, per_rank_iterations={per_rank_iterations}, " |
| f"accumulation_steps={self.accumulation_steps}, " |
| f"optimizer_steps_per_rank={optimizer_steps_per_rank}" |
| ) |
|
|
| self._train_epoch(self.initial_epoch) |
|
|
| dist.barrier() |
|
|
| if not is_main_process(): |
| try: |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| except Exception: |
| pass |
| cleanup_distributed() |
| return |
|
|
| cleanup_distributed() |
|
|
| gift_eval_config = self.config.get("gift_eval") |
| if gift_eval_config.get("evaluate_on_gift_eval"): |
| output_dir = f"{self.run_output_dir}/gift_eval_results" |
|
|
| evaluate_in_memory( |
| model=self.model.module if isinstance(self.model, DDP) else self.model, |
| config=self.config, |
| datasets=ALL_DATASETS, |
| terms=["short", "medium", "long"], |
| dataset_storage_path=gift_eval_config.get("dataset_storage_path"), |
| batch_size=self.config.get("batch_size"), |
| max_context_length=gift_eval_config.get("max_context_length"), |
| output_dir=output_dir, |
| create_plots=gift_eval_config.get("create_plots"), |
| max_plots=gift_eval_config.get("max_plots"), |
| ) |
|
|
| aggregate_results( |
| result_root_dir=output_dir, |
| ) |
|
|
| if self.config.get("wandb"): |
| logger.info("TRAINING COMPLETED SUCCESSFULLY!") |
| wandb.finish() |
|
|
| try: |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| except Exception: |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-c", "--config", default="./configs/train.yaml", help="Path to config file") |
| parser.add_argument( |
| "--run_output_dir", |
| default=None, |
| help=( |
| "Optional output directory to store checkpoints and artifacts. " |
| "If provided, overrides model_path/model_name for saving." |
| ), |
| ) |
| args = parser.parse_args() |
| with open(args.config) as config_file: |
| config = yaml.safe_load(config_file) |
|
|
| |
| if getattr(args, "run_output_dir", None): |
| config["run_output_dir"] = args.run_output_dir |
|
|
| try: |
| pipeline = TrainingPipeline(config) |
| pipeline.train() |
| finally: |
| |
| try: |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| except Exception: |
| pass |
|
|