🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| CONTINUE QWEN TRAINING: 3000 → 6000 steps | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from datasets import load_dataset | |
| import os, time, random, json | |
| from dataclasses import dataclass, field | |
| from typing import List | |
| CKPT = "./results/qwen3b_repetition_v2_fixed/ckpt_3000" | |
| OUT = "./results/qwen3b_repetition_v2_continued" | |
| class Config: | |
| model_path: str = "Qwen/Qwen2.5-3B" | |
| probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) | |
| d_fiber: int = 16 | |
| d_control: int = 64 | |
| start_step: int = 3000 | |
| max_steps: int = 6000 | |
| batch_size: int = 1 | |
| grad_accum: int = 8 | |
| max_length: int = 256 | |
| lr_lora: float = 1e-5 | |
| lr_predictor: float = 5e-5 | |
| weight_decay: float = 0.01 | |
| rep_window: int = 32 | |
| log_every: int = 10 | |
| save_every: int = 500 | |
| eval_every: int = 200 | |