#!/usr/bin/env python3 """ Zero-shot evaluation: Gemma 4 E2B (2B) on scam call transcripts. Tests base-model capability before deciding if fine-tuning is needed. REQUIREMENTS: pip install transformers datasets torch huggingface_hub USAGE: python eval_zero_shot.py --model google/gemma-4-E2B-it \ --dataset BothBosu/scam-dialogue \ --split test \ --limit 100 # Or full eval (test split is ~400 rows): python eval_zero_shot.py --limit -1 OUTPUT: - results_zero_shot.json (per-example predictions + overall metrics) - Console report with accuracy, confusion matrix, per-class F1 """ import argparse import json import time from pathlib import Path import torch from datasets import load_dataset from transformers import AutoProcessor, Gemma4ForConditionalGeneration # ── Prompt engineering ────────────────────────────────────────────── SYS = ( "You are a phone scam detection expert. " "Your job is to read a call transcript and decide if it is a scam." ) USER_TEMPLATE = ( "Read this phone call transcript and classify it:\n\n" "{transcript}\n\n" "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. " "Do not explain." ) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", default="google/gemma-4-E2B-it", help="HuggingFace model id (E2B text-only)") p.add_argument("--dataset", default="BothBosu/scam-dialogue", help="HF dataset with 'dialogue' and 'label' columns") p.add_argument("--split", default="test") p.add_argument("--limit", type=int, default=100, help="Max rows to eval (-1 = all)") p.add_argument("--device", default="auto", help="cuda / cpu / auto") p.add_argument("--dtype", default="bf16", choices=["bf16","fp16","fp32"]) p.add_argument("--out", default="results_zero_shot.json") return p.parse_args() def load_model(model_id: str, device: str, dtype: str): torch_dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }[dtype] print(f"Loading {model_id} (dtype={dtype}, device={device}) …") model = Gemma4ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch_dtype, device_map="auto" if device == "auto" else None, ) if device != "auto": model = model.to(device) processor = AutoProcessor.from_pretrained(model_id) model.eval() return model, processor @torch.inference_mode() def classify(model, processor, transcript: str) -> str: messages = [ {"role": "system", "content": [{"type": "text", "text": SYS}]}, {"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]}, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ) inputs = {k: v.to(model.device) for k, v in inputs.items()} gen_ids = model.generate( **inputs, max_new_tokens=5, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id, ) # slice off prompt tokens new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:] text = processor.batch_decode(new_ids, skip_special_tokens=True)[0] return text.strip().upper() def normalize(pred_raw: str) -> str: if "SCAM" in pred_raw: return "SCAM" if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]): return "LEGITIMATE" return pred_raw # unknown → will count as wrong def gold_label(label_int: int) -> str: return "SCAM" if label_int == 1 else "LEGITIMATE" def compute_metrics(items): total = len(items) tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM") fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE") fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM") tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE") accuracy = (tp + tn) / total precision = tp / (tp + fp) if (tp + fp) else 0 recall = tp / (tp + fn) if (tp + fn) else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0 return { "total": total, "accuracy": accuracy, "precision_scam": precision, "recall_scam": recall, "f1_scam": f1, "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}, } def main(): args = parse_args() model, processor = load_model(args.model, args.device, args.dtype) ds = load_dataset(args.dataset, split=args.split) n = len(ds) if args.limit < 0 else min(args.limit, len(ds)) items = [] t0 = time.time() for i in range(n): row = ds[i] gold = gold_label(row["label"]) pred_raw = classify(model, processor, row["dialogue"]) pred = normalize(pred_raw) correct = pred == gold items.append({ "index": i, "dialogue": row["dialogue"][:500] + "…" if len(row["dialogue"]) > 500 else row["dialogue"], "gold": gold, "pred_raw": pred_raw, "pred": pred, "correct": correct, }) mark = "✓" if correct else "✗" print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}") elapsed = time.time() - t0 metrics = compute_metrics(items) metrics["time_sec"] = elapsed metrics["throughput"] = n / elapsed print("\n" + "=" * 60) print("ZERO-SHOT EVALUATION REPORT") print("=" * 60) print(f"Model : {args.model}") print(f"Dataset : {args.dataset} / {args.split}") print(f"Samples : {n}") print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.1f} ex/s)") print(f"Accuracy : {metrics['accuracy']:.2%}") print(f"Precision : {metrics['precision_scam']:.2%} (SCAM class)") print(f"Recall : {metrics['recall_scam']:.2%} (SCAM class)") print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}") print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} " f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}") print("=" * 60) # Save out = {"args": vars(args), "metrics": metrics, "items": items} Path(args.out).write_text(json.dumps(out, indent=2)) print(f"\nSaved full results → {args.out}") # ── Rubric / decision ───────────────────────────────────────── print("\n" + "=" * 60) print("DECISION RUBRIC") print("=" * 60) acc = metrics["accuracy"] f1 = metrics["f1_scam"] if acc >= 0.90 and f1 >= 0.85: print("✅ PASS — Base model is strong enough. Fine-tuning is OPTIONAL.") elif acc >= 0.75 and f1 >= 0.70: print("⚠️ MARGINAL — Fine-tuning recommended for production use.") print(" Expected uplift from fine-tuning: +5–15 pp accuracy.") else: print("❌ FAIL — Base model is not reliable.") print(" Fine-tuning REQUIRED before shipping.") print(" (Run the Unsloth SFT script provided next.)") if __name__ == "__main__": main()