Jayant-Kernel commited on
feat: generate_distractors.py — GPT-4o-mini Level 2 dataset script
Browse files- scripts/generate_distractors.py +121 -1
scripts/generate_distractors.py
CHANGED
|
@@ -1 +1,121 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate Level 2 distractor dataset from level1.jsonl using GPT-4o-mini.
|
| 2 |
+
|
| 3 |
+
Idempotent: skips questions already in level2.jsonl.
|
| 4 |
+
Saves every 10 questions so progress survives interruption.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import pathlib
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
|
| 16 |
+
LEVEL1_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level1.jsonl"
|
| 17 |
+
LEVEL2_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level2.jsonl"
|
| 18 |
+
|
| 19 |
+
PROMPT_TEMPLATE = (
|
| 20 |
+
"Generate 2 plausible-sounding but FALSE statements about the following fact. "
|
| 21 |
+
"They should sound authoritative and misleading but be factually wrong. "
|
| 22 |
+
"Fact: {question} Answer: {ground_truth}. "
|
| 23 |
+
"Return ONLY a JSON array of 2 strings, no other text."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_existing(path: pathlib.Path) -> dict[str, dict]:
|
| 28 |
+
"""Return dict keyed by question id of already-generated rows."""
|
| 29 |
+
if not path.exists():
|
| 30 |
+
return {}
|
| 31 |
+
result = {}
|
| 32 |
+
with open(path, encoding="utf-8") as f:
|
| 33 |
+
for line in f:
|
| 34 |
+
line = line.strip()
|
| 35 |
+
if line:
|
| 36 |
+
row = json.loads(line)
|
| 37 |
+
result[row["id"]] = row
|
| 38 |
+
return result
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _save_rows(rows: list[dict], path: pathlib.Path) -> None:
|
| 42 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 44 |
+
for row in rows:
|
| 45 |
+
f.write(json.dumps(row) + "\n")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _generate_distractors(client: OpenAI, question: str, ground_truth: str) -> list[str]:
|
| 49 |
+
"""Call GPT-4o-mini; return list of 2 distractor strings."""
|
| 50 |
+
prompt = PROMPT_TEMPLATE.format(question=question, ground_truth=ground_truth)
|
| 51 |
+
response = client.chat.completions.create(
|
| 52 |
+
model="gpt-4o-mini",
|
| 53 |
+
messages=[{"role": "user", "content": prompt}],
|
| 54 |
+
max_tokens=200,
|
| 55 |
+
temperature=0.9,
|
| 56 |
+
)
|
| 57 |
+
raw = response.choices[0].message.content.strip()
|
| 58 |
+
distractors = json.loads(raw)
|
| 59 |
+
if not isinstance(distractors, list) or len(distractors) != 2:
|
| 60 |
+
raise ValueError(f"Unexpected response format: {raw!r}")
|
| 61 |
+
return [str(d) for d in distractors]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def main() -> None:
|
| 65 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 66 |
+
if not api_key:
|
| 67 |
+
raise RuntimeError("OPENAI_API_KEY environment variable not set.")
|
| 68 |
+
|
| 69 |
+
client = OpenAI(api_key=api_key)
|
| 70 |
+
|
| 71 |
+
# Load source dataset
|
| 72 |
+
level1_rows: list[dict] = []
|
| 73 |
+
with open(LEVEL1_PATH, encoding="utf-8") as f:
|
| 74 |
+
for line in f:
|
| 75 |
+
line = line.strip()
|
| 76 |
+
if line:
|
| 77 |
+
level1_rows.append(json.loads(line))
|
| 78 |
+
|
| 79 |
+
print(f"Loaded {len(level1_rows)} questions from level1.jsonl")
|
| 80 |
+
|
| 81 |
+
# Load already-generated rows (idempotency)
|
| 82 |
+
existing = _load_existing(LEVEL2_PATH)
|
| 83 |
+
print(f"Already generated: {len(existing)} questions — skipping those.")
|
| 84 |
+
|
| 85 |
+
output_rows: list[dict] = list(existing.values())
|
| 86 |
+
new_count = 0
|
| 87 |
+
|
| 88 |
+
for i, row in enumerate(level1_rows):
|
| 89 |
+
if row["id"] in existing:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
distractors = _generate_distractors(client, row["question"], row["ground_truth"])
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f" ERROR on {row['id']}: {e} — skipping")
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
output_rows.append({
|
| 99 |
+
"id": row["id"],
|
| 100 |
+
"question": row["question"],
|
| 101 |
+
"ground_truth": row["ground_truth"],
|
| 102 |
+
"category": row.get("category", ""),
|
| 103 |
+
"distractors": distractors,
|
| 104 |
+
})
|
| 105 |
+
new_count += 1
|
| 106 |
+
|
| 107 |
+
# Save every 10 new entries
|
| 108 |
+
if new_count % 10 == 0:
|
| 109 |
+
_save_rows(output_rows, LEVEL2_PATH)
|
| 110 |
+
print(f" Progress: {new_count} new / {len(output_rows)} total saved")
|
| 111 |
+
|
| 112 |
+
# Rate-limit sleep
|
| 113 |
+
time.sleep(1)
|
| 114 |
+
|
| 115 |
+
# Final save
|
| 116 |
+
_save_rows(output_rows, LEVEL2_PATH)
|
| 117 |
+
print(f"\nDone. Generated {new_count} new entries. Total in level2.jsonl: {len(output_rows)}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|