Spaces:
Sleeping
Sleeping
| """ | |
| SFT prime: teach Qwen 2.5-3B the teacher's CoT-then-answer format. | |
| This is Stage 2 of Algorithm Distillation. We've already collected | |
| teacher trajectories (Stage 1). Here we fine-tune the student on the | |
| teacher's full responses β `<reasoning>...</reasoning>\nS M W ACTION_NAME` β | |
| so the student learns BOTH the format and the reasoning pattern that | |
| produced each answer. | |
| After this stage, the student should beat heuristic baselines on the | |
| v2 grader (which awards 0.20 for belief_accuracy). GRPO refinement is | |
| optional β only if the SFT'd model regresses on something. | |
| Usage (from rhythm_env root): | |
| python training/sft_prime.py \ | |
| --teacher_jsonls data/teacher_30ep_validation.jsonl \ | |
| data/teacher_indist_30_99.jsonl \ | |
| data/teacher_ood_10000_10049.jsonl \ | |
| --output_dir outputs/rhythm-env-sft-primed \ | |
| --max_steps 600 \ | |
| --epochs 2 | |
| Designed to run on HF Jobs with a10g-large flavor. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| # The teacher's system prompt is the canonical contract β student must learn | |
| # to respond to this exact prompt. Imported from the teacher script for SSOT. | |
| from scripts.generate_teacher_trajectories import TEACHER_SYSTEM_PROMPT | |
| def load_teacher_dataset(jsonl_paths: list[str], drop_parse_fails: bool = True) -> list[dict]: | |
| """Read teacher JSONL files and return list of {prompt, response} pairs. | |
| Each input row is one step from one teacher episode. We turn it into a | |
| chat-format SFT example: messages=[system, user] β completion=response. | |
| Steps where the teacher's response failed to parse are dropped (we | |
| don't want to teach the student bad outputs). | |
| """ | |
| pairs: list[dict] = [] | |
| n_total = 0 | |
| n_dropped = 0 | |
| for path in jsonl_paths: | |
| with open(path) as f: | |
| for line in f: | |
| row = json.loads(line) | |
| n_total += 1 | |
| if drop_parse_fails and row.get("parse_failed"): | |
| n_dropped += 1 | |
| continue | |
| resp = row.get("teacher_response", "") | |
| if not resp or not resp.strip(): | |
| n_dropped += 1 | |
| continue | |
| pairs.append({ | |
| "messages": [ | |
| {"role": "system", "content": TEACHER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": row["user_prompt"]}, | |
| {"role": "assistant", "content": resp}, | |
| ], | |
| }) | |
| print(f"Loaded {len(pairs)} SFT examples ({n_dropped}/{n_total} dropped: " | |
| f"parse-failed or empty)") | |
| return pairs | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--teacher_jsonls", nargs="+", required=True, | |
| help="One or more teacher trajectory JSONL files") | |
| parser.add_argument("--output_dir", type=str, default="outputs/rhythm-env-sft-primed") | |
| parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct") | |
| parser.add_argument("--epochs", type=int, default=2, | |
| help="SFT epochs over the dataset (2 is plenty for ~3000 examples)") | |
| parser.add_argument("--max_steps", type=int, default=-1, | |
| help="Override epochs with a step count (-1 = use epochs)") | |
| parser.add_argument("--lora_rank", type=int, default=16) | |
| parser.add_argument("--learning_rate", type=float, default=2e-4) | |
| parser.add_argument("--max_seq_length", type=int, default=2048, | |
| help="Must fit system + user + CoT response. ~600 user + ~120 CoT + ~10 ans + slack") | |
| parser.add_argument("--per_device_batch_size", type=int, default=1) | |
| parser.add_argument("--grad_accum", type=int, default=8, | |
| help="Effective batch size = per_device * grad_accum") | |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) | |
| parser.add_argument("--save_method", type=str, default="merged_16bit", | |
| choices=["lora", "merged_16bit", "merged_4bit"]) | |
| args = parser.parse_args() | |
| # ---- 1. Load + format the dataset ---- | |
| print("=" * 60) | |
| print("Step 1: Loading teacher dataset") | |
| print("=" * 60) | |
| pairs = load_teacher_dataset(args.teacher_jsonls) | |
| if not pairs: | |
| sys.exit("ERROR: no SFT examples loaded β check JSONL paths") | |
| from datasets import Dataset | |
| raw_ds = Dataset.from_list(pairs) | |
| print(f"Dataset size: {len(raw_ds)} examples") | |
| # ---- 2. Load Qwen base via Unsloth ---- | |
| print("\n" + "=" * 60) | |
| print(f"Step 2: Loading base model {args.model_name}") | |
| print("=" * 60) | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model_name, | |
| load_in_4bit=True, | |
| max_seq_length=args.max_seq_length, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=args.lora_rank, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=args.lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| print(f"LoRA rank {args.lora_rank}, alpha {args.lora_rank * 2}") | |
| # ---- 3. Map to chat-template strings + tokenize ---- | |
| print("\n" + "=" * 60) | |
| print("Step 3: Preparing dataset") | |
| print("=" * 60) | |
| def format_example(ex): | |
| text = tokenizer.apply_chat_template( | |
| ex["messages"], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| return {"text": text} | |
| ds = raw_ds.map(format_example, remove_columns=raw_ds.column_names) | |
| print("Sample formatted text (first 800 chars):") | |
| print(ds[0]["text"][:800]) | |
| print("...") | |
| # ---- 4. SFTTrainer ---- | |
| print("\n" + "=" * 60) | |
| print("Step 4: Configuring SFTTrainer") | |
| print("=" * 60) | |
| from trl import SFTConfig, SFTTrainer | |
| sft_kwargs = dict( | |
| per_device_train_batch_size=args.per_device_batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| warmup_ratio=args.warmup_ratio, | |
| lr_scheduler_type="cosine", | |
| optim="adamw_8bit", | |
| weight_decay=0.001, | |
| logging_steps=5, | |
| save_strategy="no", | |
| report_to="none", | |
| output_dir=args.output_dir, | |
| max_seq_length=args.max_seq_length, | |
| dataset_text_field="text", | |
| packing=False, | |
| ) | |
| if args.max_steps > 0: | |
| sft_kwargs["max_steps"] = args.max_steps | |
| else: | |
| sft_kwargs["num_train_epochs"] = args.epochs | |
| sft_config = SFTConfig(**sft_kwargs) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=ds, | |
| args=sft_config, | |
| ) | |
| print(f"Effective batch size: {args.per_device_batch_size * args.grad_accum}") | |
| if args.max_steps > 0: | |
| print(f"max_steps: {args.max_steps}") | |
| else: | |
| print(f"epochs: {args.epochs} β ~{len(ds) * args.epochs // (args.per_device_batch_size * args.grad_accum)} steps") | |
| # ---- 5. Train ---- | |
| print("\n" + "=" * 60) | |
| print("Step 5: Training") | |
| print("=" * 60) | |
| trainer.train() | |
| # ---- 6. Save ---- | |
| print("\n" + "=" * 60) | |
| print("Step 6: Saving model") | |
| print("=" * 60) | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| if args.save_method == "lora": | |
| model.save_pretrained(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| else: | |
| model.save_pretrained_merged( | |
| args.output_dir, | |
| tokenizer, | |
| save_method=args.save_method, | |
| ) | |
| # Save log_history for plot_from_log.py | |
| log_path = os.path.join(args.output_dir, "log_history.json") | |
| with open(log_path, "w") as f: | |
| json.dump(trainer.state.log_history, f, indent=2) | |
| # Save training config | |
| config_path = os.path.join(args.output_dir, "training_config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(vars(args), f, indent=2) | |
| print(f"\nSaved SFT-primed model to: {args.output_dir}") | |
| print(f"Log history: {log_path}") | |
| print(f"Training config: {config_path}") | |
| print() | |
| print("Next: python training/inference_eval.py --model_path " + args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |