#!/usr/bin/env python3 """ Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL. Optimized for Kaggle T4×2 (free) or any single GPU with ≥16 GB VRAM. REQUIREMENTS: pip install unsloth transformers datasets trl peft accelerate USAGE: python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft NOTES: - Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM. - Targets all linear layers for maximum fine-tuning capacity. - Expect ~3–5 min/epoch on T4×2 with batch=2, grad_accum=4. """ import argparse from datasets import load_dataset, concatenate_datasets from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", default="google/gemma-4-E2B-it") p.add_argument("--max_seq_length", type=int, default=2048) p.add_argument("--lora_r", type=int, default=16) p.add_argument("--lora_alpha", type=int, default=32) p.add_argument("--batch_size", type=int, default=2) p.add_argument("--gradient_accumulation_steps", type=int, default=4) p.add_argument("--epochs", type=int, default=3) p.add_argument("--lr", type=float, default=2e-4) p.add_argument("--warmup_ratio", type=float, default=0.1) p.add_argument("--weight_decay", type=float, default=0.01) p.add_argument("--output", default="grandgemma-scam-sft") p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name") return p.parse_args() SYSTEM = "You are a phone scam detection expert." PROMPT = ( "Read this phone call transcript and classify it:\n\n" "{transcript}\n\n" "Answer with exactly ONE word: SCAM or LEGITIMATE." ) def main(): args = parse_args() # ── Load model (4-bit via Unsloth) ──────────────────────────── print(f"Loading {args.model} with Unsloth 4-bit …") model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_length, dtype=None, # auto-detect bf16 / fp16 load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_r, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=args.lora_alpha, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=42, ) # ── Prepare dataset ───────────────────────────────────────────── print("Loading & formatting scam-dialogue …") ds_train = load_dataset("BothBosu/scam-dialogue", split="train") ds_test = load_dataset("BothBosu/scam-dialogue", split="test") # Optional: merge BothBosu/Scammer-Conversation as extra data try: ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train") ds_train = concatenate_datasets([ds_train, ds_extra]) print(f"Merged extra data → train size = {len(ds_train)}") except Exception as e: print(f"Could not load extra data: {e}") def format_example(example): answer = "SCAM" if example["label"] == 1 else "LEGITIMATE" messages = [ {"role": "system", "content": SYSTEM}, {"role": "user", "content": PROMPT.format(transcript=example["dialogue"])}, {"role": "assistant", "content": answer}, ] return {"text": tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False)} ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names) ds_test = ds_test.map(format_example, remove_columns=ds_test.column_names) # ── Training arguments ────────────────────────────────────────── training_args = SFTConfig( output_dir=args.output, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, logging_strategy="steps", logging_steps=10, eval_strategy="epoch", save_strategy="epoch", save_total_limit=2, bf16=True, fp16=False, optim="adamw_8bit", seed=42, report_to="none", max_seq_length=args.max_seq_length, push_to_hub=True if args.push_to_hub else False, hub_model_id=args.push_to_hub, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=ds_train, eval_dataset=ds_test, args=training_args, dataset_text_field="text", ) # ── Train ───────────────────────────────────────────────────── print("\nStarting training …") trainer.train() # ── Save ────────────────────────────────────────────────────── print(f"\nSaving adapter to {args.output} …") model.save_pretrained(args.output) tokenizer.save_pretrained(args.output) if args.push_to_hub: print(f"Pushing to https://huggingface.co/{args.push_to_hub}") model.push_to_hub(args.push_to_hub, tokenizer=tokenizer) print("Done.") if __name__ == "__main__": main()