reading-steiner-qwen3.5-2b / train_mini_test.py
OmAlve's picture
Upload train_mini_test.py
f578784 verified
"""
Quick smoke-test script to verify your environment works before full training.
Runs on just 20 training + 5 eval samples with 10 steps.
"""
import torch
from datasets import load_dataset
from unsloth import FastLanguageModel
from trl import SFTConfig, SFTTrainer
from transformers import AutoTokenizer, set_seed
set_seed(3407)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-2B", trust_remote_code=True)
def preprocess(example):
text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
return {"text": text}
print("Loading dataset...")
train_dataset = load_dataset("OmAlve/reading-steiner-data", split="train[:20]")
eval_dataset = load_dataset("OmAlve/reading-steiner-data", split="eval[:5]")
train_dataset = train_dataset.map(preprocess, remove_columns=["messages"])
eval_dataset = eval_dataset.map(preprocess, remove_columns=["messages"])
print("Loading model with Unsloth...")
model, _ = FastLanguageModel.from_pretrained(
model_name="Qwen/Qwen3.5-2B",
max_seq_length=4096,
dtype=torch.float16,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing=True,
random_state=3407,
)
args = SFTConfig(
output_dir="./mini_out",
max_length=4096,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
max_steps=10,
learning_rate=2e-4,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=2,
logging_first_step=True,
eval_strategy="steps",
eval_steps=5,
save_strategy="no",
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
gradient_checkpointing=True,
disable_tqdm=True,
report_to=["trackio"],
seed=3407,
push_to_hub=False,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
print("Starting mini training run...")
trainer.train()
print("Mini test passed! Your environment is ready for full training.")