PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 10,758 Bytes
d9ba941
 
 
 
 
 
 
 
5a23de5
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a896ecd
d9ba941
 
 
 
91d636a
 
 
 
 
 
 
 
 
 
 
5a23de5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ba941
 
91d636a
d9ba941
5a23de5
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608f732
 
 
 
 
 
 
 
 
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608f732
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a896ecd
 
 
 
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#!/usr/bin/env python3
"""QLoRA SFT training for TMF921 intent-to-config research dataset.

Designed for a single RTX 6000 Ada 48/50GB server. Uses TRL SFTTrainer with PEFT QLoRA.
"""
import argparse
import math
import os
import re
from pathlib import Path

import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback, set_seed
from trl import SFTConfig, SFTTrainer

from tmf921_train.utils import load_config, write_json

try:
    import trackio
except Exception:  # pragma: no cover
    trackio = None


class TrackioAlertCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero or not logs or trackio is None:
            return
        loss = logs.get("loss")
        grad_norm = logs.get("grad_norm")
        if loss is not None and (math.isnan(float(loss)) or math.isinf(float(loss))):
            trackio.alert(
                title="NaN/Inf training loss",
                text=f"step={state.global_step} loss={loss} — stop run and reduce learning_rate by 10x.",
                level="ERROR",
            )
        if grad_norm is not None and float(grad_norm) > 10.0:
            trackio.alert(
                title="Gradient norm spike",
                text=f"step={state.global_step} grad_norm={float(grad_norm):.3f} — consider lower lr or max_grad_norm.",
                level="WARN",
            )

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if not state.is_world_process_zero or not metrics or trackio is None:
            return
        loss = metrics.get("eval_loss")
        if loss is not None and float(loss) > 1.0:
            trackio.alert(
                title="High validation loss",
                text=f"step={state.global_step} eval_loss={float(loss):.4f} — check convergence and rare-class oversampling.",
                level="WARN",
            )


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--config", default="configs/rtx6000ada_qwen3_8b_qlora.yaml")
    p.add_argument("--model_name_or_path")
    p.add_argument("--dataset_name")
    p.add_argument("--train_split")
    p.add_argument("--eval_split")
    p.add_argument("--output_dir")
    p.add_argument("--hub_model_id")
    p.add_argument("--max_steps", type=int, default=None, help="Debug/short run override")
    p.add_argument("--no_push", action="store_true")
    p.add_argument("--packing", action="store_true", help="Override config and enable packing. Requires compatible attention setup.")
    p.add_argument("--flash_attn", action="store_true", help="Use flash_attention_2 in model_init_kwargs. Install flash-attn first.")
    p.add_argument("--resume_from_checkpoint", default=None, help="Path to checkpoint dir, or 'true' to auto-resume latest checkpoint in output_dir")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


def require_cuda():
    print("=== CUDA CHECK ===")
    print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
    if not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA is not available to PyTorch. Refusing to train on CPU. "
            "Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0."
        )
    print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")


def valid_hf_repo_id(repo_id):
    if not repo_id or not isinstance(repo_id, str):
        return False
    if repo_id.endswith("/") or repo_id.startswith("/") or "//" in repo_id:
        return False
    pattern = r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}/[A-Za-z0-9][A-Za-z0-9._-]{0,95}$"
    return re.match(pattern, repo_id) is not None


def sanitize_trackio_config(cfg):
    # Environment variable takes precedence only if valid. Invalid values like "nraptisss/"
    # crash Trackio before training starts, so ignore them and continue without a Space.
    env_space = os.environ.get("TRACKIO_SPACE_ID", "").strip()
    cfg_space = str(cfg.get("trackio_space_id") or "").strip()
    chosen = env_space or cfg_space
    if chosen and valid_hf_repo_id(chosen):
        cfg["trackio_space_id"] = chosen
        print(f"Trackio Space: {chosen}")
    else:
        if chosen:
            print(f"WARNING: ignoring invalid Trackio Space ID: {chosen!r}. Expected format: namespace/space-name")
        cfg["trackio_space_id"] = None
        os.environ.pop("TRACKIO_SPACE_ID", None)
    # Set DISABLE_TRACKIO=1 to bypass Trackio completely if desired.
    if os.environ.get("DISABLE_TRACKIO", "0") == "1":
        print("Trackio disabled via DISABLE_TRACKIO=1")
        cfg["project"] = None
        cfg["trackio_space_id"] = None
    return cfg


def main():
    args = parse_args()
    require_cuda()
    cfg = load_config(args.config)
    cfg = sanitize_trackio_config(cfg)
    for k in ["model_name_or_path", "dataset_name", "train_split", "eval_split", "output_dir", "hub_model_id"]:
        v = getattr(args, k)
        if v is not None:
            cfg[k] = v
    if args.max_steps is not None:
        cfg["max_steps"] = args.max_steps
        cfg["num_train_epochs"] = 1
    if args.no_push:
        cfg["push_to_hub"] = False
    if args.packing:
        cfg["packing"] = True

    set_seed(args.seed)
    Path(cfg["output_dir"]).mkdir(parents=True, exist_ok=True)
    write_json(Path(cfg["output_dir"]) / "resolved_config.json", cfg)

    print("Loading dataset", cfg["dataset_name"])
    ds = load_dataset(cfg["dataset_name"])
    train_dataset = ds[cfg.get("train_split", "train_sota")]
    eval_dataset = ds[cfg.get("eval_split", "validation")]
    print(train_dataset)
    print(eval_dataset)

    # TRL infers dataset type from column names. This research dataset includes both
    # `messages` and convenience `prompt`/`completion` columns; passing all columns can
    # make TRL classify it as prompt-completion instead of conversational and reject
    # assistant_only_loss=True. For SFT we intentionally train from ChatML `messages`.
    train_dataset = train_dataset.select_columns(["messages"])
    eval_dataset = eval_dataset.select_columns(["messages"])
    print("SFT train columns:", train_dataset.column_names)
    print("SFT eval columns:", eval_dataset.column_names)

    tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"], trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    bnb_config = None
    if cfg.get("load_in_4bit", True):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=cfg.get("bnb_4bit_quant_type", "nf4"),
            bnb_4bit_use_double_quant=bool(cfg.get("bnb_4bit_use_double_quant", True)),
            bnb_4bit_compute_dtype=torch.bfloat16,
        )

    model_init_kwargs = {
        "trust_remote_code": True,
        "device_map": {"": 0},
        "dtype": torch.bfloat16 if cfg.get("bf16", True) else torch.float16,
    }
    if bnb_config is not None:
        model_init_kwargs["quantization_config"] = bnb_config
    if args.flash_attn:
        model_init_kwargs["attn_implementation"] = "flash_attention_2"

    target_modules = cfg.get("lora_target_modules", "all-linear")
    peft_config = LoraConfig(
        r=int(cfg.get("lora_r", 64)),
        lora_alpha=int(cfg.get("lora_alpha", 16)),
        lora_dropout=float(cfg.get("lora_dropout", 0.05)),
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )

    report_to = "trackio" if cfg.get("project") else "none"
    sft_args = SFTConfig(
        output_dir=cfg["output_dir"],
        model_init_kwargs=model_init_kwargs,
        max_length=int(cfg.get("max_length", 2048)),
        packing=bool(cfg.get("packing", False)),
        assistant_only_loss=bool(cfg.get("assistant_only_loss", True)),
        dataset_num_proc=int(cfg.get("dataset_num_proc", 8)),
        learning_rate=float(cfg.get("learning_rate", 2e-4)),
        lr_scheduler_type=cfg.get("lr_scheduler_type", "constant"),
        warmup_steps=int(cfg.get("warmup_steps", 0)),
        weight_decay=float(cfg.get("weight_decay", 0.0)),
        max_grad_norm=float(cfg.get("max_grad_norm", 0.3)),
        num_train_epochs=float(cfg.get("epochs", 2)),
        max_steps=int(cfg["max_steps"]) if cfg.get("max_steps") is not None else -1,
        per_device_train_batch_size=int(cfg.get("per_device_train_batch_size", 2)),
        gradient_accumulation_steps=int(cfg.get("gradient_accumulation_steps", 8)),
        per_device_eval_batch_size=int(cfg.get("per_device_eval_batch_size", 2)),
        bf16=bool(cfg.get("bf16", True)),
        gradient_checkpointing=bool(cfg.get("gradient_checkpointing", True)),
        gradient_checkpointing_kwargs={"use_reentrant": False},
        optim=cfg.get("optim", "paged_adamw_32bit"),
        eval_strategy="steps",
        eval_steps=int(cfg.get("eval_steps", 250)),
        save_strategy="steps",
        save_steps=int(cfg.get("save_steps", 250)),
        save_total_limit=int(cfg.get("save_total_limit", 3)),
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        logging_strategy="steps",
        logging_steps=int(cfg.get("logging_steps", 10)),
        logging_first_step=True,
        disable_tqdm=True,
        report_to=report_to,
        run_name=cfg.get("run_name"),
        project=cfg.get("project"),
        trackio_space_id=cfg.get("trackio_space_id"),
        push_to_hub=bool(cfg.get("push_to_hub", True)),
        hub_model_id=cfg.get("hub_model_id"),
    )

    trainer = SFTTrainer(
        model=cfg["model_name_or_path"],
        args=sft_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
        callbacks=[TrackioAlertCallback()],
    )

    resume_arg = args.resume_from_checkpoint
    if resume_arg is not None and str(resume_arg).lower() == "true":
        resume_arg = True
    trainer.train(resume_from_checkpoint=resume_arg)
    metrics = trainer.evaluate()
    write_json(Path(cfg["output_dir"]) / "final_eval_metrics.json", metrics)
    trainer.save_model(cfg["output_dir"])
    tokenizer.save_pretrained(cfg["output_dir"])

    if bool(cfg.get("push_to_hub", True)):
        trainer.push_to_hub(
            commit_message="Qwen TMF921 QLoRA SFT",
            dataset_name=cfg["dataset_name"],
        )
        print(f"Pushed model/adapters to https://huggingface.co/{cfg.get('hub_model_id')}")


if __name__ == "__main__":
    main()