| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Fine-tune Ministral-3B on n8n-workflows-thinking dataset for SFT. |
| This script trains the model to generate n8n workflows with chain-of-thought reasoning. |
| """ |
|
|
| import os |
| import torch |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
|
|
| |
| |
| |
| MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" |
| DATASET_NAME = "stmasson/n8n-workflows-thinking" |
| OUTPUT_MODEL = "stmasson/mistral-7b-n8n-workflows" |
| MAX_SEQ_LENGTH = 4096 |
|
|
| |
| trackio.init(project="mistral-7b-n8n-sft") |
|
|
| print(f"Loading tokenizer from {MODEL_NAME}...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| print(f"Loading dataset {DATASET_NAME}...") |
| |
| dataset = load_dataset( |
| "json", |
| data_files={ |
| "train": f"hf://datasets/{DATASET_NAME}/data/sft/train.jsonl", |
| "validation": f"hf://datasets/{DATASET_NAME}/data/sft/validation.jsonl" |
| } |
| ) |
| train_dataset = dataset["train"] |
| eval_dataset = dataset["validation"] |
| print(f"Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval examples") |
|
|
| |
| print("Preprocessing dataset with chat template...") |
| def preprocess_function(example): |
| """Apply chat template to messages.""" |
| text = tokenizer.apply_chat_template( |
| example["messages"], |
| tokenize=False, |
| add_generation_prompt=False |
| ) |
| return {"text": text} |
|
|
| train_dataset = train_dataset.map( |
| preprocess_function, |
| remove_columns=train_dataset.column_names, |
| desc="Applying chat template to train" |
| ) |
| eval_dataset = eval_dataset.map( |
| preprocess_function, |
| remove_columns=eval_dataset.column_names, |
| desc="Applying chat template to eval" |
| ) |
| print(f"Preprocessed: {len(train_dataset)} train, {len(eval_dataset)} eval") |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| print(f"Loading model {MODEL_NAME} with 4-bit quantization...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| quantization_config=bnb_config, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| attn_implementation="sdpa", |
| ) |
| model = prepare_model_for_kbit_training(model) |
|
|
| |
| lora_config = LoraConfig( |
| r=64, |
| lora_alpha=128, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| |
| training_args = SFTConfig( |
| output_dir="./ministral-3b-n8n-sft", |
| num_train_epochs=2, |
| per_device_train_batch_size=1, |
| per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, |
| learning_rate=1e-4, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.05, |
| weight_decay=0.01, |
| logging_steps=10, |
| save_strategy="steps", |
| save_steps=200, |
| eval_strategy="steps", |
| eval_steps=200, |
| save_total_limit=3, |
| bf16=True, |
| gradient_checkpointing=True, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| max_length=MAX_SEQ_LENGTH, |
| packing=False, |
| dataset_text_field="text", |
| |
| push_to_hub=True, |
| hub_model_id=OUTPUT_MODEL, |
| hub_strategy="checkpoint", |
| hub_private_repo=False, |
| |
| report_to="trackio", |
| run_name="ministral-3b-n8n-sft", |
| ) |
|
|
| |
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| print("Starting training...") |
| |
| try: |
| from huggingface_hub import hf_hub_download, list_repo_files |
| files = list_repo_files(OUTPUT_MODEL) |
| if "last-checkpoint" in str(files) or "adapter_model.safetensors" in files: |
| print(f"Found existing checkpoint on Hub, downloading to resume...") |
| |
| import os |
| os.makedirs("./resume-checkpoint", exist_ok=True) |
| for f in ["adapter_model.safetensors", "adapter_config.json", "trainer_state.json", "training_args.bin"]: |
| try: |
| hf_hub_download(OUTPUT_MODEL, f, local_dir="./resume-checkpoint") |
| except: |
| pass |
| if os.path.exists("./resume-checkpoint/trainer_state.json"): |
| trainer.train(resume_from_checkpoint="./resume-checkpoint") |
| else: |
| trainer.train() |
| else: |
| trainer.train() |
| except Exception as e: |
| print(f"Could not resume from checkpoint: {e}, starting fresh...") |
| trainer.train() |
|
|
| |
| print("Saving final model...") |
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| print(f"\nTraining complete!") |
| print(f"Model saved to: https://huggingface.co/{OUTPUT_MODEL}") |
| print(f"Training metrics: https://huggingface.co/spaces/stmasson/trackio") |
|
|