| |
| """Second-stage continuation training for an existing QLoRA adapter. |
| |
| Loads the base model in 4-bit, prepares it for k-bit training, loads an existing LoRA adapter with |
| is_trainable=True, and continues SFT on a local weak-layer dataset. |
| """ |
| import argparse |
| import math |
| import os |
| import re |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from peft import PeftConfig, PeftModel, get_model_status, prepare_model_for_kbit_training |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback, set_seed |
| from trl import SFTConfig, SFTTrainer |
|
|
| from tmf921_train.utils import load_config, write_json |
|
|
| try: |
| import trackio |
| except Exception: |
| trackio = None |
|
|
|
|
| class TrackioAlertCallback(TrainerCallback): |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not state.is_world_process_zero or not logs or trackio is None: |
| return |
| loss = logs.get("loss") |
| if loss is not None and (math.isnan(float(loss)) or math.isinf(float(loss))): |
| trackio.alert(title="NaN/Inf stage2 loss", text=f"step={state.global_step} loss={loss} — lower LR", level="ERROR") |
|
|
|
|
| def require_cuda(): |
| print("=== CUDA CHECK ===") |
| print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA unavailable. Refusing CPU training.") |
| print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}") |
|
|
|
|
| def valid_hf_repo_id(repo_id): |
| if not repo_id or not isinstance(repo_id, str): |
| return False |
| if repo_id.endswith("/") or repo_id.startswith("/") or "//" in repo_id: |
| return False |
| return re.match(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}/[A-Za-z0-9][A-Za-z0-9._-]{0,95}$", repo_id) is not None |
|
|
|
|
| def sanitize_trackio_config(cfg): |
| env_space = os.environ.get("TRACKIO_SPACE_ID", "").strip() |
| cfg_space = str(cfg.get("trackio_space_id") or "").strip() |
| chosen = env_space or cfg_space |
| if chosen and valid_hf_repo_id(chosen): |
| cfg["trackio_space_id"] = chosen |
| else: |
| if chosen: |
| print(f"WARNING: ignoring invalid Trackio Space ID: {chosen!r}") |
| cfg["trackio_space_id"] = None |
| os.environ.pop("TRACKIO_SPACE_ID", None) |
| if os.environ.get("DISABLE_TRACKIO", "0") == "1": |
| print("Trackio disabled via DISABLE_TRACKIO=1") |
| cfg["project"] = None |
| cfg["trackio_space_id"] = None |
| return cfg |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", default="configs/stage2_weak_layer_qwen3_8b.yaml") |
| p.add_argument("--adapter_path", help="Existing adapter path/repo to continue training") |
| p.add_argument("--dataset_dir", help="Local dir containing train.parquet and validation.parquet") |
| p.add_argument("--output_dir") |
| p.add_argument("--hub_model_id") |
| p.add_argument("--max_steps", type=int, default=None) |
| p.add_argument("--no_push", action="store_true") |
| p.add_argument("--seed", type=int, default=43) |
| p.add_argument("--resume_from_checkpoint", default=None) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| require_cuda() |
| cfg = sanitize_trackio_config(load_config(args.config)) |
| for k in ["adapter_path", "dataset_dir", "output_dir", "hub_model_id"]: |
| v = getattr(args, k) |
| if v is not None: |
| cfg[k] = v |
| if args.max_steps is not None: |
| cfg["max_steps"] = args.max_steps |
| if args.no_push: |
| cfg["push_to_hub"] = False |
|
|
| set_seed(args.seed) |
| Path(cfg["output_dir"]).mkdir(parents=True, exist_ok=True) |
| write_json(Path(cfg["output_dir"]) / "resolved_config.json", cfg) |
|
|
| print("Loading local stage2 dataset", cfg["dataset_dir"]) |
| data_files = { |
| "train": str(Path(cfg["dataset_dir"]) / "train.parquet"), |
| "validation": str(Path(cfg["dataset_dir"]) / "validation.parquet"), |
| } |
| ds = load_dataset("parquet", data_files=data_files) |
| train_dataset = ds["train"].select_columns(["messages"]) |
| eval_dataset = ds["validation"].select_columns(["messages"]) |
| print(train_dataset) |
| print(eval_dataset) |
|
|
| peft_cfg = PeftConfig.from_pretrained(cfg["adapter_path"]) |
| base_model_id = cfg.get("model_name_or_path") or peft_cfg.base_model_name_or_path or "Qwen/Qwen3-8B" |
| print("Base model:", base_model_id) |
| print("Adapter:", cfg["adapter_path"]) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg["adapter_path"], trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type=cfg.get("bnb_4bit_quant_type", "nf4"), |
| bnb_4bit_use_double_quant=bool(cfg.get("bnb_4bit_use_double_quant", True)), |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| quantization_config=bnb_config, |
| device_map={"": 0}, |
| dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
| base_model.config.use_cache = False |
| base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=bool(cfg.get("gradient_checkpointing", True))) |
| model = PeftModel.from_pretrained(base_model, cfg["adapter_path"], is_trainable=True) |
| model.print_trainable_parameters() |
| status = get_model_status(model) |
| print(status) |
| if status.trainable_params <= 0: |
| raise RuntimeError("No trainable adapter parameters found; refusing to run stage2.") |
|
|
| report_to = "trackio" if cfg.get("project") else "none" |
| sft_args = SFTConfig( |
| output_dir=cfg["output_dir"], |
| max_length=int(cfg.get("max_length", 2048)), |
| packing=bool(cfg.get("packing", False)), |
| assistant_only_loss=bool(cfg.get("assistant_only_loss", True)), |
| dataset_num_proc=int(cfg.get("dataset_num_proc", 8)), |
| learning_rate=float(cfg.get("learning_rate", 5e-5)), |
| lr_scheduler_type=cfg.get("lr_scheduler_type", "constant"), |
| warmup_steps=int(cfg.get("warmup_steps", 0)), |
| weight_decay=float(cfg.get("weight_decay", 0.0)), |
| max_grad_norm=float(cfg.get("max_grad_norm", 0.3)), |
| num_train_epochs=float(cfg.get("epochs", 1)), |
| max_steps=int(cfg["max_steps"]) if cfg.get("max_steps") is not None else -1, |
| per_device_train_batch_size=int(cfg.get("per_device_train_batch_size", 1)), |
| gradient_accumulation_steps=int(cfg.get("gradient_accumulation_steps", 16)), |
| per_device_eval_batch_size=int(cfg.get("per_device_eval_batch_size", 1)), |
| bf16=True, |
| gradient_checkpointing=bool(cfg.get("gradient_checkpointing", True)), |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| optim=cfg.get("optim", "paged_adamw_32bit"), |
| eval_strategy="steps", |
| eval_steps=int(cfg.get("eval_steps", 100)), |
| save_strategy="steps", |
| save_steps=int(cfg.get("save_steps", 100)), |
| save_total_limit=int(cfg.get("save_total_limit", 3)), |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| logging_strategy="steps", |
| logging_steps=int(cfg.get("logging_steps", 10)), |
| logging_first_step=True, |
| disable_tqdm=True, |
| report_to=report_to, |
| run_name=cfg.get("run_name"), |
| project=cfg.get("project"), |
| trackio_space_id=cfg.get("trackio_space_id"), |
| push_to_hub=bool(cfg.get("push_to_hub", True)), |
| hub_model_id=cfg.get("hub_model_id"), |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=sft_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=tokenizer, |
| callbacks=[TrackioAlertCallback()], |
| ) |
| resume_arg = args.resume_from_checkpoint |
| if resume_arg is not None and str(resume_arg).lower() == "true": |
| resume_arg = True |
| trainer.train(resume_from_checkpoint=resume_arg) |
| metrics = trainer.evaluate() |
| write_json(Path(cfg["output_dir"]) / "final_eval_metrics.json", metrics) |
| trainer.save_model(cfg["output_dir"]) |
| tokenizer.save_pretrained(cfg["output_dir"]) |
|
|
| if bool(cfg.get("push_to_hub", True)): |
| |
| trainer.push_to_hub( |
| commit_message="Stage2 weak-layer QLoRA continuation", |
| dataset_name="nraptisss/TMF921-intent-to-config-research-sota", |
| ) |
| print(f"Pushed stage2 adapter to https://huggingface.co/{cfg.get('hub_model_id')}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|