PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /train_qlora.py
nraptisss's picture
Harden Trackio Space validation to avoid startup crash
5a23de5 verified
#!/usr/bin/env python3
"""QLoRA SFT training for TMF921 intent-to-config research dataset.
Designed for a single RTX 6000 Ada 48/50GB server. Uses TRL SFTTrainer with PEFT QLoRA.
"""
import argparse
import math
import os
import re
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback, set_seed
from trl import SFTConfig, SFTTrainer
from tmf921_train.utils import load_config, write_json
try:
import trackio
except Exception: # pragma: no cover
trackio = None
class TrackioAlertCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero or not logs or trackio is None:
return
loss = logs.get("loss")
grad_norm = logs.get("grad_norm")
if loss is not None and (math.isnan(float(loss)) or math.isinf(float(loss))):
trackio.alert(
title="NaN/Inf training loss",
text=f"step={state.global_step} loss={loss} — stop run and reduce learning_rate by 10x.",
level="ERROR",
)
if grad_norm is not None and float(grad_norm) > 10.0:
trackio.alert(
title="Gradient norm spike",
text=f"step={state.global_step} grad_norm={float(grad_norm):.3f} — consider lower lr or max_grad_norm.",
level="WARN",
)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not state.is_world_process_zero or not metrics or trackio is None:
return
loss = metrics.get("eval_loss")
if loss is not None and float(loss) > 1.0:
trackio.alert(
title="High validation loss",
text=f"step={state.global_step} eval_loss={float(loss):.4f} — check convergence and rare-class oversampling.",
level="WARN",
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--config", default="configs/rtx6000ada_qwen3_8b_qlora.yaml")
p.add_argument("--model_name_or_path")
p.add_argument("--dataset_name")
p.add_argument("--train_split")
p.add_argument("--eval_split")
p.add_argument("--output_dir")
p.add_argument("--hub_model_id")
p.add_argument("--max_steps", type=int, default=None, help="Debug/short run override")
p.add_argument("--no_push", action="store_true")
p.add_argument("--packing", action="store_true", help="Override config and enable packing. Requires compatible attention setup.")
p.add_argument("--flash_attn", action="store_true", help="Use flash_attention_2 in model_init_kwargs. Install flash-attn first.")
p.add_argument("--resume_from_checkpoint", default=None, help="Path to checkpoint dir, or 'true' to auto-resume latest checkpoint in output_dir")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def require_cuda():
print("=== CUDA CHECK ===")
print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is not available to PyTorch. Refusing to train on CPU. "
"Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0."
)
print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
def valid_hf_repo_id(repo_id):
if not repo_id or not isinstance(repo_id, str):
return False
if repo_id.endswith("/") or repo_id.startswith("/") or "//" in repo_id:
return False
pattern = r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}/[A-Za-z0-9][A-Za-z0-9._-]{0,95}$"
return re.match(pattern, repo_id) is not None
def sanitize_trackio_config(cfg):
# Environment variable takes precedence only if valid. Invalid values like "nraptisss/"
# crash Trackio before training starts, so ignore them and continue without a Space.
env_space = os.environ.get("TRACKIO_SPACE_ID", "").strip()
cfg_space = str(cfg.get("trackio_space_id") or "").strip()
chosen = env_space or cfg_space
if chosen and valid_hf_repo_id(chosen):
cfg["trackio_space_id"] = chosen
print(f"Trackio Space: {chosen}")
else:
if chosen:
print(f"WARNING: ignoring invalid Trackio Space ID: {chosen!r}. Expected format: namespace/space-name")
cfg["trackio_space_id"] = None
os.environ.pop("TRACKIO_SPACE_ID", None)
# Set DISABLE_TRACKIO=1 to bypass Trackio completely if desired.
if os.environ.get("DISABLE_TRACKIO", "0") == "1":
print("Trackio disabled via DISABLE_TRACKIO=1")
cfg["project"] = None
cfg["trackio_space_id"] = None
return cfg
def main():
args = parse_args()
require_cuda()
cfg = load_config(args.config)
cfg = sanitize_trackio_config(cfg)
for k in ["model_name_or_path", "dataset_name", "train_split", "eval_split", "output_dir", "hub_model_id"]:
v = getattr(args, k)
if v is not None:
cfg[k] = v
if args.max_steps is not None:
cfg["max_steps"] = args.max_steps
cfg["num_train_epochs"] = 1
if args.no_push:
cfg["push_to_hub"] = False
if args.packing:
cfg["packing"] = True
set_seed(args.seed)
Path(cfg["output_dir"]).mkdir(parents=True, exist_ok=True)
write_json(Path(cfg["output_dir"]) / "resolved_config.json", cfg)
print("Loading dataset", cfg["dataset_name"])
ds = load_dataset(cfg["dataset_name"])
train_dataset = ds[cfg.get("train_split", "train_sota")]
eval_dataset = ds[cfg.get("eval_split", "validation")]
print(train_dataset)
print(eval_dataset)
# TRL infers dataset type from column names. This research dataset includes both
# `messages` and convenience `prompt`/`completion` columns; passing all columns can
# make TRL classify it as prompt-completion instead of conversational and reject
# assistant_only_loss=True. For SFT we intentionally train from ChatML `messages`.
train_dataset = train_dataset.select_columns(["messages"])
eval_dataset = eval_dataset.select_columns(["messages"])
print("SFT train columns:", train_dataset.column_names)
print("SFT eval columns:", eval_dataset.column_names)
tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"], trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = None
if cfg.get("load_in_4bit", True):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=cfg.get("bnb_4bit_quant_type", "nf4"),
bnb_4bit_use_double_quant=bool(cfg.get("bnb_4bit_use_double_quant", True)),
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_init_kwargs = {
"trust_remote_code": True,
"device_map": {"": 0},
"dtype": torch.bfloat16 if cfg.get("bf16", True) else torch.float16,
}
if bnb_config is not None:
model_init_kwargs["quantization_config"] = bnb_config
if args.flash_attn:
model_init_kwargs["attn_implementation"] = "flash_attention_2"
target_modules = cfg.get("lora_target_modules", "all-linear")
peft_config = LoraConfig(
r=int(cfg.get("lora_r", 64)),
lora_alpha=int(cfg.get("lora_alpha", 16)),
lora_dropout=float(cfg.get("lora_dropout", 0.05)),
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
report_to = "trackio" if cfg.get("project") else "none"
sft_args = SFTConfig(
output_dir=cfg["output_dir"],
model_init_kwargs=model_init_kwargs,
max_length=int(cfg.get("max_length", 2048)),
packing=bool(cfg.get("packing", False)),
assistant_only_loss=bool(cfg.get("assistant_only_loss", True)),
dataset_num_proc=int(cfg.get("dataset_num_proc", 8)),
learning_rate=float(cfg.get("learning_rate", 2e-4)),
lr_scheduler_type=cfg.get("lr_scheduler_type", "constant"),
warmup_steps=int(cfg.get("warmup_steps", 0)),
weight_decay=float(cfg.get("weight_decay", 0.0)),
max_grad_norm=float(cfg.get("max_grad_norm", 0.3)),
num_train_epochs=float(cfg.get("epochs", 2)),
max_steps=int(cfg["max_steps"]) if cfg.get("max_steps") is not None else -1,
per_device_train_batch_size=int(cfg.get("per_device_train_batch_size", 2)),
gradient_accumulation_steps=int(cfg.get("gradient_accumulation_steps", 8)),
per_device_eval_batch_size=int(cfg.get("per_device_eval_batch_size", 2)),
bf16=bool(cfg.get("bf16", True)),
gradient_checkpointing=bool(cfg.get("gradient_checkpointing", True)),
gradient_checkpointing_kwargs={"use_reentrant": False},
optim=cfg.get("optim", "paged_adamw_32bit"),
eval_strategy="steps",
eval_steps=int(cfg.get("eval_steps", 250)),
save_strategy="steps",
save_steps=int(cfg.get("save_steps", 250)),
save_total_limit=int(cfg.get("save_total_limit", 3)),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_strategy="steps",
logging_steps=int(cfg.get("logging_steps", 10)),
logging_first_step=True,
disable_tqdm=True,
report_to=report_to,
run_name=cfg.get("run_name"),
project=cfg.get("project"),
trackio_space_id=cfg.get("trackio_space_id"),
push_to_hub=bool(cfg.get("push_to_hub", True)),
hub_model_id=cfg.get("hub_model_id"),
)
trainer = SFTTrainer(
model=cfg["model_name_or_path"],
args=sft_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
peft_config=peft_config,
callbacks=[TrackioAlertCallback()],
)
resume_arg = args.resume_from_checkpoint
if resume_arg is not None and str(resume_arg).lower() == "true":
resume_arg = True
trainer.train(resume_from_checkpoint=resume_arg)
metrics = trainer.evaluate()
write_json(Path(cfg["output_dir"]) / "final_eval_metrics.json", metrics)
trainer.save_model(cfg["output_dir"])
tokenizer.save_pretrained(cfg["output_dir"])
if bool(cfg.get("push_to_hub", True)):
trainer.push_to_hub(
commit_message="Qwen TMF921 QLoRA SFT",
dataset_name=cfg["dataset_name"],
)
print(f"Pushed model/adapters to https://huggingface.co/{cfg.get('hub_model_id')}")
if __name__ == "__main__":
main()