s23deepak commited on
Commit
9b731b7
Β·
verified Β·
1 Parent(s): 7fb4235

Upload train_sft_unsloth.py

Browse files
Files changed (1) hide show
  1. train_sft_unsloth.py +150 -0
train_sft_unsloth.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tune Gemma 4 E2B (2B) for scam-call classification with Unsloth + TRL.
4
+ Optimized for Kaggle T4Γ—2 (free) or any single GPU with β‰₯16 GB VRAM.
5
+
6
+ REQUIREMENTS:
7
+ pip install unsloth transformers datasets trl peft accelerate
8
+
9
+ USAGE:
10
+ python train_sft_unsloth.py --output s23deepak/grandgemma-scam-sft
11
+
12
+ NOTES:
13
+ - Uses 4-bit quantization + LoRA (r=16) to fit on 16 GB VRAM.
14
+ - Targets all linear layers for maximum fine-tuning capacity.
15
+ - Expect ~3–5 min/epoch on T4Γ—2 with batch=2, grad_accum=4.
16
+ """
17
+
18
+ import argparse
19
+ from datasets import load_dataset, concatenate_datasets
20
+ from trl import SFTConfig, SFTTrainer
21
+ from unsloth import FastLanguageModel
22
+
23
+
24
+ def parse_args():
25
+ p = argparse.ArgumentParser()
26
+ p.add_argument("--model", default="google/gemma-4-E2B-it")
27
+ p.add_argument("--max_seq_length", type=int, default=2048)
28
+ p.add_argument("--lora_r", type=int, default=16)
29
+ p.add_argument("--lora_alpha", type=int, default=32)
30
+ p.add_argument("--batch_size", type=int, default=2)
31
+ p.add_argument("--gradient_accumulation_steps", type=int, default=4)
32
+ p.add_argument("--epochs", type=int, default=3)
33
+ p.add_argument("--lr", type=float, default=2e-4)
34
+ p.add_argument("--warmup_ratio", type=float, default=0.1)
35
+ p.add_argument("--weight_decay", type=float, default=0.01)
36
+ p.add_argument("--output", default="grandgemma-scam-sft")
37
+ p.add_argument("--push_to_hub", default=None, help="HF repo id, e.g. username/model-name")
38
+ return p.parse_args()
39
+
40
+
41
+ SYSTEM = "You are a phone scam detection expert."
42
+ PROMPT = (
43
+ "Read this phone call transcript and classify it:\n\n"
44
+ "{transcript}\n\n"
45
+ "Answer with exactly ONE word: SCAM or LEGITIMATE."
46
+ )
47
+
48
+
49
+ def main():
50
+ args = parse_args()
51
+
52
+ # ── Load model (4-bit via Unsloth) ────────────────────────────
53
+ print(f"Loading {args.model} with Unsloth 4-bit …")
54
+ model, tokenizer = FastLanguageModel.from_pretrained(
55
+ model_name=args.model,
56
+ max_seq_length=args.max_seq_length,
57
+ dtype=None, # auto-detect bf16 / fp16
58
+ load_in_4bit=True,
59
+ )
60
+
61
+ model = FastLanguageModel.get_peft_model(
62
+ model,
63
+ r=args.lora_r,
64
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
65
+ "gate_proj", "up_proj", "down_proj"],
66
+ lora_alpha=args.lora_alpha,
67
+ lora_dropout=0,
68
+ bias="none",
69
+ use_gradient_checkpointing="unsloth",
70
+ random_state=42,
71
+ )
72
+
73
+ # ── Prepare dataset ─────────────────────────────────────────────
74
+ print("Loading & formatting scam-dialogue …")
75
+ ds_train = load_dataset("BothBosu/scam-dialogue", split="train")
76
+ ds_test = load_dataset("BothBosu/scam-dialogue", split="test")
77
+
78
+ # Optional: merge BothBosu/Scammer-Conversation as extra data
79
+ try:
80
+ ds_extra = load_dataset("BothBosu/Scammer-Conversation", split="train")
81
+ ds_train = concatenate_datasets([ds_train, ds_extra])
82
+ print(f"Merged extra data β†’ train size = {len(ds_train)}")
83
+ except Exception as e:
84
+ print(f"Could not load extra data: {e}")
85
+
86
+ def format_example(example):
87
+ answer = "SCAM" if example["label"] == 1 else "LEGITIMATE"
88
+ messages = [
89
+ {"role": "system", "content": SYSTEM},
90
+ {"role": "user", "content": PROMPT.format(transcript=example["dialogue"])},
91
+ {"role": "assistant", "content": answer},
92
+ ]
93
+ return {"text": tokenizer.apply_chat_template(
94
+ messages, tokenize=False, add_generation_prompt=False)}
95
+
96
+ ds_train = ds_train.map(format_example, remove_columns=ds_train.column_names)
97
+ ds_test = ds_test.map(format_example, remove_columns=ds_test.column_names)
98
+
99
+ # ── Training arguments ──────────────────────────────────────────
100
+ training_args = SFTConfig(
101
+ output_dir=args.output,
102
+ num_train_epochs=args.epochs,
103
+ per_device_train_batch_size=args.batch_size,
104
+ per_device_eval_batch_size=args.batch_size,
105
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
106
+ learning_rate=args.lr,
107
+ warmup_ratio=args.warmup_ratio,
108
+ weight_decay=args.weight_decay,
109
+ logging_strategy="steps",
110
+ logging_steps=10,
111
+ eval_strategy="epoch",
112
+ save_strategy="epoch",
113
+ save_total_limit=2,
114
+ bf16=True,
115
+ fp16=False,
116
+ optim="adamw_8bit",
117
+ seed=42,
118
+ report_to="none",
119
+ max_seq_length=args.max_seq_length,
120
+ push_to_hub=True if args.push_to_hub else False,
121
+ hub_model_id=args.push_to_hub,
122
+ )
123
+
124
+ trainer = SFTTrainer(
125
+ model=model,
126
+ tokenizer=tokenizer,
127
+ train_dataset=ds_train,
128
+ eval_dataset=ds_test,
129
+ args=training_args,
130
+ dataset_text_field="text",
131
+ )
132
+
133
+ # ── Train ─────────────────────────────────────────────────────
134
+ print("\nStarting training …")
135
+ trainer.train()
136
+
137
+ # ── Save ──────────────────────────────────────────────────────
138
+ print(f"\nSaving adapter to {args.output} …")
139
+ model.save_pretrained(args.output)
140
+ tokenizer.save_pretrained(args.output)
141
+
142
+ if args.push_to_hub:
143
+ print(f"Pushing to https://huggingface.co/{args.push_to_hub}")
144
+ model.push_to_hub(args.push_to_hub, tokenizer=tokenizer)
145
+
146
+ print("Done.")
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()