grandgemma-eval / eval_zero_shot.py
s23deepak's picture
Upload eval_zero_shot.py
7fb4235 verified
#!/usr/bin/env python3
"""
Zero-shot evaluation: Gemma 4 E2B (2B) on scam call transcripts.
Tests base-model capability before deciding if fine-tuning is needed.
REQUIREMENTS:
pip install transformers datasets torch huggingface_hub
USAGE:
python eval_zero_shot.py --model google/gemma-4-E2B-it \
--dataset BothBosu/scam-dialogue \
--split test \
--limit 100
# Or full eval (test split is ~400 rows):
python eval_zero_shot.py --limit -1
OUTPUT:
- results_zero_shot.json (per-example predictions + overall metrics)
- Console report with accuracy, confusion matrix, per-class F1
"""
import argparse
import json
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoProcessor, Gemma4ForConditionalGeneration
# ── Prompt engineering ──────────────────────────────────────────────
SYS = (
"You are a phone scam detection expert. "
"Your job is to read a call transcript and decide if it is a scam."
)
USER_TEMPLATE = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
"Do not explain."
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it",
help="HuggingFace model id (E2B text-only)")
p.add_argument("--dataset", default="BothBosu/scam-dialogue",
help="HF dataset with 'dialogue' and 'label' columns")
p.add_argument("--split", default="test")
p.add_argument("--limit", type=int, default=100,
help="Max rows to eval (-1 = all)")
p.add_argument("--device", default="auto", help="cuda / cpu / auto")
p.add_argument("--dtype", default="bf16", choices=["bf16","fp16","fp32"])
p.add_argument("--out", default="results_zero_shot.json")
return p.parse_args()
def load_model(model_id: str, device: str, dtype: str):
torch_dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[dtype]
print(f"Loading {model_id} (dtype={dtype}, device={device}) …")
model = Gemma4ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto" if device == "auto" else None,
)
if device != "auto":
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
model.eval()
return model, processor
@torch.inference_mode()
def classify(model, processor, transcript: str) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYS}]},
{"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_ids = model.generate(
**inputs,
max_new_tokens=5,
do_sample=False,
pad_token_id=processor.tokenizer.pad_token_id,
)
# slice off prompt tokens
new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
text = processor.batch_decode(new_ids, skip_special_tokens=True)[0]
return text.strip().upper()
def normalize(pred_raw: str) -> str:
if "SCAM" in pred_raw:
return "SCAM"
if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
return "LEGITIMATE"
return pred_raw # unknown β†’ will count as wrong
def gold_label(label_int: int) -> str:
return "SCAM" if label_int == 1 else "LEGITIMATE"
def compute_metrics(items):
total = len(items)
tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
accuracy = (tp + tn) / total
precision = tp / (tp + fp) if (tp + fp) else 0
recall = tp / (tp + fn) if (tp + fn) else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
return {
"total": total,
"accuracy": accuracy,
"precision_scam": precision,
"recall_scam": recall,
"f1_scam": f1,
"confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
}
def main():
args = parse_args()
model, processor = load_model(args.model, args.device, args.dtype)
ds = load_dataset(args.dataset, split=args.split)
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
items = []
t0 = time.time()
for i in range(n):
row = ds[i]
gold = gold_label(row["label"])
pred_raw = classify(model, processor, row["dialogue"])
pred = normalize(pred_raw)
correct = pred == gold
items.append({
"index": i,
"dialogue": row["dialogue"][:500] + "…" if len(row["dialogue"]) > 500 else row["dialogue"],
"gold": gold,
"pred_raw": pred_raw,
"pred": pred,
"correct": correct,
})
mark = "βœ“" if correct else "βœ—"
print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' β†’ {pred:11} {mark}")
elapsed = time.time() - t0
metrics = compute_metrics(items)
metrics["time_sec"] = elapsed
metrics["throughput"] = n / elapsed
print("\n" + "=" * 60)
print("ZERO-SHOT EVALUATION REPORT")
print("=" * 60)
print(f"Model : {args.model}")
print(f"Dataset : {args.dataset} / {args.split}")
print(f"Samples : {n}")
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.1f} ex/s)")
print(f"Accuracy : {metrics['accuracy']:.2%}")
print(f"Precision : {metrics['precision_scam']:.2%} (SCAM class)")
print(f"Recall : {metrics['recall_scam']:.2%} (SCAM class)")
print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}")
print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
print("=" * 60)
# Save
out = {"args": vars(args), "metrics": metrics, "items": items}
Path(args.out).write_text(json.dumps(out, indent=2))
print(f"\nSaved full results β†’ {args.out}")
# ── Rubric / decision ─────────────────────────────────────────
print("\n" + "=" * 60)
print("DECISION RUBRIC")
print("=" * 60)
acc = metrics["accuracy"]
f1 = metrics["f1_scam"]
if acc >= 0.90 and f1 >= 0.85:
print("βœ… PASS β€” Base model is strong enough. Fine-tuning is OPTIONAL.")
elif acc >= 0.75 and f1 >= 0.70:
print("⚠️ MARGINAL β€” Fine-tuning recommended for production use.")
print(" Expected uplift from fine-tuning: +5–15 pp accuracy.")
else:
print("❌ FAIL β€” Base model is not reliable.")
print(" Fine-tuning REQUIRED before shipping.")
print(" (Run the Unsloth SFT script provided next.)")
if __name__ == "__main__":
main()