Spaces:
Sleeping
Sleeping
| """Run resumable LoRA SFT against the vulnops heuristic dataset.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| DataCollatorForSeq2Seq, | |
| Trainer, | |
| TrainerCallback, | |
| TrainingArguments, | |
| ) | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from training_utils import ( | |
| detect_device, | |
| latest_checkpoint, | |
| load_jsonl, | |
| preferred_torch_dtype, | |
| set_default_env, | |
| write_json, | |
| ) | |
| class JsonlSFTDataset(Dataset): | |
| """Mask prompt tokens so only the completion contributes to the loss.""" | |
| def __init__(self, records: List[Dict[str, object]], tokenizer, max_length: int): | |
| self.examples: List[Dict[str, List[int]]] = [] | |
| for record in records: | |
| prompt = str(record["prompt"]) | |
| completion = str(record["completion"]) | |
| prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] | |
| completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] | |
| input_ids = (prompt_ids + completion_ids)[:max_length] | |
| labels = ([-100] * len(prompt_ids) + completion_ids)[:max_length] | |
| attention_mask = [1] * len(input_ids) | |
| self.examples.append( | |
| { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| def __len__(self) -> int: | |
| return len(self.examples) | |
| def __getitem__(self, index: int) -> Dict[str, List[int]]: | |
| return self.examples[index] | |
| class JsonlMetricLogger(TrainerCallback): | |
| """Append metrics during training so partial runs are still inspectable.""" | |
| def __init__(self, output_root: Path): | |
| self.output_root = output_root | |
| self.metrics_path = output_root / "metrics" / "train_metrics.jsonl" | |
| self.manifest_path = output_root / "run_manifest.json" | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if not logs: | |
| return | |
| payload = { | |
| "global_step": int(state.global_step), | |
| "epoch": float(state.epoch or 0.0), | |
| **{key: float(value) if isinstance(value, (int, float)) else value for key, value in logs.items()}, | |
| } | |
| self.metrics_path.parent.mkdir(parents=True, exist_ok=True) | |
| with self.metrics_path.open("a", encoding="utf-8") as handle: | |
| handle.write(json.dumps(payload, sort_keys=True) + "\n") | |
| write_json( | |
| self.manifest_path, | |
| { | |
| "status": "training", | |
| "global_step": int(state.global_step), | |
| "epoch": float(state.epoch or 0.0), | |
| "best_model_checkpoint": state.best_model_checkpoint, | |
| "log_history_entries": len(state.log_history), | |
| }, | |
| ) | |
| class AbortOnInvalidLoss(TrainerCallback): | |
| """Stop training early when the run becomes numerically invalid.""" | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if not logs: | |
| return control | |
| for key in ("loss", "eval_loss", "grad_norm"): | |
| value = logs.get(key) | |
| if isinstance(value, (int, float)) and not math.isfinite(float(value)): | |
| control.should_training_stop = True | |
| break | |
| return control | |
| def build_training_args(args, output_root: Path, use_cpu: bool) -> TrainingArguments: | |
| warmup_steps = max(1, int(args.warmup_ratio * args.estimated_train_steps)) | |
| return TrainingArguments( | |
| output_dir=str(output_root / "checkpoints"), | |
| num_train_epochs=args.num_train_epochs, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| learning_rate=args.learning_rate, | |
| warmup_steps=warmup_steps, | |
| optim="adamw_torch", | |
| weight_decay=args.weight_decay, | |
| logging_strategy="steps", | |
| logging_steps=args.logging_steps, | |
| logging_first_step=True, | |
| eval_strategy="no", | |
| save_strategy="steps", | |
| save_steps=args.save_steps, | |
| save_total_limit=3, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| dataloader_num_workers=0, | |
| dataloader_pin_memory=False, | |
| gradient_checkpointing=True, | |
| lr_scheduler_type="cosine", | |
| load_best_model_at_end=False, | |
| use_cpu=use_cpu, | |
| fp16=False, | |
| bf16=False, | |
| max_grad_norm=0.5, | |
| seed=args.seed, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="Qwen/Qwen3.5-4B") | |
| parser.add_argument("--output-root", default="artifacts/lora_qwen3_4b") | |
| parser.add_argument("--max-length", type=int, default=1536) | |
| parser.add_argument("--num-train-epochs", type=float, default=6.0) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=1) | |
| parser.add_argument("--per-device-eval-batch-size", type=int, default=1) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=8) | |
| parser.add_argument("--learning-rate", type=float, default=5e-5) | |
| parser.add_argument("--warmup-ratio", type=float, default=0.1) | |
| parser.add_argument("--weight-decay", type=float, default=0.0) | |
| parser.add_argument("--logging-steps", type=int, default=5) | |
| parser.add_argument("--save-steps", type=int, default=10) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--fresh-start", action="store_true") | |
| args = parser.parse_args() | |
| try: | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| except ImportError as exc: | |
| raise RuntimeError("Install peft before running LoRA training.") from exc | |
| output_root = (ROOT / args.output_root).resolve() | |
| data_dir = output_root / "data" | |
| train_records = load_jsonl(data_dir / "train.jsonl") | |
| eval_records = load_jsonl(data_dir / "eval.jsonl") | |
| if not train_records or not eval_records: | |
| raise RuntimeError("Missing train/eval JSONL data. Run scripts/generate_sft_data.py first.") | |
| set_default_env(output_root) | |
| device = detect_device() | |
| use_cpu = device == "cpu" | |
| torch_dtype = torch.float32 if device == "mps" else preferred_torch_dtype(device) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.config.use_cache = False | |
| if hasattr(model, "enable_input_require_grads"): | |
| model.enable_input_require_grads() | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| if device in {"cuda", "mps"}: | |
| model.to(device) | |
| train_dataset = JsonlSFTDataset(train_records, tokenizer, args.max_length) | |
| eval_dataset = JsonlSFTDataset(eval_records, tokenizer, args.max_length) | |
| data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True) | |
| updates_per_epoch = max( | |
| 1, | |
| math.ceil(len(train_dataset) / (args.per_device_train_batch_size * args.gradient_accumulation_steps)), | |
| ) | |
| args.estimated_train_steps = max(1, math.ceil(args.num_train_epochs * updates_per_epoch)) | |
| training_args = build_training_args(args, output_root, use_cpu) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| data_collator=data_collator, | |
| callbacks=[JsonlMetricLogger(output_root), AbortOnInvalidLoss()], | |
| ) | |
| checkpoint_dir = output_root / "checkpoints" | |
| resume_checkpoint = None if args.fresh_start else latest_checkpoint(checkpoint_dir) | |
| write_json( | |
| output_root / "run_manifest.json", | |
| { | |
| "status": "starting_training", | |
| "device": device, | |
| "model": args.model, | |
| "train_examples": len(train_dataset), | |
| "eval_examples": len(eval_dataset), | |
| "estimated_train_steps": args.estimated_train_steps, | |
| "resume_checkpoint": str(resume_checkpoint) if resume_checkpoint else None, | |
| }, | |
| ) | |
| train_result = trainer.train(resume_from_checkpoint=str(resume_checkpoint) if resume_checkpoint else None) | |
| trainer.save_model(str(output_root / "adapter")) | |
| tokenizer.save_pretrained(str(output_root / "adapter")) | |
| final_eval = trainer.evaluate(eval_dataset=eval_dataset) | |
| summary = { | |
| "status": "finished", | |
| "device": device, | |
| "train_loss": float(train_result.training_loss), | |
| "global_step": int(trainer.state.global_step), | |
| "eval_loss": float(final_eval["eval_loss"]) if math.isfinite(float(final_eval["eval_loss"])) else None, | |
| "adapter_dir": str(output_root / "adapter"), | |
| } | |
| write_json(output_root / "training_summary.json", summary) | |
| write_json(output_root / "run_manifest.json", summary) | |
| print(json.dumps(summary, indent=2, sort_keys=True)) | |
| if __name__ == "__main__": | |
| main() | |