Jayant-Kernel commited on
Commit
1c6af55
·
unverified ·
1 Parent(s): c0e1de3

feat: generate_distractors.py — GPT-4o-mini Level 2 dataset script

Browse files
Files changed (1) hide show
  1. scripts/generate_distractors.py +121 -1
scripts/generate_distractors.py CHANGED
@@ -1 +1,121 @@
1
- # TODO: Phase 4 — generate distractor contexts for Level 2 dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate Level 2 distractor dataset from level1.jsonl using GPT-4o-mini.
2
+
3
+ Idempotent: skips questions already in level2.jsonl.
4
+ Saves every 10 questions so progress survives interruption.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import os
11
+ import pathlib
12
+ import time
13
+
14
+ from openai import OpenAI
15
+
16
+ LEVEL1_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level1.jsonl"
17
+ LEVEL2_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level2.jsonl"
18
+
19
+ PROMPT_TEMPLATE = (
20
+ "Generate 2 plausible-sounding but FALSE statements about the following fact. "
21
+ "They should sound authoritative and misleading but be factually wrong. "
22
+ "Fact: {question} Answer: {ground_truth}. "
23
+ "Return ONLY a JSON array of 2 strings, no other text."
24
+ )
25
+
26
+
27
+ def _load_existing(path: pathlib.Path) -> dict[str, dict]:
28
+ """Return dict keyed by question id of already-generated rows."""
29
+ if not path.exists():
30
+ return {}
31
+ result = {}
32
+ with open(path, encoding="utf-8") as f:
33
+ for line in f:
34
+ line = line.strip()
35
+ if line:
36
+ row = json.loads(line)
37
+ result[row["id"]] = row
38
+ return result
39
+
40
+
41
+ def _save_rows(rows: list[dict], path: pathlib.Path) -> None:
42
+ path.parent.mkdir(parents=True, exist_ok=True)
43
+ with open(path, "w", encoding="utf-8") as f:
44
+ for row in rows:
45
+ f.write(json.dumps(row) + "\n")
46
+
47
+
48
+ def _generate_distractors(client: OpenAI, question: str, ground_truth: str) -> list[str]:
49
+ """Call GPT-4o-mini; return list of 2 distractor strings."""
50
+ prompt = PROMPT_TEMPLATE.format(question=question, ground_truth=ground_truth)
51
+ response = client.chat.completions.create(
52
+ model="gpt-4o-mini",
53
+ messages=[{"role": "user", "content": prompt}],
54
+ max_tokens=200,
55
+ temperature=0.9,
56
+ )
57
+ raw = response.choices[0].message.content.strip()
58
+ distractors = json.loads(raw)
59
+ if not isinstance(distractors, list) or len(distractors) != 2:
60
+ raise ValueError(f"Unexpected response format: {raw!r}")
61
+ return [str(d) for d in distractors]
62
+
63
+
64
+ def main() -> None:
65
+ api_key = os.environ.get("OPENAI_API_KEY")
66
+ if not api_key:
67
+ raise RuntimeError("OPENAI_API_KEY environment variable not set.")
68
+
69
+ client = OpenAI(api_key=api_key)
70
+
71
+ # Load source dataset
72
+ level1_rows: list[dict] = []
73
+ with open(LEVEL1_PATH, encoding="utf-8") as f:
74
+ for line in f:
75
+ line = line.strip()
76
+ if line:
77
+ level1_rows.append(json.loads(line))
78
+
79
+ print(f"Loaded {len(level1_rows)} questions from level1.jsonl")
80
+
81
+ # Load already-generated rows (idempotency)
82
+ existing = _load_existing(LEVEL2_PATH)
83
+ print(f"Already generated: {len(existing)} questions — skipping those.")
84
+
85
+ output_rows: list[dict] = list(existing.values())
86
+ new_count = 0
87
+
88
+ for i, row in enumerate(level1_rows):
89
+ if row["id"] in existing:
90
+ continue
91
+
92
+ try:
93
+ distractors = _generate_distractors(client, row["question"], row["ground_truth"])
94
+ except Exception as e:
95
+ print(f" ERROR on {row['id']}: {e} — skipping")
96
+ continue
97
+
98
+ output_rows.append({
99
+ "id": row["id"],
100
+ "question": row["question"],
101
+ "ground_truth": row["ground_truth"],
102
+ "category": row.get("category", ""),
103
+ "distractors": distractors,
104
+ })
105
+ new_count += 1
106
+
107
+ # Save every 10 new entries
108
+ if new_count % 10 == 0:
109
+ _save_rows(output_rows, LEVEL2_PATH)
110
+ print(f" Progress: {new_count} new / {len(output_rows)} total saved")
111
+
112
+ # Rate-limit sleep
113
+ time.sleep(1)
114
+
115
+ # Final save
116
+ _save_rows(output_rows, LEVEL2_PATH)
117
+ print(f"\nDone. Generated {new_count} new entries. Total in level2.jsonl: {len(output_rows)}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()