grandgemma-eval / train_sft_unsloth.py
s23deepak's picture
Upload train_sft_unsloth.py
9b731b7 verified
#!/usr/bin/env python3
"""
Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL.
Optimized for Kaggle T4Γ—2 (free) or any single GPU with β‰₯16 GB VRAM.
REQUIREMENTS:
pip install unsloth transformers datasets trl peft accelerate
USAGE:
python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft
NOTES:
- Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM.
- Targets all linear layers for maximum fine-tuning capacity.
- Expect ~3–5 min/epoch on T4Γ—2 with batch=2, grad_accum=4.
"""
import argparse
from datasets import load_dataset, concatenate_datasets
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it")
p.add_argument("--max_seq_length", type=int, default=2048)
p.add_argument("--lora_r", type=int, default=16)
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--gradient_accumulation_steps", type=int, default=4)
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--warmup_ratio", type=float, default=0.1)
p.add_argument("--weight_decay", type=float, default=0.01)
p.add_argument("--output", default="grandgemma-scam-sft")
p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name")
return p.parse_args()
SYSTEM = "You are a phone scam detection expert."
PROMPT = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE word: SCAM or LEGITIMATE."
)
def main():
args = parse_args()
# ── Load model (4-bit via Unsloth) ────────────────────────────
print(f"Loading {args.model} with Unsloth 4-bit …")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
dtype=None, # auto-detect bf16 / fp16
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_alpha,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
# ── Prepare dataset ─────────────────────────────────────────────
print("Loading & formatting scam-dialogue …")
ds_train = load_dataset("BothBosu/scam-dialogue", split="train")
ds_test = load_dataset("BothBosu/scam-dialogue", split="test")
# Optional: merge BothBosu/Scammer-Conversation as extra data
try:
ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train")
ds_train = concatenate_datasets([ds_train, ds_extra])
print(f"Merged extra data β†’ train size = {len(ds_train)}")
except Exception as e:
print(f"Could not load extra data: {e}")
def format_example(example):
answer = "SCAM" if example["label"] == 1 else "LEGITIMATE"
messages = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": PROMPT.format(transcript=example["dialogue"])},
{"role": "assistant", "content": answer},
]
return {"text": tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)}
ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names)
ds_test = ds_test.map(format_example, remove_columns=ds_test.column_names)
# ── Training arguments ──────────────────────────────────────────
training_args = SFTConfig(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
logging_strategy="steps",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
bf16=True,
fp16=False,
optim="adamw_8bit",
seed=42,
report_to="none",
max_seq_length=args.max_seq_length,
push_to_hub=True if args.push_to_hub else False,
hub_model_id=args.push_to_hub,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds_train,
eval_dataset=ds_test,
args=training_args,
dataset_text_field="text",
)
# ── Train ─────────────────────────────────────────────────────
print("\nStarting training …")
trainer.train()
# ── Save ──────────────────────────────────────────────────────
print(f"\nSaving adapter to {args.output} …")
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
if args.push_to_hub:
print(f"Pushing to https://huggingface.co/{args.push_to_hub}")
model.push_to_hub(args.push_to_hub, tokenizer=tokenizer)
print("Done.")
if __name__ == "__main__":
main()