| import os |
| import torch |
| from datasets import load_dataset, Dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| ) |
| from peft import LoraConfig |
| from trl.trainer.sft_trainer import SFTTrainer |
| from trl.trainer.sft_config import SFTConfig |
| import argparse |
| import pandas as pd |
|
|
| |
| tokenizer = None |
|
|
|
|
| def format_instruction(sample): |
| |
| label_str = "Phishing" if sample["phishing"] == 1 else "Safe" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": f"Classify the following email text as either 'Safe' or 'Phishing'. Respond with only one word: 'Safe' or 'Phishing'.\n\nEmail text: {sample['text']}\n\nClassification:", |
| }, |
| {"role": "assistant", "content": label_str}, |
| ] |
| |
| return ( |
| {"text": tokenizer.apply_chat_template(messages, tokenize=False)} |
| if tokenizer |
| else {"text": ""} |
| ) |
|
|
|
|
| def main(args): |
| global tokenizer |
| device = ( |
| "cuda" |
| if torch.cuda.is_available() |
| else "mps" |
| if torch.backends.mps.is_available() |
| else "cpu" |
| ) |
| print(f"Using device: {device}") |
|
|
| model_id = args.model_id |
| print(f"Loading tokenizer and model: {model_id}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| device_map=device if device != "mps" else None, |
| ) |
| if device == "mps": |
| model.to("mps") |
|
|
| |
| peft_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| target_modules=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ], |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| |
| print(f"Loading data from {args.dataset_name}...") |
| if os.path.exists(args.dataset_name): |
| train_df = pd.read_csv(os.path.join(args.dataset_name, "train.csv")) |
| val_df = pd.read_csv(os.path.join(args.dataset_name, "val.csv")) |
| if args.quick_test: |
| train_df = train_df.head(100) |
| val_df = val_df.head(20) |
| train_dataset = Dataset.from_pandas(train_df) |
| val_dataset = Dataset.from_pandas(val_df) |
| else: |
| dataset = load_dataset(args.dataset_name) |
| train_dataset = dataset["train"] |
| val_dataset = dataset["validation"] if "validation" in dataset else None |
|
|
| |
| print("Formatting datasets...") |
| train_dataset = train_dataset.map(format_instruction) |
| if val_dataset: |
| val_dataset = val_dataset.map(format_instruction) |
|
|
| |
| sft_config = SFTConfig( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| learning_rate=args.lr, |
| logging_steps=10, |
| num_train_epochs=args.epochs, |
| max_steps=args.max_steps, |
| eval_strategy="steps" if val_dataset else "no", |
| eval_steps=100, |
| save_strategy="steps", |
| save_steps=100, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.1, |
| bf16=torch.cuda.is_available(), |
| push_to_hub=args.push_to_hub, |
| report_to="tensorboard" if not args.no_report else "none", |
| remove_unused_columns=False, |
| dataset_text_field="text", |
| max_length=args.max_seq_length, |
| ) |
|
|
| |
| trainer = SFTTrainer( |
| model=model, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| peft_config=peft_config, |
| processing_class=tokenizer, |
| args=sft_config, |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
|
|
| print(f"Saving model to {args.output_dir}") |
| trainer.save_model(args.output_dir) |
| if args.push_to_hub: |
| trainer.push_to_hub() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model_id", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct" |
| ) |
| parser.add_argument("--dataset_name", type=str, default="data/") |
| parser.add_argument("--output_dir", type=str, default="models/smollm2-phish-sft") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--grad_accum", type=int, default=4) |
| parser.add_argument("--lr", type=float, default=2e-4) |
| parser.add_argument("--epochs", type=int, default=1) |
| parser.add_argument("--max_steps", type=int, default=-1) |
| parser.add_argument("--max_seq_length", type=int, default=512) |
| parser.add_argument("--lora_r", type=int, default=16) |
| parser.add_argument("--lora_alpha", type=int, default=32) |
| parser.add_argument("--lora_dropout", type=float, default=0.05) |
| parser.add_argument("--quick_test", action="store_true") |
| parser.add_argument("--push_to_hub", action="store_true") |
| parser.add_argument("--no_report", action="store_true") |
| args = parser.parse_args() |
| main(args) |
|
|