#!/usr/bin/env python3 """ CONTINUE QWEN TRAINING: 3000 → 6000 steps """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel from datasets import load_dataset import os, time, random, json from dataclasses import dataclass, field from typing import List CKPT = "./results/qwen3b_repetition_v2_fixed/ckpt_3000" OUT = "./results/qwen3b_repetition_v2_continued" @dataclass class Config: model_path: str = "Qwen/Qwen2.5-3B" probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) d_fiber: int = 16 d_control: int = 64 start_step: int = 3000 max_steps: int = 6000 batch_size: int = 1 grad_accum: int = 8 max_length: int = 256 lr_lora: float = 1e-5 lr_predictor: float = 5e-5 weight_decay: float = 0.01 rep_window: int = 32 log_every: int = 10 save_every: int = 500 eval_every: int = 200