File size: 5,810 Bytes
9b731b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""
Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL.
Optimized for Kaggle T4Γ—2 (free) or any single GPU with β‰₯16 GB VRAM.

REQUIREMENTS:
  pip install unsloth transformers datasets trl peft accelerate

USAGE:
  python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft

NOTES:
  - Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM.
  - Targets all linear layers for maximum fine-tuning capacity.
  - Expect ~3–5 min/epoch on T4Γ—2 with batch=2, grad_accum=4.
"""

import argparse
from datasets import load_dataset, concatenate_datasets
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model", default="google/gemma-4-E2B-it")
    p.add_argument("--max_seq_length", type=int, default=2048)
    p.add_argument("--lora_r", type=int, default=16)
    p.add_argument("--lora_alpha", type=int, default=32)
    p.add_argument("--batch_size", type=int, default=2)
    p.add_argument("--gradient_accumulation_steps", type=int, default=4)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--warmup_ratio", type=float, default=0.1)
    p.add_argument("--weight_decay", type=float, default=0.01)
    p.add_argument("--output", default="grandgemma-scam-sft")
    p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name")
    return p.parse_args()


SYSTEM = "You are a phone scam detection expert."
PROMPT = (
    "Read this phone call transcript and classify it:\n\n"
    "{transcript}\n\n"
    "Answer with exactly ONE word: SCAM or LEGITIMATE."
)


def main():
    args = parse_args()

    # ── Load model (4-bit via Unsloth) ────────────────────────────
    print(f"Loading {args.model} with Unsloth 4-bit …")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        dtype=None,          # auto-detect bf16 / fp16
        load_in_4bit=True,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_r,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        lora_alpha=args.lora_alpha,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=42,
    )

    # ── Prepare dataset ─────────────────────────────────────────────
    print("Loading & formatting scam-dialogue …")
    ds_train = load_dataset("BothBosu/scam-dialogue", split="train")
    ds_test  = load_dataset("BothBosu/scam-dialogue", split="test")

    # Optional: merge BothBosu/Scammer-Conversation as extra data
    try:
        ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train")
        ds_train = concatenate_datasets([ds_train, ds_extra])
        print(f"Merged extra data β†’ train size = {len(ds_train)}")
    except Exception as e:
        print(f"Could not load extra data: {e}")

    def format_example(example):
        answer = "SCAM" if example["label"] == 1 else "LEGITIMATE"
        messages = [
            {"role": "system",    "content": SYSTEM},
            {"role": "user",      "content": PROMPT.format(transcript=example["dialogue"])},
            {"role": "assistant", "content": answer},
        ]
        return {"text": tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False)}

    ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names)
    ds_test  = ds_test.map(format_example,  remove_columns=ds_test.column_names)

    # ── Training arguments ──────────────────────────────────────────
    training_args = SFTConfig(
        output_dir=args.output,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.lr,
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        logging_strategy="steps",
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        bf16=True,
        fp16=False,
        optim="adamw_8bit",
        seed=42,
        report_to="none",
        max_seq_length=args.max_seq_length,
        push_to_hub=True if args.push_to_hub else False,
        hub_model_id=args.push_to_hub,
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=ds_train,
        eval_dataset=ds_test,
        args=training_args,
        dataset_text_field="text",
    )

    # ── Train ─────────────────────────────────────────────────────
    print("\nStarting training …")
    trainer.train()

    # ── Save ──────────────────────────────────────────────────────
    print(f"\nSaving adapter to {args.output} …")
    model.save_pretrained(args.output)
    tokenizer.save_pretrained(args.output)

    if args.push_to_hub:
        print(f"Pushing to https://huggingface.co/{args.push_to_hub}")
        model.push_to_hub(args.push_to_hub, tokenizer=tokenizer)

    print("Done.")


if __name__ == "__main__":
    main()