vectrayx-paper-code / training /finetune_sft.py
jsantillana's picture
Upload folder using huggingface_hub
6848cb6 verified
"""SFT fine-tuning with assistant-only loss masking and an internal mini-curriculum.
Mini-curriculum (within SFT):
Epoch 1-2: 60% conversational (OASST1 ES + sft_conv) + 40% CVE Q&A
Epoch 3: add tool-use (50% conv + 25% CVE + 25% tool_use)
This avoids drowning the chat behavior in JSON tool-call patterns the way SFT v3 did.
Run example:
python -m training_v2.train.finetune_sft \
--config training_v2/configs/nano.json \
--tokenizer training_v2/tokenizer/out/vectrayx_bpe.model \
--resume training_v2/checkpoints/phase3/last.pt \
--out training_v2/checkpoints/sft_v4 \
--batch-size 16 --grad-accum 4 --epochs 3 --lr 2e-5
"""
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
import sentencepiece as spm
import torch
from torch.utils.data import DataLoader, ConcatDataset
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT))
from training_v2.data.sft_dataset import SFTDataset
from training_v2.model.transformer import VectraYXNano, ModelConfig
from training_v2.train.utils import (
cosine_with_warmup, make_optimizer, save_checkpoint, load_checkpoint, log_jsonl,
)
SFT_FILES = {
"conversational": [
"corpus/sft_conversational.jsonl",
"sft_v2_data/oasst1_es.jsonl",
],
"cve_qa": [
"corpus/sft_v2_dataset.jsonl",
],
"tool_use": [
"corpus/tooluse_dataset.jsonl",
],
}
def load_sft_corpus_config(path):
global SFT_FILES
cfg = json.loads(Path(path).read_text())
SFT_FILES = {
"conversational": cfg.get("sft_conversational", SFT_FILES["conversational"]),
"cve_qa": cfg.get("sft_cve_qa", SFT_FILES["cve_qa"]),
"tool_use": cfg.get("sft_tool_use", SFT_FILES["tool_use"]),
}
def discover(paths, root):
found = []
for rel in paths:
full = Path(root) / rel
if full.exists():
found.append(full)
else:
print(f" [skip missing] {full}")
return found
def build_dataset(args, sp, include_tools):
block_size = ModelConfig.from_json(args.config).max_seq_len
pad_id = sp.pad_id() if sp.pad_id() >= 0 else 0
conv = discover(SFT_FILES["conversational"], args.corpus_root)
cve = discover(SFT_FILES["cve_qa"], args.corpus_root)
tools = discover(SFT_FILES["tool_use"], args.corpus_root)
parts = []
if conv:
parts.append(("conv", SFTDataset(conv, sp, block_size, pad_id=pad_id, seed=args.seed)))
if cve:
parts.append(("cve", SFTDataset(cve, sp, block_size, pad_id=pad_id, seed=args.seed + 1)))
if include_tools and tools:
parts.append(("tools", SFTDataset(tools, sp, block_size, pad_id=pad_id, seed=args.seed + 2)))
return parts, pad_id
def make_loader(parts, weights, batch_size, num_workers):
"""Weighted sampling across the named parts."""
sizes = [len(d) for _, d in parts]
names = [n for n, _ in parts]
datasets = [d for _, d in parts]
big = ConcatDataset(datasets)
offsets = np.cumsum([0] + sizes)
weight_per_idx = np.zeros(offsets[-1], dtype=np.float64)
for i, n in enumerate(names):
w = weights.get(n, 1.0) / max(1, sizes[i])
weight_per_idx[offsets[i]:offsets[i + 1]] = w
sampler = torch.utils.data.WeightedRandomSampler(
weights=weight_per_idx,
num_samples=int(sum(sizes)),
replacement=True,
)
def collate(batch):
xs = torch.stack([b[0] for b in batch], 0)
ys = torch.stack([b[1] for b in batch], 0)
ms = torch.stack([b[2] for b in batch], 0)
return xs, ys, ms
return DataLoader(
big, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, collate_fn=collate, pin_memory=True,
persistent_workers=num_workers > 0,
)
def main():
p = argparse.ArgumentParser()
p.add_argument("--config", required=True)
p.add_argument("--tokenizer", required=True)
p.add_argument("--resume", required=True, help="pre-training checkpoint to fine-tune")
p.add_argument("--out", required=True)
p.add_argument("--corpus-root", default=".")
p.add_argument("--corpus-config", default=None)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--grad-accum", type=int, default=4)
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--lr", type=float, default=2e-5)
p.add_argument("--weight-decay", type=float, default=0.0)
p.add_argument("--grad-clip", type=float, default=1.0)
p.add_argument("--warmup-frac", type=float, default=0.03)
p.add_argument("--num-workers", type=int, default=2)
p.add_argument("--log-every", type=int, default=20)
p.add_argument("--save-every", type=int, default=500)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
args = p.parse_args()
if args.corpus_config:
load_sft_corpus_config(args.corpus_config)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
cfg = ModelConfig.from_json(args.config)
model = VectraYXNano(cfg).to(args.device)
print(f"[model] {model.num_params()/1e6:.2f}M params")
load_checkpoint(args.resume, model, optimizer=None, map_location=args.device)
print(f"[resume] {args.resume}")
sp = spm.SentencePieceProcessor()
sp.load(args.tokenizer)
parts, pad_id = build_dataset(args, sp, include_tools=True)
if not parts:
raise RuntimeError("no SFT files found")
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
log_path = out_dir / "train_log.jsonl"
optimizer = make_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
use_amp = args.device == "cuda" and dtype != torch.float32
epoch_plans = [
{"conv": 1.00, "cve": 0.00, "tools": 0.0}, # epoch 1: SOLO conversacional
{"conv": 0.70, "cve": 0.30, "tools": 0.00}, # epoch 2: + CVE Q&A
{"conv": 0.55, "cve": 0.30, "tools": 0.15}, # epoch 3: + tool use
]
total_steps = 0
for ep in range(args.epochs):
weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
print(f"\n=== epoch {ep+1}/{args.epochs} | mix={weights} ===")
loader = make_loader(parts, weights, args.batch_size, args.num_workers)
steps_per_epoch = max(1, len(loader) // args.grad_accum)
total_steps += steps_per_epoch
warmup = max(50, int(args.warmup_frac * total_steps))
print(f"[sft] total_steps≈{total_steps} warmup={warmup}")
model.train()
t_start = time.time()
step = 0
running_loss = 0.0
running_n = 0
for ep in range(args.epochs):
weights = epoch_plans[min(ep, len(epoch_plans) - 1)]
loader = make_loader(parts, weights, args.batch_size, args.num_workers)
data_iter = iter(loader)
steps_per_epoch = max(1, len(loader) // args.grad_accum)
for _ in range(steps_per_epoch):
cur_lr = cosine_with_warmup(step, warmup, total_steps, args.lr)
for g in optimizer.param_groups:
g["lr"] = cur_lr
optimizer.zero_grad(set_to_none=True)
loss_accum = 0.0
for _micro in range(args.grad_accum):
try:
xs, ys, ms = next(data_iter)
except StopIteration:
data_iter = iter(loader)
xs, ys, ms = next(data_iter)
xs = xs.to(args.device, non_blocking=True)
ys = ys.to(args.device, non_blocking=True)
ms = ms.to(args.device, non_blocking=True)
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_amp):
_, loss = model(xs, targets=ys, loss_mask=ms)
loss = loss / args.grad_accum
loss.backward()
loss_accum += loss.item() * args.grad_accum
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
step += 1
running_loss += loss_accum / args.grad_accum
running_n += 1
if step % args.log_every == 0:
elapsed = time.time() - t_start
avg = running_loss / running_n
print(f"[sft ep{ep+1} step {step:>5}/{total_steps}] loss={avg:.4f} "
f"lr={cur_lr:.2e} gnorm={gnorm:.2f} elapsed={elapsed/60:.1f}min")
log_jsonl(log_path, {"epoch": ep + 1, "step": step, "loss": avg,
"lr": cur_lr, "gnorm": float(gnorm)})
running_loss = 0.0
running_n = 0
if step % args.save_every == 0:
save_checkpoint(out_dir / "last.pt", model, optimizer,
{"step": step}, step,
extra={"epoch": ep + 1, "weights": weights})
save_checkpoint(out_dir / f"epoch{ep+1}.pt", model, optimizer,
{"step": step}, step,
extra={"epoch": ep + 1, "weights": weights})
print(f"[save] {out_dir}/epoch{ep+1}.pt")
save_checkpoint(out_dir / "final.pt", model, optimizer, {"step": step}, step,
extra={"done": True})
print(f"[done] SFT → {out_dir}/final.pt")
if __name__ == "__main__":
main()