soilformer / modelling /train.py
Kuangdai
Initial release of SoilFormer
6fb6c07
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict, Optional
import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LinearLR, SequentialLR
from tqdm import tqdm
from loader import SoilFormerDataset, build_train_eval_dataloaders
from soilformer import SoilFormer, loss_function
from utils import get_dtype, load_json, save_json
try:
import wandb
except ImportError: # pragma: no cover
wandb = None
def set_seed(seed: int, deterministic: bool = True) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def resolve_device(device_str: str) -> torch.device:
device_str = device_str.lower()
if device_str == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("config requests cuda, but CUDA is not available")
return torch.device("cuda")
if device_str == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError("config requests mps, but MPS is not available")
return torch.device("mps")
if device_str == "cpu":
return torch.device("cpu")
raise ValueError(f"Unsupported device: {device_str}")
def move_batch_to_device(batch: Dict, device: torch.device, float_dtype: torch.dtype) -> Dict:
out = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
if value.dtype.is_floating_point:
out[key] = value.to(device=device, dtype=float_dtype, non_blocking=True)
else:
out[key] = value.to(device=device, non_blocking=True)
elif isinstance(value, dict):
sub = {}
for sub_key, sub_value in value.items():
if isinstance(sub_value, torch.Tensor):
if sub_value.dtype.is_floating_point:
sub[sub_key] = sub_value.to(device=device, dtype=float_dtype, non_blocking=True)
else:
sub[sub_key] = sub_value.to(device=device, non_blocking=True)
else:
sub[sub_key] = sub_value
out[key] = sub
else:
out[key] = value
return out
def build_scheduler(
optimizer: torch.optim.Optimizer,
scheduler_cfg: Dict,
):
scheduler_type = str(scheduler_cfg.get("type", "none")).lower()
if scheduler_type == "none":
return None
warmup_epochs = int(scheduler_cfg.get("warmup_epochs", 0))
warmup_start_factor = float(scheduler_cfg.get("warmup_start_factor", 0.1))
if scheduler_type == "cosine":
total_epochs = int(scheduler_cfg["total_epochs"])
eta_min = float(scheduler_cfg.get("eta_min", 1e-6))
if warmup_epochs > 0:
t_max = int(scheduler_cfg.get("t_max", total_epochs - warmup_epochs))
if t_max <= 0:
raise ValueError(
f"Invalid cosine scheduler config: total_epochs={total_epochs}, "
f"warmup_epochs={warmup_epochs}, resulting T_max={t_max}"
)
else:
t_max = int(scheduler_cfg.get("t_max", total_epochs))
main_scheduler = CosineAnnealingLR(
optimizer,
T_max=t_max,
eta_min=eta_min,
)
elif scheduler_type == "step":
step_size = int(scheduler_cfg["step_size"])
gamma = float(scheduler_cfg.get("gamma", 0.1))
main_scheduler = StepLR(
optimizer,
step_size=step_size,
gamma=gamma,
)
else:
raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
if warmup_epochs <= 0:
return main_scheduler
warmup_scheduler = LinearLR(
optimizer,
start_factor=warmup_start_factor,
total_iters=warmup_epochs,
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, main_scheduler],
milestones=[warmup_epochs],
)
return scheduler
def get_checkpoint_model_state(model: SoilFormer) -> Dict[str, torch.Tensor]:
if hasattr(model, "_checkpoint_state_dict"):
return model._checkpoint_state_dict() # noqa
return model.state_dict()
def load_checkpoint_model_state(model: SoilFormer, state_dict: Dict[str, torch.Tensor]) -> None:
if hasattr(model, "load_weights"):
payload = {"model_state_dict": state_dict}
tmp_path = None
try:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
tmp_path = f.name
torch.save(payload, tmp_path)
model.load_weights(tmp_path, map_location="cpu", strict=True)
finally:
if tmp_path is not None and os.path.exists(tmp_path):
os.remove(tmp_path)
return
model.load_state_dict(state_dict, strict=True)
def save_checkpoint(
checkpoint_path: Path,
model: SoilFormer,
optimizer: torch.optim.Optimizer,
scheduler,
epoch: int,
global_step: int,
config_train: Dict,
config_model: Dict,
config_data: Dict,
) -> None:
checkpoint = {
"epoch": epoch,
"global_step": global_step,
"model_state_dict": get_checkpoint_model_state(model),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": None if scheduler is None else scheduler.state_dict(),
"config_train": config_train,
"config_model": config_model,
"config_data": config_data,
}
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, checkpoint_path)
def rotate_checkpoints(checkpoint_dir: Path, max_saved_checkpoints: int) -> None:
checkpoint_paths = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pt"))
if max_saved_checkpoints is None or max_saved_checkpoints <= 0:
return
while len(checkpoint_paths) > max_saved_checkpoints:
oldest = checkpoint_paths.pop(0)
oldest.unlink(missing_ok=True)
def compute_loss_from_batch(
model: SoilFormer,
batch: Dict,
device: torch.device,
dtype: torch.dtype,
cat_s_bound: Optional[float] = None,
num_s_bound: Optional[float] = None,
):
batch = move_batch_to_device(batch, device=device, float_dtype=dtype)
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
cat_local_ids=batch["masked_cat_local_ids"],
numeric_values_by_nin=batch["masked_numeric_values_by_nin"],
cat_valid_positions=batch["masked_cat_valid_positions"],
numeric_valid_positions_by_nin=batch["masked_numeric_valid_positions_by_nin"],
pixel_values=batch["pixel_values"],
vision_valid_positions=batch["vision_valid_positions"],
)
total_loss, stats = loss_function(
x_cat=cat_logits_padded,
s_cat=cat_s,
y_cat=batch["original_cat_local_ids"],
loss_mask_cat=batch["cat_loss_mask"],
valid_class_mask=valid_class_mask,
x_num=value_by_nin,
s_num=s_by_nin,
y_num=batch["original_numeric_values_by_nin"],
loss_mask_num=batch["numeric_loss_mask_by_nin"],
reduction="mean",
cat_s_bound=cat_s_bound,
num_s_bound=num_s_bound,
)
return total_loss, stats
@torch.no_grad()
def evaluate(
model: SoilFormer,
dataset: SoilFormerDataset,
eval_loader,
device: torch.device,
dtype: torch.dtype,
cat_mask_ratio: float,
num_mask_ratio: float,
active_mask_seed: int,
show_tqdm: bool,
epoch: int,
cat_s_bound: Optional[float] = None,
num_s_bound: Optional[float] = None,
):
model.eval()
totals = {
"total": 0.0,
"cat_loss": 0.0,
"num_loss": 0.0,
"cat_base": 0.0,
"num_base": 0.0,
"cat_acc": 0.0,
}
num_batches = 0
iterator = eval_loader
if show_tqdm:
iterator = tqdm(eval_loader, desc=f"Eval {epoch}", leave=False)
for batch_idx, raw_batch in enumerate(iterator):
mask_seed = int(active_mask_seed + batch_idx)
masked_batch = dataset.perform_active_mask(
raw_batch,
cat_ratio=cat_mask_ratio,
num_ratio=num_mask_ratio,
seed=mask_seed,
)
_, stats = compute_loss_from_batch(
model=model,
batch=masked_batch,
device=device,
dtype=dtype,
cat_s_bound=cat_s_bound,
num_s_bound=num_s_bound,
)
num_batches += 1
for key in totals:
totals[key] += float(stats[key].item())
if num_batches == 0:
raise RuntimeError("Eval dataloader is empty")
return {f"eval/{k}": v / num_batches for k, v in totals.items()}
def maybe_init_wandb(config_train: Dict):
wandb_cfg = config_train["logging"]["wandb"]
if not bool(wandb_cfg.get("enabled", False)):
return None
if wandb is None:
raise ImportError("wandb is enabled in config but package is not installed")
run = wandb.init(
project=wandb_cfg["project"],
entity=wandb_cfg.get("entity"),
name=wandb_cfg.get("run_name"),
dir=wandb_cfg.get("dir"),
config=config_train,
mode=wandb_cfg.get("mode", "online"),
)
return run
def print_parameter_stats(model):
total = 0
trainable = 0
for p in model.parameters():
num = p.numel()
total += num
if p.requires_grad:
trainable += num
print("\nParameter statistics:")
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")
print(f"Frozen parameters: {total - trainable:,}\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config/config_train.json")
args = parser.parse_args()
config_train = load_json(args.config)
config_paths = config_train["paths"]
config_data = load_json(config_paths["config_data_path"])
config_model = load_json(config_paths["config_model_path"])
seed_cfg = config_train["seed"]
runtime_cfg = config_train["runtime"]
optim_cfg = config_train["optimization"]
checkpoint_cfg = config_train["checkpoint"]
logging_cfg = config_train["logging"]
loss_cfg = config_train["loss"]
set_seed(int(seed_cfg["seed"]), deterministic=bool(seed_cfg.get("deterministic", True)))
device = resolve_device(runtime_cfg["device"])
dtype = get_dtype(config_model.get("dtype", "bfloat16"))
output_dir = Path(config_paths["output_dir"])
checkpoint_dir = output_dir / "checkpoints"
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
save_json(config_train, str(output_dir / "config_train.snapshot.json"))
save_json(config_data, str(output_dir / "config_data.snapshot.json"))
save_json(config_model, str(output_dir / "config_model.snapshot.json"))
dataset = SoilFormerDataset(
csv_path=config_data["data_csv_path"],
photo_map_path=config_data["photo_map_path"],
cat_vocab_path=config_data["cat_vocab_path"],
numeric_vocab_path=config_data["numeric_vocab_path"],
numeric_stats_path=config_data["numeric_stats_path"],
photo_root=config_data["photo_root"],
image_size=int(config_data["image_size"]),
)
train_loader, eval_loader, train_generator = build_train_eval_dataloaders(
dataset=dataset,
train_ratio=float(config_data["train_ratio"]),
seed=int(config_data["train_eval_split_seed"]),
batch_size=int(config_data["batch_size"]),
)
print("\nSample statistics:")
print("Train samples:", len(train_loader.dataset))
print("Eval samples:", len(eval_loader.dataset))
train_generator.manual_seed(int(seed_cfg["seed"]))
model = SoilFormer(config=config_model, device=str(device))
resume_path = checkpoint_cfg.get("resume_checkpoint_path")
if resume_path:
checkpoint = torch.load(resume_path, map_location="cpu")
load_checkpoint_model_state(model, checkpoint["model_state_dict"])
else:
model.init_weights(std=float(runtime_cfg.get("init_weight_std", 0.02)))
checkpoint = None
print_parameter_stats(model)
optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=float(optim_cfg["lr"]),
betas=(float(optim_cfg["beta1"]), float(optim_cfg["beta2"])),
eps=float(optim_cfg["eps"]),
weight_decay=float(optim_cfg["weight_decay"]),
)
scheduler = build_scheduler(
optimizer=optimizer,
scheduler_cfg=optim_cfg.get("scheduler", {"type": "none"})
)
start_epoch = 1
global_step = 0
if checkpoint is not None:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if scheduler is not None and checkpoint.get("scheduler_state_dict") is not None:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = int(checkpoint["epoch"]) + 1
global_step = int(checkpoint.get("global_step", 0))
wandb_run = maybe_init_wandb(config_train)
num_epochs = int(runtime_cfg["num_epochs"])
show_tqdm = bool(logging_cfg.get("tqdm", True))
cat_mask_ratio = float(config_data["cat_mask_ratio"])
num_mask_ratio = float(config_data["num_mask_ratio"])
active_mask_seed = int(config_data["active_mask_seed"])
max_grad_norm = optim_cfg.get("max_grad_norm")
epochs_per_save = int(checkpoint_cfg["epochs_per_save"])
max_saved_checkpoints = int(checkpoint_cfg["max_saved_checkpoints"])
for epoch in range(start_epoch, num_epochs + 1):
model.train()
epoch_totals = {
"total": 0.0,
"cat_loss": 0.0,
"num_loss": 0.0,
"cat_base": 0.0,
"num_base": 0.0,
"cat_acc": 0.0,
}
num_batches = 0
iterator = train_loader
if show_tqdm:
iterator = tqdm(train_loader, desc=f"Train {epoch}", leave=True)
for batch_idx, raw_batch in enumerate(iterator):
global_step += 1
mask_seed = int(active_mask_seed + epoch * 1_000_000 + batch_idx)
masked_batch = dataset.perform_active_mask(
raw_batch,
cat_ratio=cat_mask_ratio,
num_ratio=num_mask_ratio,
seed=mask_seed,
)
optimizer.zero_grad(set_to_none=True)
total_loss, stats = compute_loss_from_batch(
model=model,
batch=masked_batch,
device=device,
dtype=dtype,
cat_s_bound=loss_cfg.get("cat_s_bound", None),
num_s_bound=loss_cfg.get("num_s_bound", None),
)
total_loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(max_grad_norm))
optimizer.step()
num_batches += 1
for key in epoch_totals:
epoch_totals[key] += float(stats[key].item())
current_lr = float(optimizer.param_groups[0]["lr"])
train_step_log = {
"train/step_total": float(stats["total"].item()),
"train/step_cat_loss": float(stats["cat_loss"].item()),
"train/step_num_loss": float(stats["num_loss"].item()),
"train/step_cat_acc": float(stats["cat_acc"].item()),
"train/lr": current_lr,
"epoch": epoch,
"global_step": global_step,
}
if wandb_run is not None:
wandb.log(train_step_log, step=global_step)
if show_tqdm:
iterator.set_postfix(
loss=f"{train_step_log['train/step_total']:.4f}",
lr=f"{current_lr:.3e}",
)
if num_batches == 0:
raise RuntimeError("Train dataloader is empty")
train_epoch_log = {f"train/{k}": v / num_batches for k, v in epoch_totals.items()}
train_epoch_log["train/lr_epoch_end"] = float(optimizer.param_groups[0]["lr"])
train_epoch_log["epoch"] = epoch
train_epoch_log["global_step"] = global_step
eval_log = evaluate(
model=model,
dataset=dataset,
eval_loader=eval_loader,
device=device,
dtype=dtype,
cat_mask_ratio=cat_mask_ratio,
num_mask_ratio=num_mask_ratio,
active_mask_seed=active_mask_seed,
show_tqdm=show_tqdm,
epoch=epoch,
cat_s_bound=loss_cfg.get("cat_s_bound", None),
num_s_bound=loss_cfg.get("num_s_bound", None),
)
eval_log["epoch"] = epoch
eval_log["global_step"] = global_step
merged_log = {}
merged_log.update(train_epoch_log)
merged_log.update(eval_log)
print(json.dumps(merged_log, ensure_ascii=False))
if wandb_run is not None:
wandb.log(merged_log, step=global_step)
if scheduler is not None:
scheduler.step()
if epochs_per_save > 0 and epoch % epochs_per_save == 0:
checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
save_checkpoint(
checkpoint_path=checkpoint_path,
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=epoch,
global_step=global_step,
config_train=config_train,
config_model=config_model,
config_data=config_data,
)
rotate_checkpoints(checkpoint_dir, max_saved_checkpoints)
if wandb_run is not None:
wandb.finish()
if __name__ == "__main__":
main()