cfhot-weights / code /training_pipelines /07c_qwen3b_CONTINUE.py
LoganResearch's picture
🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified
raw
history blame contribute delete
988 Bytes
#!/usr/bin/env python3
"""
CONTINUE QWEN TRAINING: 3000 → 6000 steps
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from datasets import load_dataset
import os, time, random, json
from dataclasses import dataclass, field
from typing import List
CKPT = "./results/qwen3b_repetition_v2_fixed/ckpt_3000"
OUT = "./results/qwen3b_repetition_v2_continued"
@dataclass
class Config:
model_path: str = "Qwen/Qwen2.5-3B"
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
d_fiber: int = 16
d_control: int = 64
start_step: int = 3000
max_steps: int = 6000
batch_size: int = 1
grad_accum: int = 8
max_length: int = 256
lr_lora: float = 1e-5
lr_predictor: float = 5e-5
weight_decay: float = 0.01
rep_window: int = 32
log_every: int = 10
save_every: int = 500
eval_every: int = 200