File size: 9,770 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Run resumable LoRA SFT against the vulnops heuristic dataset."""

from __future__ import annotations

import argparse
import json
import math
import sys
from pathlib import Path
from typing import Dict, List

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainerCallback,
    TrainingArguments,
)

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from training_utils import (
    detect_device,
    latest_checkpoint,
    load_jsonl,
    preferred_torch_dtype,
    set_default_env,
    write_json,
)


class JsonlSFTDataset(Dataset):
    """Mask prompt tokens so only the completion contributes to the loss."""

    def __init__(self, records: List[Dict[str, object]], tokenizer, max_length: int):
        self.examples: List[Dict[str, List[int]]] = []
        for record in records:
            prompt = str(record["prompt"])
            completion = str(record["completion"])
            prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
            completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id]

            input_ids = (prompt_ids + completion_ids)[:max_length]
            labels = ([-100] * len(prompt_ids) + completion_ids)[:max_length]
            attention_mask = [1] * len(input_ids)
            self.examples.append(
                {
                    "input_ids": input_ids,
                    "labels": labels,
                    "attention_mask": attention_mask,
                }
            )

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, index: int) -> Dict[str, List[int]]:
        return self.examples[index]


class JsonlMetricLogger(TrainerCallback):
    """Append metrics during training so partial runs are still inspectable."""

    def __init__(self, output_root: Path):
        self.output_root = output_root
        self.metrics_path = output_root / "metrics" / "train_metrics.jsonl"
        self.manifest_path = output_root / "run_manifest.json"

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        payload = {
            "global_step": int(state.global_step),
            "epoch": float(state.epoch or 0.0),
            **{key: float(value) if isinstance(value, (int, float)) else value for key, value in logs.items()},
        }
        self.metrics_path.parent.mkdir(parents=True, exist_ok=True)
        with self.metrics_path.open("a", encoding="utf-8") as handle:
            handle.write(json.dumps(payload, sort_keys=True) + "\n")
        write_json(
            self.manifest_path,
            {
                "status": "training",
                "global_step": int(state.global_step),
                "epoch": float(state.epoch or 0.0),
                "best_model_checkpoint": state.best_model_checkpoint,
                "log_history_entries": len(state.log_history),
            },
        )


class AbortOnInvalidLoss(TrainerCallback):
    """Stop training early when the run becomes numerically invalid."""

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return control
        for key in ("loss", "eval_loss", "grad_norm"):
            value = logs.get(key)
            if isinstance(value, (int, float)) and not math.isfinite(float(value)):
                control.should_training_stop = True
                break
        return control


def build_training_args(args, output_root: Path, use_cpu: bool) -> TrainingArguments:
    warmup_steps = max(1, int(args.warmup_ratio * args.estimated_train_steps))
    return TrainingArguments(
        output_dir=str(output_root / "checkpoints"),
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        warmup_steps=warmup_steps,
        optim="adamw_torch",
        weight_decay=args.weight_decay,
        logging_strategy="steps",
        logging_steps=args.logging_steps,
        logging_first_step=True,
        eval_strategy="no",
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit=3,
        report_to="none",
        remove_unused_columns=False,
        dataloader_num_workers=0,
        dataloader_pin_memory=False,
        gradient_checkpointing=True,
        lr_scheduler_type="cosine",
        load_best_model_at_end=False,
        use_cpu=use_cpu,
        fp16=False,
        bf16=False,
        max_grad_norm=0.5,
        seed=args.seed,
    )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
    parser.add_argument("--output-root", default="artifacts/lora_qwen3_4b")
    parser.add_argument("--max-length", type=int, default=1536)
    parser.add_argument("--num-train-epochs", type=float, default=6.0)
    parser.add_argument("--per-device-train-batch-size", type=int, default=1)
    parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
    parser.add_argument("--learning-rate", type=float, default=5e-5)
    parser.add_argument("--warmup-ratio", type=float, default=0.1)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--logging-steps", type=int, default=5)
    parser.add_argument("--save-steps", type=int, default=10)
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--fresh-start", action="store_true")
    args = parser.parse_args()

    try:
        from peft import LoraConfig, TaskType, get_peft_model
    except ImportError as exc:
        raise RuntimeError("Install peft before running LoRA training.") from exc

    output_root = (ROOT / args.output_root).resolve()
    data_dir = output_root / "data"
    train_records = load_jsonl(data_dir / "train.jsonl")
    eval_records = load_jsonl(data_dir / "eval.jsonl")
    if not train_records or not eval_records:
        raise RuntimeError("Missing train/eval JSONL data. Run scripts/generate_sft_data.py first.")

    set_default_env(output_root)
    device = detect_device()
    use_cpu = device == "cpu"
    torch_dtype = torch.float32 if device == "mps" else preferred_torch_dtype(device)

    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    model.config.use_cache = False
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    if device in {"cuda", "mps"}:
        model.to(device)

    train_dataset = JsonlSFTDataset(train_records, tokenizer, args.max_length)
    eval_dataset = JsonlSFTDataset(eval_records, tokenizer, args.max_length)
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
    updates_per_epoch = max(
        1,
        math.ceil(len(train_dataset) / (args.per_device_train_batch_size * args.gradient_accumulation_steps)),
    )
    args.estimated_train_steps = max(1, math.ceil(args.num_train_epochs * updates_per_epoch))
    training_args = build_training_args(args, output_root, use_cpu)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        data_collator=data_collator,
        callbacks=[JsonlMetricLogger(output_root), AbortOnInvalidLoss()],
    )

    checkpoint_dir = output_root / "checkpoints"
    resume_checkpoint = None if args.fresh_start else latest_checkpoint(checkpoint_dir)
    write_json(
        output_root / "run_manifest.json",
        {
            "status": "starting_training",
            "device": device,
            "model": args.model,
            "train_examples": len(train_dataset),
            "eval_examples": len(eval_dataset),
            "estimated_train_steps": args.estimated_train_steps,
            "resume_checkpoint": str(resume_checkpoint) if resume_checkpoint else None,
        },
    )

    train_result = trainer.train(resume_from_checkpoint=str(resume_checkpoint) if resume_checkpoint else None)
    trainer.save_model(str(output_root / "adapter"))
    tokenizer.save_pretrained(str(output_root / "adapter"))

    final_eval = trainer.evaluate(eval_dataset=eval_dataset)
    summary = {
        "status": "finished",
        "device": device,
        "train_loss": float(train_result.training_loss),
        "global_step": int(trainer.state.global_step),
        "eval_loss": float(final_eval["eval_loss"]) if math.isfinite(float(final_eval["eval_loss"])) else None,
        "adapter_dir": str(output_root / "adapter"),
    }
    write_json(output_root / "training_summary.json", summary)
    write_json(output_root / "run_manifest.json", summary)
    print(json.dumps(summary, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()