Unsloth code
#2
by PSM24 - opened
Would you mind sharing your Unsloth code for training GPT-OSS like this?
Here you go, let me know if you have any questions.
# %%
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gpt-oss-20b",
dtype = dtype, # None for auto detection
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
full_finetuning = False, # [NEW!] We have full finetuning now!
# token = "hf_...", # use one if using gated models
)
# %%
model = FastLanguageModel.get_peft_model(
model,
r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
# %%
def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
# %%
from datasets import load_dataset
dataset = load_dataset("TeichAI/glm-4.6-250x", split="train")
# %%
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)
# %%
from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
args = SFTConfig(
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
warmup_steps = 5,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 1250,
learning_rate = 2e-4,
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use TrackIO/WandB etc
),
)
# %%
# The 'labels' key is not present because SFTTrainer's masking was not used.
# We can extract the assistant's message from the 'text' column instead.
# Keep in mind that this is not as reliable as the assistant only loss provided by unsloth, just couldn't get it to work...
import re
def extract_assistant_message(text):
# This regex looks for the pattern <|start|>assistant<|message|> followed by any characters
# (non-greedily) until <|end|> or <|return|>
match = re.search(r'<\|start\|>assistant<\|message\|>(.*?)(<\|end\|>|<\|return\|>)', text, re.DOTALL)
if match:
return match.group(1).strip()
return "Assistant message not found."
# %%
trainer_stats = trainer.train()
armand0e changed discussion status to closed