File size: 7,527 Bytes
7fb4235 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | #!/usr/bin/env python3
"""
Zero-shot evaluation: Gemma 4 E2B (2B) on scam call transcripts.
Tests base-model capability before deciding if fine-tuning is needed.
REQUIREMENTS:
pip install transformers datasets torch huggingface_hub
USAGE:
python eval_zero_shot.py --model google/gemma-4-E2B-it \
--dataset BothBosu/scam-dialogue \
--split test \
--limit 100
# Or full eval (test split is ~400 rows):
python eval_zero_shot.py --limit -1
OUTPUT:
- results_zero_shot.json (per-example predictions + overall metrics)
- Console report with accuracy, confusion matrix, per-class F1
"""
import argparse
import json
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoProcessor, Gemma4ForConditionalGeneration
# ββ Prompt engineering ββββββββββββββββββββββββββββββββββββββββββββββ
SYS = (
"You are a phone scam detection expert. "
"Your job is to read a call transcript and decide if it is a scam."
)
USER_TEMPLATE = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
"Do not explain."
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it",
help="HuggingFace model id (E2B text-only)")
p.add_argument("--dataset", default="BothBosu/scam-dialogue",
help="HF dataset with 'dialogue' and 'label' columns")
p.add_argument("--split", default="test")
p.add_argument("--limit", type=int, default=100,
help="Max rows to eval (-1 = all)")
p.add_argument("--device", default="auto", help="cuda / cpu / auto")
p.add_argument("--dtype", default="bf16", choices=["bf16","fp16","fp32"])
p.add_argument("--out", default="results_zero_shot.json")
return p.parse_args()
def load_model(model_id: str, device: str, dtype: str):
torch_dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[dtype]
print(f"Loading {model_id} (dtype={dtype}, device={device}) β¦")
model = Gemma4ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto" if device == "auto" else None,
)
if device != "auto":
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
model.eval()
return model, processor
@torch.inference_mode()
def classify(model, processor, transcript: str) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYS}]},
{"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_ids = model.generate(
**inputs,
max_new_tokens=5,
do_sample=False,
pad_token_id=processor.tokenizer.pad_token_id,
)
# slice off prompt tokens
new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
text = processor.batch_decode(new_ids, skip_special_tokens=True)[0]
return text.strip().upper()
def normalize(pred_raw: str) -> str:
if "SCAM" in pred_raw:
return "SCAM"
if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
return "LEGITIMATE"
return pred_raw # unknown β will count as wrong
def gold_label(label_int: int) -> str:
return "SCAM" if label_int == 1 else "LEGITIMATE"
def compute_metrics(items):
total = len(items)
tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
accuracy = (tp + tn) / total
precision = tp / (tp + fp) if (tp + fp) else 0
recall = tp / (tp + fn) if (tp + fn) else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
return {
"total": total,
"accuracy": accuracy,
"precision_scam": precision,
"recall_scam": recall,
"f1_scam": f1,
"confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
}
def main():
args = parse_args()
model, processor = load_model(args.model, args.device, args.dtype)
ds = load_dataset(args.dataset, split=args.split)
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
items = []
t0 = time.time()
for i in range(n):
row = ds[i]
gold = gold_label(row["label"])
pred_raw = classify(model, processor, row["dialogue"])
pred = normalize(pred_raw)
correct = pred == gold
items.append({
"index": i,
"dialogue": row["dialogue"][:500] + "β¦" if len(row["dialogue"]) > 500 else row["dialogue"],
"gold": gold,
"pred_raw": pred_raw,
"pred": pred,
"correct": correct,
})
mark = "β" if correct else "β"
print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' β {pred:11} {mark}")
elapsed = time.time() - t0
metrics = compute_metrics(items)
metrics["time_sec"] = elapsed
metrics["throughput"] = n / elapsed
print("\n" + "=" * 60)
print("ZERO-SHOT EVALUATION REPORT")
print("=" * 60)
print(f"Model : {args.model}")
print(f"Dataset : {args.dataset} / {args.split}")
print(f"Samples : {n}")
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.1f} ex/s)")
print(f"Accuracy : {metrics['accuracy']:.2%}")
print(f"Precision : {metrics['precision_scam']:.2%} (SCAM class)")
print(f"Recall : {metrics['recall_scam']:.2%} (SCAM class)")
print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}")
print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
print("=" * 60)
# Save
out = {"args": vars(args), "metrics": metrics, "items": items}
Path(args.out).write_text(json.dumps(out, indent=2))
print(f"\nSaved full results β {args.out}")
# ββ Rubric / decision βββββββββββββββββββββββββββββββββββββββββ
print("\n" + "=" * 60)
print("DECISION RUBRIC")
print("=" * 60)
acc = metrics["accuracy"]
f1 = metrics["f1_scam"]
if acc >= 0.90 and f1 >= 0.85:
print("β
PASS β Base model is strong enough. Fine-tuning is OPTIONAL.")
elif acc >= 0.75 and f1 >= 0.70:
print("β οΈ MARGINAL β Fine-tuning recommended for production use.")
print(" Expected uplift from fine-tuning: +5β15 pp accuracy.")
else:
print("β FAIL β Base model is not reliable.")
print(" Fine-tuning REQUIRED before shipping.")
print(" (Run the Unsloth SFT script provided next.)")
if __name__ == "__main__":
main()
|