import argparse import json import os import random from pathlib import Path from typing import Dict, Optional import numpy as np import torch from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LinearLR, SequentialLR from tqdm import tqdm from loader import SoilFormerDataset, build_train_eval_dataloaders from soilformer import SoilFormer, loss_function from utils import get_dtype, load_json, save_json try: import wandb except ImportError: # pragma: no cover wandb = None def set_seed(seed: int, deterministic: bool = True) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def resolve_device(device_str: str) -> torch.device: device_str = device_str.lower() if device_str == "cuda": if not torch.cuda.is_available(): raise RuntimeError("config requests cuda, but CUDA is not available") return torch.device("cuda") if device_str == "mps": if not torch.backends.mps.is_available(): raise RuntimeError("config requests mps, but MPS is not available") return torch.device("mps") if device_str == "cpu": return torch.device("cpu") raise ValueError(f"Unsupported device: {device_str}") def move_batch_to_device(batch: Dict, device: torch.device, float_dtype: torch.dtype) -> Dict: out = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): if value.dtype.is_floating_point: out[key] = value.to(device=device, dtype=float_dtype, non_blocking=True) else: out[key] = value.to(device=device, non_blocking=True) elif isinstance(value, dict): sub = {} for sub_key, sub_value in value.items(): if isinstance(sub_value, torch.Tensor): if sub_value.dtype.is_floating_point: sub[sub_key] = sub_value.to(device=device, dtype=float_dtype, non_blocking=True) else: sub[sub_key] = sub_value.to(device=device, non_blocking=True) else: sub[sub_key] = sub_value out[key] = sub else: out[key] = value return out def build_scheduler( optimizer: torch.optim.Optimizer, scheduler_cfg: Dict, ): scheduler_type = str(scheduler_cfg.get("type", "none")).lower() if scheduler_type == "none": return None warmup_epochs = int(scheduler_cfg.get("warmup_epochs", 0)) warmup_start_factor = float(scheduler_cfg.get("warmup_start_factor", 0.1)) if scheduler_type == "cosine": total_epochs = int(scheduler_cfg["total_epochs"]) eta_min = float(scheduler_cfg.get("eta_min", 1e-6)) if warmup_epochs > 0: t_max = int(scheduler_cfg.get("t_max", total_epochs - warmup_epochs)) if t_max <= 0: raise ValueError( f"Invalid cosine scheduler config: total_epochs={total_epochs}, " f"warmup_epochs={warmup_epochs}, resulting T_max={t_max}" ) else: t_max = int(scheduler_cfg.get("t_max", total_epochs)) main_scheduler = CosineAnnealingLR( optimizer, T_max=t_max, eta_min=eta_min, ) elif scheduler_type == "step": step_size = int(scheduler_cfg["step_size"]) gamma = float(scheduler_cfg.get("gamma", 0.1)) main_scheduler = StepLR( optimizer, step_size=step_size, gamma=gamma, ) else: raise ValueError(f"Unsupported scheduler type: {scheduler_type}") if warmup_epochs <= 0: return main_scheduler warmup_scheduler = LinearLR( optimizer, start_factor=warmup_start_factor, total_iters=warmup_epochs, ) scheduler = SequentialLR( optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_epochs], ) return scheduler def get_checkpoint_model_state(model: SoilFormer) -> Dict[str, torch.Tensor]: if hasattr(model, "_checkpoint_state_dict"): return model._checkpoint_state_dict() # noqa return model.state_dict() def load_checkpoint_model_state(model: SoilFormer, state_dict: Dict[str, torch.Tensor]) -> None: if hasattr(model, "load_weights"): payload = {"model_state_dict": state_dict} tmp_path = None try: import tempfile with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: tmp_path = f.name torch.save(payload, tmp_path) model.load_weights(tmp_path, map_location="cpu", strict=True) finally: if tmp_path is not None and os.path.exists(tmp_path): os.remove(tmp_path) return model.load_state_dict(state_dict, strict=True) def save_checkpoint( checkpoint_path: Path, model: SoilFormer, optimizer: torch.optim.Optimizer, scheduler, epoch: int, global_step: int, config_train: Dict, config_model: Dict, config_data: Dict, ) -> None: checkpoint = { "epoch": epoch, "global_step": global_step, "model_state_dict": get_checkpoint_model_state(model), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": None if scheduler is None else scheduler.state_dict(), "config_train": config_train, "config_model": config_model, "config_data": config_data, } checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save(checkpoint, checkpoint_path) def rotate_checkpoints(checkpoint_dir: Path, max_saved_checkpoints: int) -> None: checkpoint_paths = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pt")) if max_saved_checkpoints is None or max_saved_checkpoints <= 0: return while len(checkpoint_paths) > max_saved_checkpoints: oldest = checkpoint_paths.pop(0) oldest.unlink(missing_ok=True) def compute_loss_from_batch( model: SoilFormer, batch: Dict, device: torch.device, dtype: torch.dtype, cat_s_bound: Optional[float] = None, num_s_bound: Optional[float] = None, ): batch = move_batch_to_device(batch, device=device, float_dtype=dtype) cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model( cat_local_ids=batch["masked_cat_local_ids"], numeric_values_by_nin=batch["masked_numeric_values_by_nin"], cat_valid_positions=batch["masked_cat_valid_positions"], numeric_valid_positions_by_nin=batch["masked_numeric_valid_positions_by_nin"], pixel_values=batch["pixel_values"], vision_valid_positions=batch["vision_valid_positions"], ) total_loss, stats = loss_function( x_cat=cat_logits_padded, s_cat=cat_s, y_cat=batch["original_cat_local_ids"], loss_mask_cat=batch["cat_loss_mask"], valid_class_mask=valid_class_mask, x_num=value_by_nin, s_num=s_by_nin, y_num=batch["original_numeric_values_by_nin"], loss_mask_num=batch["numeric_loss_mask_by_nin"], reduction="mean", cat_s_bound=cat_s_bound, num_s_bound=num_s_bound, ) return total_loss, stats @torch.no_grad() def evaluate( model: SoilFormer, dataset: SoilFormerDataset, eval_loader, device: torch.device, dtype: torch.dtype, cat_mask_ratio: float, num_mask_ratio: float, active_mask_seed: int, show_tqdm: bool, epoch: int, cat_s_bound: Optional[float] = None, num_s_bound: Optional[float] = None, ): model.eval() totals = { "total": 0.0, "cat_loss": 0.0, "num_loss": 0.0, "cat_base": 0.0, "num_base": 0.0, "cat_acc": 0.0, } num_batches = 0 iterator = eval_loader if show_tqdm: iterator = tqdm(eval_loader, desc=f"Eval {epoch}", leave=False) for batch_idx, raw_batch in enumerate(iterator): mask_seed = int(active_mask_seed + batch_idx) masked_batch = dataset.perform_active_mask( raw_batch, cat_ratio=cat_mask_ratio, num_ratio=num_mask_ratio, seed=mask_seed, ) _, stats = compute_loss_from_batch( model=model, batch=masked_batch, device=device, dtype=dtype, cat_s_bound=cat_s_bound, num_s_bound=num_s_bound, ) num_batches += 1 for key in totals: totals[key] += float(stats[key].item()) if num_batches == 0: raise RuntimeError("Eval dataloader is empty") return {f"eval/{k}": v / num_batches for k, v in totals.items()} def maybe_init_wandb(config_train: Dict): wandb_cfg = config_train["logging"]["wandb"] if not bool(wandb_cfg.get("enabled", False)): return None if wandb is None: raise ImportError("wandb is enabled in config but package is not installed") run = wandb.init( project=wandb_cfg["project"], entity=wandb_cfg.get("entity"), name=wandb_cfg.get("run_name"), dir=wandb_cfg.get("dir"), config=config_train, mode=wandb_cfg.get("mode", "online"), ) return run def print_parameter_stats(model): total = 0 trainable = 0 for p in model.parameters(): num = p.numel() total += num if p.requires_grad: trainable += num print("\nParameter statistics:") print(f"Total parameters: {total:,}") print(f"Trainable parameters: {trainable:,}") print(f"Frozen parameters: {total - trainable:,}\n") def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="config/config_train.json") args = parser.parse_args() config_train = load_json(args.config) config_paths = config_train["paths"] config_data = load_json(config_paths["config_data_path"]) config_model = load_json(config_paths["config_model_path"]) seed_cfg = config_train["seed"] runtime_cfg = config_train["runtime"] optim_cfg = config_train["optimization"] checkpoint_cfg = config_train["checkpoint"] logging_cfg = config_train["logging"] loss_cfg = config_train["loss"] set_seed(int(seed_cfg["seed"]), deterministic=bool(seed_cfg.get("deterministic", True))) device = resolve_device(runtime_cfg["device"]) dtype = get_dtype(config_model.get("dtype", "bfloat16")) output_dir = Path(config_paths["output_dir"]) checkpoint_dir = output_dir / "checkpoints" output_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir.mkdir(parents=True, exist_ok=True) save_json(config_train, str(output_dir / "config_train.snapshot.json")) save_json(config_data, str(output_dir / "config_data.snapshot.json")) save_json(config_model, str(output_dir / "config_model.snapshot.json")) dataset = SoilFormerDataset( csv_path=config_data["data_csv_path"], photo_map_path=config_data["photo_map_path"], cat_vocab_path=config_data["cat_vocab_path"], numeric_vocab_path=config_data["numeric_vocab_path"], numeric_stats_path=config_data["numeric_stats_path"], photo_root=config_data["photo_root"], image_size=int(config_data["image_size"]), ) train_loader, eval_loader, train_generator = build_train_eval_dataloaders( dataset=dataset, train_ratio=float(config_data["train_ratio"]), seed=int(config_data["train_eval_split_seed"]), batch_size=int(config_data["batch_size"]), ) print("\nSample statistics:") print("Train samples:", len(train_loader.dataset)) print("Eval samples:", len(eval_loader.dataset)) train_generator.manual_seed(int(seed_cfg["seed"])) model = SoilFormer(config=config_model, device=str(device)) resume_path = checkpoint_cfg.get("resume_checkpoint_path") if resume_path: checkpoint = torch.load(resume_path, map_location="cpu") load_checkpoint_model_state(model, checkpoint["model_state_dict"]) else: model.init_weights(std=float(runtime_cfg.get("init_weight_std", 0.02))) checkpoint = None print_parameter_stats(model) optimizer = AdamW( [p for p in model.parameters() if p.requires_grad], lr=float(optim_cfg["lr"]), betas=(float(optim_cfg["beta1"]), float(optim_cfg["beta2"])), eps=float(optim_cfg["eps"]), weight_decay=float(optim_cfg["weight_decay"]), ) scheduler = build_scheduler( optimizer=optimizer, scheduler_cfg=optim_cfg.get("scheduler", {"type": "none"}) ) start_epoch = 1 global_step = 0 if checkpoint is not None: optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if scheduler is not None and checkpoint.get("scheduler_state_dict") is not None: scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = int(checkpoint["epoch"]) + 1 global_step = int(checkpoint.get("global_step", 0)) wandb_run = maybe_init_wandb(config_train) num_epochs = int(runtime_cfg["num_epochs"]) show_tqdm = bool(logging_cfg.get("tqdm", True)) cat_mask_ratio = float(config_data["cat_mask_ratio"]) num_mask_ratio = float(config_data["num_mask_ratio"]) active_mask_seed = int(config_data["active_mask_seed"]) max_grad_norm = optim_cfg.get("max_grad_norm") epochs_per_save = int(checkpoint_cfg["epochs_per_save"]) max_saved_checkpoints = int(checkpoint_cfg["max_saved_checkpoints"]) for epoch in range(start_epoch, num_epochs + 1): model.train() epoch_totals = { "total": 0.0, "cat_loss": 0.0, "num_loss": 0.0, "cat_base": 0.0, "num_base": 0.0, "cat_acc": 0.0, } num_batches = 0 iterator = train_loader if show_tqdm: iterator = tqdm(train_loader, desc=f"Train {epoch}", leave=True) for batch_idx, raw_batch in enumerate(iterator): global_step += 1 mask_seed = int(active_mask_seed + epoch * 1_000_000 + batch_idx) masked_batch = dataset.perform_active_mask( raw_batch, cat_ratio=cat_mask_ratio, num_ratio=num_mask_ratio, seed=mask_seed, ) optimizer.zero_grad(set_to_none=True) total_loss, stats = compute_loss_from_batch( model=model, batch=masked_batch, device=device, dtype=dtype, cat_s_bound=loss_cfg.get("cat_s_bound", None), num_s_bound=loss_cfg.get("num_s_bound", None), ) total_loss.backward() if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), float(max_grad_norm)) optimizer.step() num_batches += 1 for key in epoch_totals: epoch_totals[key] += float(stats[key].item()) current_lr = float(optimizer.param_groups[0]["lr"]) train_step_log = { "train/step_total": float(stats["total"].item()), "train/step_cat_loss": float(stats["cat_loss"].item()), "train/step_num_loss": float(stats["num_loss"].item()), "train/step_cat_acc": float(stats["cat_acc"].item()), "train/lr": current_lr, "epoch": epoch, "global_step": global_step, } if wandb_run is not None: wandb.log(train_step_log, step=global_step) if show_tqdm: iterator.set_postfix( loss=f"{train_step_log['train/step_total']:.4f}", lr=f"{current_lr:.3e}", ) if num_batches == 0: raise RuntimeError("Train dataloader is empty") train_epoch_log = {f"train/{k}": v / num_batches for k, v in epoch_totals.items()} train_epoch_log["train/lr_epoch_end"] = float(optimizer.param_groups[0]["lr"]) train_epoch_log["epoch"] = epoch train_epoch_log["global_step"] = global_step eval_log = evaluate( model=model, dataset=dataset, eval_loader=eval_loader, device=device, dtype=dtype, cat_mask_ratio=cat_mask_ratio, num_mask_ratio=num_mask_ratio, active_mask_seed=active_mask_seed, show_tqdm=show_tqdm, epoch=epoch, cat_s_bound=loss_cfg.get("cat_s_bound", None), num_s_bound=loss_cfg.get("num_s_bound", None), ) eval_log["epoch"] = epoch eval_log["global_step"] = global_step merged_log = {} merged_log.update(train_epoch_log) merged_log.update(eval_log) print(json.dumps(merged_log, ensure_ascii=False)) if wandb_run is not None: wandb.log(merged_log, step=global_step) if scheduler is not None: scheduler.step() if epochs_per_save > 0 and epoch % epochs_per_save == 0: checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt" save_checkpoint( checkpoint_path=checkpoint_path, model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, config_train=config_train, config_model=config_model, config_data=config_data, ) rotate_checkpoints(checkpoint_dir, max_saved_checkpoints) if wandb_run is not None: wandb.finish() if __name__ == "__main__": main()