| """Consolidated Training: all 3 speculative-action models in one A100 job. |
| |
| Trains sequentially with explicit cleanup between models: |
| 1. Qwen3-1.7B proposer (SFT on action prediction) |
| 2. Qwen3-4B verifier (SFT on ACCEPT/REJECT) |
| 3. Qwen3-8B proposer (SFT on action prediction) |
| |
| Each model is: loaded β trained β evaluated β saved+pushed β deleted β CUDA cache cleared. |
| |
| Requires: transformers>=4.51, trl, torch, datasets, accelerate, peft, huggingface_hub |
| """ |
|
|
| import gc |
| import torch |
|
|
| HUB = "narcolepticchicken" |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("1/3: TRAINING 1.7B PROPOSER") |
| print("=" * 60) |
|
|
| from datasets import load_dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
| sft_ds = load_dataset(f"{HUB}/speculative-sft-v3-main") |
| print(f"SFT data: {len(sft_ds['train'])} train, {len(sft_ds['test'])} test") |
|
|
| args_1b = SFTConfig( |
| output_dir="./out_1b", |
| hub_model_id=f"{HUB}/speculative-proposer-v3-1.7b", |
| max_length=2048, |
| packing=False, |
| learning_rate=2e-5, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| num_train_epochs=3, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=5, |
| logging_first_step=True, |
| save_strategy="epoch", |
| push_to_hub=True, |
| disable_tqdm=True, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model="Qwen/Qwen3-1.7B", |
| args=args_1b, |
| train_dataset=sft_ds["train"], |
| eval_dataset=sft_ds["test"], |
| ) |
|
|
| trainer.train() |
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| metrics = trainer.evaluate() |
| print(f"1.7B eval loss: {metrics.get('eval_loss', 'N/A')}") |
|
|
| del trainer |
| gc.collect() |
| torch.cuda.empty_cache() |
| print("1.7B proposer β") |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("2/3: TRAINING 4B VERIFIER") |
| print("=" * 60) |
|
|
| from datasets import Dataset |
|
|
| VERIFIER_SYSTEM = ( |
| "You are an action verifier. Given conversation context and a proposed next action, " |
| "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT." |
| ) |
|
|
| verif_raw = load_dataset(f"{HUB}/speculative-verifier-v3-main") |
|
|
| sft_rows = [] |
| for split_name in ["train", "test"]: |
| for row in verif_raw[split_name]: |
| ctx = row["context"] |
| proposal = row["proposal"] |
| label = row["label"] |
| answer = "ACCEPT" if label == 1 else "REJECT" |
|
|
| msgs = [{"role": "system", "content": VERIFIER_SYSTEM}] |
| for m in ctx[-6:]: |
| msgs.append({"role": m["role"], "content": str(m["content"])[:400]}) |
| msgs.append({ |
| "role": "user", |
| "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?" |
| }) |
| msgs.append({"role": "assistant", "content": answer}) |
| sft_rows.append({"messages": msgs, "split": split_name}) |
|
|
| v_train = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "train"]) |
| v_test = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "test"]) |
| print(f"Verifier SFT: {len(v_train)} train, {len(v_test)} test") |
|
|
| args_4b = SFTConfig( |
| output_dir="./out_4b", |
| hub_model_id=f"{HUB}/speculative-verifier-v3-4b", |
| max_length=2048, |
| packing=False, |
| learning_rate=2e-5, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| num_train_epochs=2, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=10, |
| logging_first_step=True, |
| save_strategy="epoch", |
| push_to_hub=True, |
| disable_tqdm=True, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model="Qwen/Qwen3-4B", |
| args=args_4b, |
| train_dataset=v_train, |
| eval_dataset=v_test, |
| ) |
|
|
| trainer.train() |
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| metrics = trainer.evaluate() |
| print(f"4B verifier eval loss: {metrics.get('eval_loss', 'N/A')}") |
|
|
| del trainer |
| gc.collect() |
| torch.cuda.empty_cache() |
| print("4B verifier β") |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("3/3: TRAINING 8B PROPOSER") |
| print("=" * 60) |
|
|
| args_8b = SFTConfig( |
| output_dir="./out_8b", |
| hub_model_id=f"{HUB}/speculative-proposer-v3-8b", |
| max_length=2048, |
| packing=False, |
| learning_rate=2e-5, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| num_train_epochs=3, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=5, |
| logging_first_step=True, |
| save_strategy="epoch", |
| push_to_hub=True, |
| disable_tqdm=True, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model="Qwen/Qwen3-8B", |
| args=args_8b, |
| train_dataset=sft_ds["train"], |
| eval_dataset=sft_ds["test"], |
| ) |
|
|
| trainer.train() |
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| metrics = trainer.evaluate() |
| print(f"8B eval loss: {metrics.get('eval_loss', 'N/A')}") |
|
|
| print("\n" + "=" * 60) |
| print("ALL THREE MODELS TRAINED SUCCESSFULLY!") |
| print(f" {HUB}/speculative-proposer-v3-1.7b") |
| print(f" {HUB}/speculative-verifier-v3-4b") |
| print(f" {HUB}/speculative-proposer-v3-8b") |
| print("=" * 60) |
|
|