| |
| """ |
| Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL. |
| Optimized for Kaggle T4Γ2 (free) or any single GPU with β₯16 GB VRAM. |
| |
| REQUIREMENTS: |
| pip install unsloth transformers datasets trl peft accelerate |
| |
| USAGE: |
| python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft |
| |
| NOTES: |
| - Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM. |
| - Targets all linear layers for maximum fine-tuning capacity. |
| - Expect ~3β5 min/epoch on T4Γ2 with batch=2, grad_accum=4. |
| """ |
|
|
| import argparse |
| from datasets import load_dataset, concatenate_datasets |
| from trl import SFTConfig, SFTTrainer |
| from unsloth import FastLanguageModel |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", default="google/gemma-4-E2B-it") |
| p.add_argument("--max_seq_length", type=int, default=2048) |
| p.add_argument("--lora_r", type=int, default=16) |
| p.add_argument("--lora_alpha", type=int, default=32) |
| p.add_argument("--batch_size", type=int, default=2) |
| p.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| p.add_argument("--epochs", type=int, default=3) |
| p.add_argument("--lr", type=float, default=2e-4) |
| p.add_argument("--warmup_ratio", type=float, default=0.1) |
| p.add_argument("--weight_decay", type=float, default=0.01) |
| p.add_argument("--output", default="grandgemma-scam-sft") |
| p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name") |
| return p.parse_args() |
|
|
|
|
| SYSTEM = "You are a phone scam detection expert." |
| PROMPT = ( |
| "Read this phone call transcript and classify it:\n\n" |
| "{transcript}\n\n" |
| "Answer with exactly ONE word: SCAM or LEGITIMATE." |
| ) |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| |
| print(f"Loading {args.model} with Unsloth 4-bit β¦") |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=args.model, |
| max_seq_length=args.max_seq_length, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=args.lora_r, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_alpha=args.lora_alpha, |
| lora_dropout=0, |
| bias="none", |
| use_gradient_checkpointing="unsloth", |
| random_state=42, |
| ) |
|
|
| |
| print("Loading & formatting scam-dialogue β¦") |
| ds_train = load_dataset("BothBosu/scam-dialogue", split="train") |
| ds_test = load_dataset("BothBosu/scam-dialogue", split="test") |
|
|
| |
| try: |
| ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train") |
| ds_train = concatenate_datasets([ds_train, ds_extra]) |
| print(f"Merged extra data β train size = {len(ds_train)}") |
| except Exception as e: |
| print(f"Could not load extra data: {e}") |
|
|
| def format_example(example): |
| answer = "SCAM" if example["label"] == 1 else "LEGITIMATE" |
| messages = [ |
| {"role": "system", "content": SYSTEM}, |
| {"role": "user", "content": PROMPT.format(transcript=example["dialogue"])}, |
| {"role": "assistant", "content": answer}, |
| ] |
| return {"text": tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False)} |
|
|
| ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names) |
| ds_test = ds_test.map(format_example, remove_columns=ds_test.column_names) |
|
|
| |
| training_args = SFTConfig( |
| output_dir=args.output, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.lr, |
| warmup_ratio=args.warmup_ratio, |
| weight_decay=args.weight_decay, |
| logging_strategy="steps", |
| logging_steps=10, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| save_total_limit=2, |
| bf16=True, |
| fp16=False, |
| optim="adamw_8bit", |
| seed=42, |
| report_to="none", |
| max_seq_length=args.max_seq_length, |
| push_to_hub=True if args.push_to_hub else False, |
| hub_model_id=args.push_to_hub, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=ds_train, |
| eval_dataset=ds_test, |
| args=training_args, |
| dataset_text_field="text", |
| ) |
|
|
| |
| print("\nStarting training β¦") |
| trainer.train() |
|
|
| |
| print(f"\nSaving adapter to {args.output} β¦") |
| model.save_pretrained(args.output) |
| tokenizer.save_pretrained(args.output) |
|
|
| if args.push_to_hub: |
| print(f"Pushing to https://huggingface.co/{args.push_to_hub}") |
| model.push_to_hub(args.push_to_hub, tokenizer=tokenizer) |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|