Unsloth code

#2
by PSM24 - opened

Would you mind sharing your Unsloth code for training GPT-OSS like this?

Here you go, let me know if you have any questions.

# %%
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # None for auto detection
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)
# %%
model = FastLanguageModel.get_peft_model(
    model,
    r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# %%
def formatting_prompts_func(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

# %%
from datasets import load_dataset
dataset = load_dataset("TeichAI/glm-4.6-250x", split="train")

# %%
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)

# %%
from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 1250,
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

# %%
# The 'labels' key is not present because SFTTrainer's masking was not used.
# We can extract the assistant's message from the 'text' column instead.
# Keep in mind that this is not as reliable as the assistant only loss provided by unsloth, just couldn't get it to work...
import re

def extract_assistant_message(text):
    # This regex looks for the pattern <|start|>assistant<|message|> followed by any characters
    # (non-greedily) until <|end|> or <|return|>
    

    match = re.search(r'<\|start\|>assistant<\|message\|>(.*?)(<\|end\|>|<\|return\|>)', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return "Assistant message not found."

# %%
trainer_stats = trainer.train()
armand0e changed discussion status to closed

Sign up or log in to comment