File size: 5,810 Bytes
9b731b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | #!/usr/bin/env python3
"""
Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL.
Optimized for Kaggle T4Γ2 (free) or any single GPU with β₯16 GB VRAM.
REQUIREMENTS:
pip install unsloth transformers datasets trl peft accelerate
USAGE:
python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft
NOTES:
- Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM.
- Targets all linear layers for maximum fine-tuning capacity.
- Expect ~3β5 min/epoch on T4Γ2 with batch=2, grad_accum=4.
"""
import argparse
from datasets import load_dataset, concatenate_datasets
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it")
p.add_argument("--max_seq_length", type=int, default=2048)
p.add_argument("--lora_r", type=int, default=16)
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--gradient_accumulation_steps", type=int, default=4)
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--warmup_ratio", type=float, default=0.1)
p.add_argument("--weight_decay", type=float, default=0.01)
p.add_argument("--output", default="grandgemma-scam-sft")
p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name")
return p.parse_args()
SYSTEM = "You are a phone scam detection expert."
PROMPT = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE word: SCAM or LEGITIMATE."
)
def main():
args = parse_args()
# ββ Load model (4-bit via Unsloth) ββββββββββββββββββββββββββββ
print(f"Loading {args.model} with Unsloth 4-bit β¦")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
dtype=None, # auto-detect bf16 / fp16
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_alpha,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
# ββ Prepare dataset βββββββββββββββββββββββββββββββββββββββββββββ
print("Loading & formatting scam-dialogue β¦")
ds_train = load_dataset("BothBosu/scam-dialogue", split="train")
ds_test = load_dataset("BothBosu/scam-dialogue", split="test")
# Optional: merge BothBosu/Scammer-Conversation as extra data
try:
ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train")
ds_train = concatenate_datasets([ds_train, ds_extra])
print(f"Merged extra data β train size = {len(ds_train)}")
except Exception as e:
print(f"Could not load extra data: {e}")
def format_example(example):
answer = "SCAM" if example["label"] == 1 else "LEGITIMATE"
messages = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": PROMPT.format(transcript=example["dialogue"])},
{"role": "assistant", "content": answer},
]
return {"text": tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)}
ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names)
ds_test = ds_test.map(format_example, remove_columns=ds_test.column_names)
# ββ Training arguments ββββββββββββββββββββββββββββββββββββββββββ
training_args = SFTConfig(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
logging_strategy="steps",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
bf16=True,
fp16=False,
optim="adamw_8bit",
seed=42,
report_to="none",
max_seq_length=args.max_seq_length,
push_to_hub=True if args.push_to_hub else False,
hub_model_id=args.push_to_hub,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds_train,
eval_dataset=ds_test,
args=training_args,
dataset_text_field="text",
)
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\nStarting training β¦")
trainer.train()
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"\nSaving adapter to {args.output} β¦")
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
if args.push_to_hub:
print(f"Pushing to https://huggingface.co/{args.push_to_hub}")
model.push_to_hub(args.push_to_hub, tokenizer=tokenizer)
print("Done.")
if __name__ == "__main__":
main()
|