grandgemma-eval / format_dataset.py
s23deepak's picture
Upload format_dataset.py
0f23a85 verified
#!/usr/bin/env python3
"""
Format BothBosu scam-dialogue (and optionally Scammer-Conversation)
into standardized chat-template JSONL for SFT.
REQUIREMENTS:
pip install datasets transformers
USAGE:
python format_dataset.py --out_dir ./formatted_scam_data
OUTPUT:
formatted_scam_data/
train.jsonl (chat-format messages per line)
test.jsonl
README.md (dataset card fragment)
Each JSONL line:
{"messages": [
{"role": "system", "content": "You are a phone scam detection expert."},
{"role": "user", "content": "Read this transcript...\n\n{transcript}"},
{"role": "assistant", "content": "SCAM"}
]}
"""
import argparse
import json
from pathlib import Path
from datasets import load_dataset, concatenate_datasets
PROMPT_TEMPLATE = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE word: SCAM or LEGITIMATE."
)
SYSTEM = "You are a phone scam detection expert."
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--primary", default="BothBosu/scam-dialogue")
p.add_argument("--secondary", default="BothBosu/Scammer-Conversation",
help="Optional extra dataset to merge into train")
p.add_argument("--out_dir", default="./formatted_scam_data")
return p.parse_args()
def row_to_chat(row):
"""Convert a raw dataset row → ChatML dict."""
answer = "SCAM" if row["label"] == 1 else "LEGITIMATE"
# Handle different column names across datasets
transcript = row.get("dialogue") or row.get("conversation")
return {
"messages": [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": PROMPT_TEMPLATE.format(transcript=transcript)},
{"role": "assistant", "content": answer},
]
}
def save_jsonl(rows, path: Path):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"Saved {len(rows)} rows → {path}")
def main():
args = parse_args()
out_dir = Path(args.out_dir)
# Load primary
print(f"Loading primary dataset: {args.primary}")
ds_train = load_dataset(args.primary, split="train")
ds_test = load_dataset(args.primary, split="test")
# Optional secondary merge
if args.secondary:
try:
ds_extra = load_dataset(args.secondary, split="train")
n_before = len(ds_train)
ds_train = concatenate_datasets([ds_train, ds_extra])
print(f"Merged {args.secondary}: {n_before}{len(ds_train)} train rows")
except Exception as e:
print(f"Skipped secondary dataset: {e}")
# Convert
train_rows = [row_to_chat(r) for r in ds_train]
test_rows = [row_to_chat(r) for r in ds_test]
# Save
save_jsonl(train_rows, out_dir / "train.jsonl")
save_jsonl(test_rows, out_dir / "test.jsonl")
# Stats
n_scam_train = sum(1 for r in train_rows if r["messages"][2]["content"] == "SCAM")
n_scam_test = sum(1 for r in test_rows if r["messages"][2]["content"] == "SCAM")
stats = {
"train": {"total": len(train_rows), "scam": n_scam_train, "legit": len(train_rows) - n_scam_train},
"test": {"total": len(test_rows), "scam": n_scam_test, "legit": len(test_rows) - n_scam_test},
}
(out_dir / "stats.json").write_text(json.dumps(stats, indent=2))
print(f"\nStats:\n{json.dumps(stats, indent=2)}")
# README fragment
readme = f"""# Formatted Scam-Call Dataset (ChatML)
Generated by `format_dataset.py`.
## Sources
- Primary: {args.primary}
- Secondary: {args.secondary or "None"}
## Statistics
```json
{json.dumps(stats, indent=2)}
```
## Schema
Each `.jsonl` line is a ChatML message list compatible with TRL / Unsloth SFTTrainer.
"""
(out_dir / "README.md").write_text(readme)
print(f"\nDone. Output directory: {out_dir.absolute()}")
if __name__ == "__main__":
main()