Code

#2
by YsK-dev - opened

U share model weigths and data but what about pipeline that you finetuned will u share also this nice job btw

TeichAI org

Let me get the files in order

TeichAI org

Ok so my file had a bunch of sanity checks for the dataset and a lot of bloat checking to make sure the dataset looked good and all. I had claude remake the file removing all these extra stuff (just to preface you on the ai-like format of the file). But more or less this is the exact file I used for training this model on both TeichAI/Gemini-3-Flash-Preview-VIBE and TeichAI/MiniMax-M2.1-Code-SFT:

import os
import re
import json
import hashlib
import multiprocessing as mp
from collections import Counter

# Environment setup (a couple of these are windows specific)
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
os.environ["HF_DATASETS_DISABLE_MULTIPROCESSING"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:true"

from datasets import load_dataset, concatenate_datasets
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig
import torch


# ============================================================================
# CONFIGURATION
# ============================================================================

# Model configuration
INPUT_MODEL = "unsloth/Qwen3-4B-Thinking-2507"
CHAT_TEMPLATE = "qwen3-thinking"
MAX_SEQ_LENGTH = 32768

# Dataset configuration (can be a single string or list of dataset names)
HF_DATASETS = ["TeichAI/MiniMax-M2.1-Code-SFT"]
DATASET_FILES = []  # Optional: local JSONL files

# Training configuration
MAX_STEPS = 2000
BATCH_SIZE = 2
GRAD_ACCUMULATION = 4
LEARNING_RATE = 2e-4
WARMUP_STEPS = 5
SAVE_STEPS = 200
SAVE_TOTAL_LIMIT = 20

# LoRA configuration
LORA_RANK = 32
LORA_ALPHA = 32
LORA_DROPOUT = 0

# Output configuration
HF_ACCOUNT = "TeichAI"
OUTPUT_MODEL_REPO = "Qwen3-4B-Thinking-Tools-SFT"
HF_TOKEN = None  # Set your token here or via environment variable
PRIVATE_UPLOAD = True
RESUME_FROM_CHECKPOINT = False


# ============================================================================
# DATASET LOADING & PREPROCESSING
# ============================================================================

def load_raw_dataset():
    """Load and concatenate datasets from HuggingFace and/or local files."""
    dsets = []

    # Load HuggingFace datasets
    hf_list = HF_DATASETS if isinstance(HF_DATASETS, list) else [HF_DATASETS]
    for hf_name in hf_list:
        if hf_name and hf_name.strip():
            print(f"Loading HuggingFace dataset: {hf_name}")
            dsets.append(load_dataset(hf_name.strip(), split="train"))

    # Load local JSONL files
    file_list = DATASET_FILES if isinstance(DATASET_FILES, list) else [DATASET_FILES]
    valid_files = [f.strip() for f in file_list if f and f.strip()]
    if valid_files:
        print(f"Loading local files: {valid_files}")
        dsets.append(load_dataset("json", data_files=valid_files, split="train"))

    if not dsets:
        raise ValueError("No datasets provided! Set HF_DATASETS or DATASET_FILES.")

    # Concatenate all datasets
    combined = dsets[0] if len(dsets) == 1 else concatenate_datasets(dsets)
    print(f"Total rows loaded: {len(combined)}")

    return combined


def deduplicate_dataset(dataset):
    """Remove duplicate prompts, keeping the first occurrence."""
    print(f"Rows before deduplication: {len(dataset)}")

    seen_hashes = set()

    def is_unique(example):
        messages = example.get("messages", [])
        if not isinstance(messages, list) or not messages:
            return False

        # Hash the prompt (everything except the last assistant message)
        prompt_msgs = messages
        if messages and isinstance(messages[-1], dict) and messages[-1].get("role") == "assistant":
            prompt_msgs = messages[:-1]

        content_str = json.dumps(prompt_msgs, sort_keys=True)
        content_hash = hashlib.md5(content_str.encode("utf-8")).hexdigest()

        if content_hash in seen_hashes:
            return False

        seen_hashes.add(content_hash)
        return True

    deduplicated = dataset.filter(is_unique, num_proc=1, load_from_cache_file=False)
    print(f"Rows after deduplication: {len(deduplicated)}")

    return deduplicated


def validate_messages(messages):
    """Validate message structure and return error reason if invalid."""
    if not isinstance(messages, list) or len(messages) == 0:
        return "messages_not_list_or_empty"

    last_non_system_role = None

    for msg in messages:
        if not isinstance(msg, dict):
            return "message_not_dict"

        role = msg.get("role")
        if role != "system":
            last_non_system_role = role

    if last_non_system_role != "assistant":
        return "does_not_end_with_assistant"

    return ""


def filter_invalid_messages(dataset):
    """Filter out invalid message sequences."""
    dataset = dataset.map(
        lambda ex: {"bad_reason": validate_messages(ex.get("messages"))},
        num_proc=1,
    )

    before = len(dataset)
    dataset = dataset.filter(lambda ex: ex["bad_reason"] == "", num_proc=1)
    dataset = dataset.remove_columns(["bad_reason"])

    filtered_count = before - len(dataset)
    if filtered_count > 0:
        print(f"Filtered out {filtered_count} invalid rows")

    return dataset


def prepare_dataset():
    """Load, clean, and prepare the training dataset."""
    print("=" * 80)
    print("LOADING DATASET")
    print("=" * 80)

    dataset = load_raw_dataset()
    dataset = deduplicate_dataset(dataset)
    dataset = dataset.shuffle(seed=42)
    dataset = filter_invalid_messages(dataset)

    print(f"\nFinal dataset size: {len(dataset)} rows")
    return dataset


# ============================================================================
# MODEL LOADING & TRAINING
# ============================================================================

def load_model_and_tokenizer():
    """Load the base model and tokenizer with LoRA configuration."""
    print("\n" + "=" * 80)
    print("LOADING MODEL")
    print("=" * 80)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=INPUT_MODEL,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=True,
        load_in_8bit=False,
        full_finetuning=False,
        token=HF_TOKEN,
        attn_implementation="eager",
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_RANK,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
    )

    tokenizer = get_chat_template(tokenizer, chat_template=CHAT_TEMPLATE)

    print(f"Model loaded: {INPUT_MODEL}")
    print(f"LoRA rank: {LORA_RANK}, alpha: {LORA_ALPHA}")
    print(f"Max sequence length: {MAX_SEQ_LENGTH}")

    return model, tokenizer


def formatting_prompts_func(examples, tokenizer):
    """Format messages into training text using chat template."""
    convos = examples["messages"]
    tools_list = examples.get("tools", [None] * len(convos))

    texts = []
    for convo, tools in zip(convos, tools_list):
        # Apply chat template with tools if available
        if tools:
            text = tokenizer.apply_chat_template(
                convo,
                tools=tools,
                tokenize=False,
                add_generation_prompt=False,
            )
        else:
            text = tokenizer.apply_chat_template(
                convo,
                tokenize=False,
                add_generation_prompt=False,
            )
        texts.append(text)

    return {"text": texts}


def train_model(model, tokenizer, dataset):
    """Train the model using SFTTrainer."""
    print("\n" + "=" * 80)
    print("PREPARING TRAINING")
    print("=" * 80)

    # Format dataset
    train_dataset = dataset.map(
        lambda ex: formatting_prompts_func(ex, tokenizer),
        batched=True,
    )

    # Create trainer
    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=None,
        dataset_num_proc=1,
        args=SFTConfig(
            dataset_text_field="text",
            max_length=MAX_SEQ_LENGTH,
            per_device_train_batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRAD_ACCUMULATION,
            warmup_steps=WARMUP_STEPS,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            logging_steps=1,
            optim="paged_adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3447,
            report_to="none",
            dataloader_num_workers=0,
            output_dir="outputs",
            save_strategy="steps",
            save_steps=SAVE_STEPS,
            save_total_limit=SAVE_TOTAL_LIMIT,
        ),
    )

    # Print GPU stats
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"GPU: {gpu_stats.name}")
    print(f"Max memory: {max_memory} GB")
    print(f"Reserved memory: {start_gpu_memory} GB")

    # Configure token IDs
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.bos_token_id = tokenizer.bos_token_id
    model.generation_config.eos_token_id = tokenizer.eos_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.bos_token_id = tokenizer.bos_token_id

    print("\n" + "=" * 80)
    print("STARTING TRAINING")
    print("=" * 80)

    # Train
    trainer_stats = trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

    # Print training stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory / max_memory * 100, 3)
    lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)

    print("\n" + "=" * 80)
    print("TRAINING COMPLETE")
    print("=" * 80)
    print(f"Training time: {trainer_stats.metrics['train_runtime']:.2f} seconds")
    print(f"Training time: {trainer_stats.metrics['train_runtime']/60:.2f} minutes")
    print(f"Peak reserved memory: {used_memory} GB")
    print(f"Peak reserved memory for training: {used_memory_for_lora} GB")
    print(f"Peak reserved memory % of max: {used_percentage}%")
    print(f"Peak reserved memory for training % of max: {lora_percentage}%")

    return model, tokenizer


def save_model(model, tokenizer):
    """Save the trained model to HuggingFace Hub."""
    print("\n" + "=" * 80)
    print("SAVING MODEL")
    print("=" * 80)

    # Save merged model
    print(f"Pushing merged model to {HF_ACCOUNT}/{OUTPUT_MODEL_REPO}...")
    model.push_to_hub_merged(
        f"{HF_ACCOUNT}/{OUTPUT_MODEL_REPO}",
        tokenizer,
        save_method="merged_16bit",
        token=HF_TOKEN,
        private=PRIVATE_UPLOAD,
    )
    print("✓ Merged model saved")

    # Save GGUF quantized versions
    print(f"Pushing GGUF models to {HF_ACCOUNT}/{OUTPUT_MODEL_REPO}-GGUF...")
    model.push_to_hub_gguf(
        f"{HF_ACCOUNT}/{OUTPUT_MODEL_REPO}-GGUF",
        tokenizer,
        quantization_method=["bf16", "f16", "q8_0"],
        token=HF_TOKEN,
        private=PRIVATE_UPLOAD,
    )
    print("✓ GGUF models saved")

    print("\n" + "=" * 80)
    print("ALL DONE!")
    print("=" * 80)


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Main training pipeline."""
    dataset = prepare_dataset()
    model, tokenizer = load_model_and_tokenizer()
    model, tokenizer = train_model(model, tokenizer, dataset)
    save_model(model, tokenizer)


if __name__ == "__main__":
    mp.freeze_support()
    main()

Maybe you could upload the code you use for each finetune on docs.teichai.com

Sign up or log in to comment