Jayant-Kernel Claude Sonnet 4.6 commited on
Commit
f2049f5
·
unverified ·
1 Parent(s): b44d7b0

feat: extend reset() to support level=2 with distractor context

Browse files
src/deceit_env/server/environment.py CHANGED
@@ -27,6 +27,9 @@ from deceit_env.server.grader import Grader
27
  _DEFAULT_DATASET = (
28
  pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
29
  )
 
 
 
30
 
31
  STEP_PENALTY = -0.05
32
  MAX_TURNS = 3
@@ -62,11 +65,14 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
62
  def __init__(
63
  self,
64
  dataset_path: str | pathlib.Path = _DEFAULT_DATASET,
 
65
  grader: Optional[Grader] = None,
66
  seed: Optional[int] = None,
67
  ) -> None:
68
  super().__init__()
69
  self._dataset = self._load_dataset(pathlib.Path(dataset_path))
 
 
70
  self._grader = grader or Grader(
71
  openai_api_key=os.environ.get("OPENAI_API_KEY")
72
  )
@@ -82,18 +88,28 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
82
  self,
83
  seed: Optional[int] = None,
84
  episode_id: Optional[str] = None,
 
85
  **kwargs,
86
  ) -> DeceitObservation:
87
  """Pick a random question and initialize a new episode."""
88
  if seed is not None:
89
  self._rng = random.Random(seed)
90
 
91
- question_row = self._rng.choice(self._dataset)
 
 
 
 
 
 
 
 
 
92
  self._current_question = question_row["question"]
93
  self._state = DeceitState(
94
  episode_id=episode_id or str(uuid.uuid4()),
95
  step_count=0,
96
- level=1,
97
  ground_truth=question_row["ground_truth"],
98
  current_question_id=question_row["id"],
99
  episode_rewards=[],
@@ -102,10 +118,10 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
102
  )
103
  return DeceitObservation(
104
  question=self._current_question,
105
- context=[],
106
  turn_index=0,
107
  max_turns=MAX_TURNS,
108
- level=1,
109
  )
110
 
111
  def step(
@@ -207,3 +223,25 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
207
  if not rows:
208
  raise ValueError(f"Dataset at {path} is empty.")
209
  return rows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  _DEFAULT_DATASET = (
28
  pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
29
  )
30
+ _DEFAULT_LEVEL2_DATASET = (
31
+ pathlib.Path(__file__).parent.parent / "data" / "level2.jsonl"
32
+ )
33
 
34
  STEP_PENALTY = -0.05
35
  MAX_TURNS = 3
 
65
  def __init__(
66
  self,
67
  dataset_path: str | pathlib.Path = _DEFAULT_DATASET,
68
+ level2_dataset_path: str | pathlib.Path = _DEFAULT_LEVEL2_DATASET,
69
  grader: Optional[Grader] = None,
70
  seed: Optional[int] = None,
71
  ) -> None:
72
  super().__init__()
73
  self._dataset = self._load_dataset(pathlib.Path(dataset_path))
74
+ self._level2_dataset_path = pathlib.Path(level2_dataset_path)
75
+ self._level2_dataset: list[dict] | None = None
76
  self._grader = grader or Grader(
77
  openai_api_key=os.environ.get("OPENAI_API_KEY")
78
  )
 
88
  self,
89
  seed: Optional[int] = None,
90
  episode_id: Optional[str] = None,
91
+ level: int = 1,
92
  **kwargs,
93
  ) -> DeceitObservation:
94
  """Pick a random question and initialize a new episode."""
95
  if seed is not None:
96
  self._rng = random.Random(seed)
97
 
98
+ if level == 2:
99
+ dataset = self._get_level2_dataset()
100
+ question_row = self._rng.choice(dataset)
101
+ distractors: list[str] = list(question_row.get("distractors", []))
102
+ self._rng.shuffle(distractors)
103
+ context = distractors
104
+ else:
105
+ question_row = self._rng.choice(self._dataset)
106
+ context = []
107
+
108
  self._current_question = question_row["question"]
109
  self._state = DeceitState(
110
  episode_id=episode_id or str(uuid.uuid4()),
111
  step_count=0,
112
+ level=level,
113
  ground_truth=question_row["ground_truth"],
114
  current_question_id=question_row["id"],
115
  episode_rewards=[],
 
118
  )
119
  return DeceitObservation(
120
  question=self._current_question,
121
+ context=context,
122
  turn_index=0,
123
  max_turns=MAX_TURNS,
124
+ level=level,
125
  )
126
 
127
  def step(
 
223
  if not rows:
224
  raise ValueError(f"Dataset at {path} is empty.")
225
  return rows
226
+
227
+ def _get_level2_dataset(self) -> list[dict]:
228
+ if self._level2_dataset is None:
229
+ self._level2_dataset = self._load_level2_dataset(self._level2_dataset_path)
230
+ return self._level2_dataset
231
+
232
+ @staticmethod
233
+ def _load_level2_dataset(path: pathlib.Path) -> list[dict]:
234
+ if not path.exists():
235
+ raise FileNotFoundError(
236
+ f"Level 2 dataset not found at {path}. "
237
+ "Run scripts/generate_distractors.py first."
238
+ )
239
+ rows = []
240
+ with open(path, encoding="utf-8") as f:
241
+ for line in f:
242
+ line = line.strip()
243
+ if line:
244
+ rows.append(json.loads(line))
245
+ if not rows:
246
+ raise ValueError(f"Level 2 dataset at {path} is empty.")
247
+ return rows