"""Consolidated Training: all 3 speculative-action models in one A100 job. Trains sequentially with explicit cleanup between models: 1. Qwen3-1.7B proposer (SFT on action prediction) 2. Qwen3-4B verifier (SFT on ACCEPT/REJECT) 3. Qwen3-8B proposer (SFT on action prediction) Each model is: loaded → trained → evaluated → saved+pushed → deleted → CUDA cache cleared. Requires: transformers>=4.51, trl, torch, datasets, accelerate, peft, huggingface_hub """ import gc import torch HUB = "narcolepticchicken" # ═══════════════════════════════════════════════════════════════ # 1. TRAIN 1.7B PROPOSER # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 60) print("1/3: TRAINING 1.7B PROPOSER") print("=" * 60) from datasets import load_dataset from trl import SFTConfig, SFTTrainer sft_ds = load_dataset(f"{HUB}/speculative-sft-v3-main") print(f"SFT data: {len(sft_ds['train'])} train, {len(sft_ds['test'])} test") args_1b = SFTConfig( output_dir="./out_1b", hub_model_id=f"{HUB}/speculative-proposer-v3-1.7b", max_length=2048, packing=False, learning_rate=2e-5, per_device_train_batch_size=4, gradient_accumulation_steps=4, num_train_epochs=3, bf16=True, gradient_checkpointing=True, logging_steps=5, logging_first_step=True, save_strategy="epoch", push_to_hub=True, disable_tqdm=True, report_to="none", ) trainer = SFTTrainer( model="Qwen/Qwen3-1.7B", args=args_1b, train_dataset=sft_ds["train"], eval_dataset=sft_ds["test"], ) trainer.train() trainer.save_model() trainer.push_to_hub() metrics = trainer.evaluate() print(f"1.7B eval loss: {metrics.get('eval_loss', 'N/A')}") del trainer gc.collect() torch.cuda.empty_cache() print("1.7B proposer ✓") # ═══════════════════════════════════════════════════════════════ # 2. TRAIN 4B VERIFIER # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 60) print("2/3: TRAINING 4B VERIFIER") print("=" * 60) from datasets import Dataset VERIFIER_SYSTEM = ( "You are an action verifier. Given conversation context and a proposed next action, " "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT." ) verif_raw = load_dataset(f"{HUB}/speculative-verifier-v3-main") sft_rows = [] for split_name in ["train", "test"]: for row in verif_raw[split_name]: ctx = row["context"] proposal = row["proposal"] label = row["label"] answer = "ACCEPT" if label == 1 else "REJECT" msgs = [{"role": "system", "content": VERIFIER_SYSTEM}] for m in ctx[-6:]: msgs.append({"role": m["role"], "content": str(m["content"])[:400]}) msgs.append({ "role": "user", "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?" }) msgs.append({"role": "assistant", "content": answer}) sft_rows.append({"messages": msgs, "split": split_name}) v_train = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "train"]) v_test = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "test"]) print(f"Verifier SFT: {len(v_train)} train, {len(v_test)} test") args_4b = SFTConfig( output_dir="./out_4b", hub_model_id=f"{HUB}/speculative-verifier-v3-4b", max_length=2048, packing=False, learning_rate=2e-5, per_device_train_batch_size=4, gradient_accumulation_steps=4, num_train_epochs=2, bf16=True, gradient_checkpointing=True, logging_steps=10, logging_first_step=True, save_strategy="epoch", push_to_hub=True, disable_tqdm=True, report_to="none", ) trainer = SFTTrainer( model="Qwen/Qwen3-4B", args=args_4b, train_dataset=v_train, eval_dataset=v_test, ) trainer.train() trainer.save_model() trainer.push_to_hub() metrics = trainer.evaluate() print(f"4B verifier eval loss: {metrics.get('eval_loss', 'N/A')}") del trainer gc.collect() torch.cuda.empty_cache() print("4B verifier ✓") # ═══════════════════════════════════════════════════════════════ # 3. TRAIN 8B PROPOSER # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 60) print("3/3: TRAINING 8B PROPOSER") print("=" * 60) args_8b = SFTConfig( output_dir="./out_8b", hub_model_id=f"{HUB}/speculative-proposer-v3-8b", max_length=2048, packing=False, learning_rate=2e-5, per_device_train_batch_size=2, # smaller batch to fit 8B on 80GB gradient_accumulation_steps=8, # effective batch = 16 (same as others) num_train_epochs=3, bf16=True, gradient_checkpointing=True, logging_steps=5, logging_first_step=True, save_strategy="epoch", push_to_hub=True, disable_tqdm=True, report_to="none", ) trainer = SFTTrainer( model="Qwen/Qwen3-8B", args=args_8b, train_dataset=sft_ds["train"], eval_dataset=sft_ds["test"], ) trainer.train() trainer.save_model() trainer.push_to_hub() metrics = trainer.evaluate() print(f"8B eval loss: {metrics.get('eval_loss', 'N/A')}") print("\n" + "=" * 60) print("ALL THREE MODELS TRAINED SUCCESSFULLY!") print(f" {HUB}/speculative-proposer-v3-1.7b") print(f" {HUB}/speculative-verifier-v3-4b") print(f" {HUB}/speculative-proposer-v3-8b") print("=" * 60)