Jayant-Kernel Claude Sonnet 4.6 commited on
Commit
d5d723b
·
unverified ·
1 Parent(s): ba97ba8

fix: propagate fatal API errors, strip markdown fences, cleanup

Browse files
Files changed (1) hide show
  1. scripts/generate_distractors.py +12 -4
scripts/generate_distractors.py CHANGED
@@ -11,11 +11,14 @@ import os
11
  import pathlib
12
  import time
13
 
 
14
  from openai import OpenAI
15
 
16
  LEVEL1_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level1.jsonl"
17
  LEVEL2_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level2.jsonl"
18
 
 
 
19
  PROMPT_TEMPLATE = (
20
  "Generate 2 plausible-sounding but FALSE statements about the following fact. "
21
  "They should sound authoritative and misleading but be factually wrong. "
@@ -49,12 +52,14 @@ def _generate_distractors(client: OpenAI, question: str, ground_truth: str) -> l
49
  """Call GPT-4o-mini; return list of 2 distractor strings."""
50
  prompt = PROMPT_TEMPLATE.format(question=question, ground_truth=ground_truth)
51
  response = client.chat.completions.create(
52
- model="gpt-4o-mini",
53
  messages=[{"role": "user", "content": prompt}],
54
  max_tokens=200,
55
  temperature=0.9,
56
  )
57
  raw = response.choices[0].message.content.strip()
 
 
58
  distractors = json.loads(raw)
59
  if not isinstance(distractors, list) or len(distractors) != 2:
60
  raise ValueError(f"Unexpected response format: {raw!r}")
@@ -86,7 +91,7 @@ def main() -> None:
86
  new_count = 0
87
  iteration_count = 0
88
 
89
- for i, row in enumerate(level1_rows):
90
  iteration_count += 1
91
 
92
  if row["id"] in existing:
@@ -94,9 +99,12 @@ def main() -> None:
94
 
95
  try:
96
  distractors = _generate_distractors(client, row["question"], row["ground_truth"])
 
 
 
 
97
  except Exception as e:
98
  print(f" ERROR on {row['id']}: {e} — skipping")
99
- # Rate-limit sleep after failed API call
100
  time.sleep(1)
101
  continue
102
 
@@ -112,7 +120,7 @@ def main() -> None:
112
  # Save and print progress every 10 loop iterations
113
  if iteration_count % 10 == 0:
114
  _save_rows(output_rows, LEVEL2_PATH)
115
- print(f" Progress: {iteration_count} processed / {new_count} new / {len(output_rows)} total saved")
116
 
117
  # Rate-limit sleep after successful API call
118
  time.sleep(1)
 
11
  import pathlib
12
  import time
13
 
14
+ import openai
15
  from openai import OpenAI
16
 
17
  LEVEL1_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level1.jsonl"
18
  LEVEL2_PATH = pathlib.Path(__file__).parent.parent / "src" / "deceit_env" / "data" / "level2.jsonl"
19
 
20
+ MODEL = "gpt-4o-mini"
21
+
22
  PROMPT_TEMPLATE = (
23
  "Generate 2 plausible-sounding but FALSE statements about the following fact. "
24
  "They should sound authoritative and misleading but be factually wrong. "
 
52
  """Call GPT-4o-mini; return list of 2 distractor strings."""
53
  prompt = PROMPT_TEMPLATE.format(question=question, ground_truth=ground_truth)
54
  response = client.chat.completions.create(
55
+ model=MODEL,
56
  messages=[{"role": "user", "content": prompt}],
57
  max_tokens=200,
58
  temperature=0.9,
59
  )
60
  raw = response.choices[0].message.content.strip()
61
+ if raw.startswith("```"):
62
+ raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
63
  distractors = json.loads(raw)
64
  if not isinstance(distractors, list) or len(distractors) != 2:
65
  raise ValueError(f"Unexpected response format: {raw!r}")
 
91
  new_count = 0
92
  iteration_count = 0
93
 
94
+ for row in level1_rows:
95
  iteration_count += 1
96
 
97
  if row["id"] in existing:
 
99
 
100
  try:
101
  distractors = _generate_distractors(client, row["question"], row["ground_truth"])
102
+ except openai.AuthenticationError as e:
103
+ raise RuntimeError(f"Unrecoverable API error (check OPENAI_API_KEY): {e}") from e
104
+ except openai.RateLimitError as e:
105
+ raise RuntimeError(f"Unrecoverable rate limit error: {e}") from e
106
  except Exception as e:
107
  print(f" ERROR on {row['id']}: {e} — skipping")
 
108
  time.sleep(1)
109
  continue
110
 
 
120
  # Save and print progress every 10 loop iterations
121
  if iteration_count % 10 == 0:
122
  _save_rows(output_rows, LEVEL2_PATH)
123
+ print(f" Progress: {iteration_count} seen / {new_count} new / {len(output_rows)} total saved")
124
 
125
  # Rate-limit sleep after successful API call
126
  time.sleep(1)