bochen2079's picture
download
raw
4.9 kB
"""
Stage 2 — DPO trainer for Katherine k0, on top of the SFT adapter.
180 curated chosen/rejected pairs from k0_dpo_only.jsonl (system-stripped).
The "chosen" are Katherine-y responses (direct, embodied, in-character).
The "rejected" are assistant-register responses (helpful, bullet lists, sympathetic).
Hyperparameters:
epochs = 2
lr = 5e-6 (much smaller than SFT — DPO needs gentle steps)
beta = 0.1 (KL strength to reference model; standard)
batch = 4 per device, grad_accum = 2 (effective 8)
max_seq = 1024
optim = adamw_8bit
Loads the SFT adapter as both policy + reference, then trains the policy to
prefer chosen responses over rejected.
"""
import argparse
import os
import sys
from unsloth import FastModel
from datasets import load_dataset
from trl import DPOTrainer, DPOConfig
def fmt_dpo_example(ex, tokenizer):
"""Convert {prompt: msgs, chosen: msgs, rejected: msgs} into the
string-form DPOTrainer expects: prompt str, chosen str, rejected str.
The chosen/rejected strings are JUST the assistant turn content (no role markers).
"""
prompt_str = tokenizer.apply_chat_template(
ex["prompt"],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
chosen_content = ex["chosen"][0].get("content", "") if ex["chosen"] else ""
rejected_content = ex["rejected"][0].get("content", "") if ex["rejected"] else ""
return {
"prompt": prompt_str,
"chosen": chosen_content,
"rejected": rejected_content,
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--data", default="data/k0_dpo_curated.jsonl")
p.add_argument("--sft-adapter", default="adapters/k0_sft_adapter",
help="SFT adapter to load as policy starting point.")
p.add_argument("--output", default="adapters/k0_dpo_adapter")
p.add_argument("--max_seq", type=int, default=1024)
p.add_argument("--epochs", type=int, default=2)
p.add_argument("--lr", type=float, default=5e-6)
p.add_argument("--beta", type=float, default=0.1)
p.add_argument("--batch", type=int, default=4)
p.add_argument("--grad_accum", type=int, default=2)
p.add_argument("--skip-train", action="store_true")
args = p.parse_args()
if args.skip_train:
if not os.path.isdir(args.output):
print(f"[error] --skip-train set but DPO adapter dir not found: {args.output}", file=sys.stderr)
sys.exit(1)
print(f"[skip-train] DPO adapter at {args.output}")
return
if not os.path.isdir(args.sft_adapter):
print(f"[error] SFT adapter not found at {args.sft_adapter}", file=sys.stderr)
print(f" Run finetune_k0.py first.", file=sys.stderr)
sys.exit(1)
print(f"[load] base + SFT adapter: {args.sft_adapter}")
model, tokenizer = FastModel.from_pretrained(
model_name=args.sft_adapter,
max_seq_length=args.max_seq,
load_in_4bit=True,
full_finetuning=False,
)
# Re-attach LoRA params from SFT adapter as the policy.
# Reference model is the same SFT-snapshot frozen; DPOTrainer handles that internally.
print(f"[data] loading {args.data}")
ds = load_dataset("json", data_files=args.data, split="train")
print(f"[data] {len(ds)} preference pairs loaded")
ds = ds.map(lambda ex: fmt_dpo_example(ex, tokenizer))
print("[sample] first DPO example:")
print("-" * 60)
print("prompt: ", ds[0]["prompt"][:500])
print("chosen: ", ds[0]["chosen"][:300])
print("rejected:", ds[0]["rejected"][:300])
print("-" * 60)
dpo_config = DPOConfig(
output_dir=args.output,
per_device_train_batch_size=args.batch,
gradient_accumulation_steps=args.grad_accum,
warmup_ratio=0.1,
num_train_epochs=args.epochs,
learning_rate=args.lr,
logging_steps=5,
save_strategy="epoch",
save_total_limit=4,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="cosine",
seed=3407,
report_to="none",
max_length=args.max_seq,
max_prompt_length=args.max_seq // 2,
beta=args.beta,
bf16=True,
)
trainer = DPOTrainer(
model=model,
ref_model=None, # uses adapter-disabled forward as reference (PEFT trick)
args=dpo_config,
train_dataset=ds,
tokenizer=tokenizer,
)
print()
print(f"[train] DPO: {args.epochs} epochs × {len(ds)} pairs / "
f"effective_batch {args.batch * args.grad_accum}")
trainer.train()
print()
print(f"[save] writing DPO adapter to {args.output}")
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
print(f"[done] DPO stage complete")
if __name__ == "__main__":
main()

Xet Storage Details

Size:
4.9 kB
·
Xet hash:
415a0299afc6189c310c57471cc45545b8531da123e7cfca45154794aabfc357

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.