| |
| """ |
| Zero-shot evaluation: Gemma 4 E2B (2B) on scam call transcripts. |
| Tests base-model capability before deciding if fine-tuning is needed. |
| |
| REQUIREMENTS: |
| pip install transformers datasets torch huggingface_hub |
| |
| USAGE: |
| python eval_zero_shot.py --model google/gemma-4-E2B-it \ |
| --dataset BothBosu/scam-dialogue \ |
| --split test \ |
| --limit 100 |
| |
| # Or full eval (test split is ~400 rows): |
| python eval_zero_shot.py --limit -1 |
| |
| OUTPUT: |
| - results_zero_shot.json (per-example predictions + overall metrics) |
| - Console report with accuracy, confusion matrix, per-class F1 |
| """ |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import AutoProcessor, Gemma4ForConditionalGeneration |
|
|
| |
| SYS = ( |
| "You are a phone scam detection expert. " |
| "Your job is to read a call transcript and decide if it is a scam." |
| ) |
|
|
| USER_TEMPLATE = ( |
| "Read this phone call transcript and classify it:\n\n" |
| "{transcript}\n\n" |
| "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. " |
| "Do not explain." |
| ) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", default="google/gemma-4-E2B-it", |
| help="HuggingFace model id (E2B text-only)") |
| p.add_argument("--dataset", default="BothBosu/scam-dialogue", |
| help="HF dataset with 'dialogue' and 'label' columns") |
| p.add_argument("--split", default="test") |
| p.add_argument("--limit", type=int, default=100, |
| help="Max rows to eval (-1 = all)") |
| p.add_argument("--device", default="auto", help="cuda / cpu / auto") |
| p.add_argument("--dtype", default="bf16", choices=["bf16","fp16","fp32"]) |
| p.add_argument("--out", default="results_zero_shot.json") |
| return p.parse_args() |
|
|
|
|
| def load_model(model_id: str, device: str, dtype: str): |
| torch_dtype = { |
| "bf16": torch.bfloat16, |
| "fp16": torch.float16, |
| "fp32": torch.float32, |
| }[dtype] |
|
|
| print(f"Loading {model_id} (dtype={dtype}, device={device}) β¦") |
| model = Gemma4ForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| device_map="auto" if device == "auto" else None, |
| ) |
| if device != "auto": |
| model = model.to(device) |
| processor = AutoProcessor.from_pretrained(model_id) |
| model.eval() |
| return model, processor |
|
|
|
|
| @torch.inference_mode() |
| def classify(model, processor, transcript: str) -> str: |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYS}]}, |
| {"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]}, |
| ] |
| inputs = processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| ) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| gen_ids = model.generate( |
| **inputs, |
| max_new_tokens=5, |
| do_sample=False, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| ) |
| |
| new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:] |
| text = processor.batch_decode(new_ids, skip_special_tokens=True)[0] |
| return text.strip().upper() |
|
|
|
|
| def normalize(pred_raw: str) -> str: |
| if "SCAM" in pred_raw: |
| return "SCAM" |
| if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]): |
| return "LEGITIMATE" |
| return pred_raw |
|
|
|
|
| def gold_label(label_int: int) -> str: |
| return "SCAM" if label_int == 1 else "LEGITIMATE" |
|
|
|
|
| def compute_metrics(items): |
| total = len(items) |
| tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM") |
| fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE") |
| fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM") |
| tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE") |
|
|
| accuracy = (tp + tn) / total |
| precision = tp / (tp + fp) if (tp + fp) else 0 |
| recall = tp / (tp + fn) if (tp + fn) else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0 |
|
|
| return { |
| "total": total, |
| "accuracy": accuracy, |
| "precision_scam": precision, |
| "recall_scam": recall, |
| "f1_scam": f1, |
| "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}, |
| } |
|
|
|
|
| def main(): |
| args = parse_args() |
| model, processor = load_model(args.model, args.device, args.dtype) |
| ds = load_dataset(args.dataset, split=args.split) |
| n = len(ds) if args.limit < 0 else min(args.limit, len(ds)) |
|
|
| items = [] |
| t0 = time.time() |
| for i in range(n): |
| row = ds[i] |
| gold = gold_label(row["label"]) |
| pred_raw = classify(model, processor, row["dialogue"]) |
| pred = normalize(pred_raw) |
| correct = pred == gold |
| items.append({ |
| "index": i, |
| "dialogue": row["dialogue"][:500] + "β¦" if len(row["dialogue"]) > 500 else row["dialogue"], |
| "gold": gold, |
| "pred_raw": pred_raw, |
| "pred": pred, |
| "correct": correct, |
| }) |
| mark = "β" if correct else "β" |
| print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' β {pred:11} {mark}") |
|
|
| elapsed = time.time() - t0 |
| metrics = compute_metrics(items) |
| metrics["time_sec"] = elapsed |
| metrics["throughput"] = n / elapsed |
|
|
| print("\n" + "=" * 60) |
| print("ZERO-SHOT EVALUATION REPORT") |
| print("=" * 60) |
| print(f"Model : {args.model}") |
| print(f"Dataset : {args.dataset} / {args.split}") |
| print(f"Samples : {n}") |
| print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.1f} ex/s)") |
| print(f"Accuracy : {metrics['accuracy']:.2%}") |
| print(f"Precision : {metrics['precision_scam']:.2%} (SCAM class)") |
| print(f"Recall : {metrics['recall_scam']:.2%} (SCAM class)") |
| print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}") |
| print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} " |
| f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}") |
| print("=" * 60) |
|
|
| |
| out = {"args": vars(args), "metrics": metrics, "items": items} |
| Path(args.out).write_text(json.dumps(out, indent=2)) |
| print(f"\nSaved full results β {args.out}") |
|
|
| |
| print("\n" + "=" * 60) |
| print("DECISION RUBRIC") |
| print("=" * 60) |
| acc = metrics["accuracy"] |
| f1 = metrics["f1_scam"] |
| if acc >= 0.90 and f1 >= 0.85: |
| print("β
PASS β Base model is strong enough. Fine-tuning is OPTIONAL.") |
| elif acc >= 0.75 and f1 >= 0.70: |
| print("β οΈ MARGINAL β Fine-tuning recommended for production use.") |
| print(" Expected uplift from fine-tuning: +5β15 pp accuracy.") |
| else: |
| print("β FAIL β Base model is not reliable.") |
| print(" Fine-tuning REQUIRED before shipping.") |
| print(" (Run the Unsloth SFT script provided next.)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|