Jayant-Kernel Claude Sonnet 4.6 commited on
feat: extend reset() to support level=2 with distractor context
Browse files
src/deceit_env/server/environment.py
CHANGED
|
@@ -27,6 +27,9 @@ from deceit_env.server.grader import Grader
|
|
| 27 |
_DEFAULT_DATASET = (
|
| 28 |
pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
|
| 29 |
)
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
STEP_PENALTY = -0.05
|
| 32 |
MAX_TURNS = 3
|
|
@@ -62,11 +65,14 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 62 |
def __init__(
|
| 63 |
self,
|
| 64 |
dataset_path: str | pathlib.Path = _DEFAULT_DATASET,
|
|
|
|
| 65 |
grader: Optional[Grader] = None,
|
| 66 |
seed: Optional[int] = None,
|
| 67 |
) -> None:
|
| 68 |
super().__init__()
|
| 69 |
self._dataset = self._load_dataset(pathlib.Path(dataset_path))
|
|
|
|
|
|
|
| 70 |
self._grader = grader or Grader(
|
| 71 |
openai_api_key=os.environ.get("OPENAI_API_KEY")
|
| 72 |
)
|
|
@@ -82,18 +88,28 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 82 |
self,
|
| 83 |
seed: Optional[int] = None,
|
| 84 |
episode_id: Optional[str] = None,
|
|
|
|
| 85 |
**kwargs,
|
| 86 |
) -> DeceitObservation:
|
| 87 |
"""Pick a random question and initialize a new episode."""
|
| 88 |
if seed is not None:
|
| 89 |
self._rng = random.Random(seed)
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
self._current_question = question_row["question"]
|
| 93 |
self._state = DeceitState(
|
| 94 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 95 |
step_count=0,
|
| 96 |
-
level=
|
| 97 |
ground_truth=question_row["ground_truth"],
|
| 98 |
current_question_id=question_row["id"],
|
| 99 |
episode_rewards=[],
|
|
@@ -102,10 +118,10 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 102 |
)
|
| 103 |
return DeceitObservation(
|
| 104 |
question=self._current_question,
|
| 105 |
-
context=
|
| 106 |
turn_index=0,
|
| 107 |
max_turns=MAX_TURNS,
|
| 108 |
-
level=
|
| 109 |
)
|
| 110 |
|
| 111 |
def step(
|
|
@@ -207,3 +223,25 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 207 |
if not rows:
|
| 208 |
raise ValueError(f"Dataset at {path} is empty.")
|
| 209 |
return rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
_DEFAULT_DATASET = (
|
| 28 |
pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
|
| 29 |
)
|
| 30 |
+
_DEFAULT_LEVEL2_DATASET = (
|
| 31 |
+
pathlib.Path(__file__).parent.parent / "data" / "level2.jsonl"
|
| 32 |
+
)
|
| 33 |
|
| 34 |
STEP_PENALTY = -0.05
|
| 35 |
MAX_TURNS = 3
|
|
|
|
| 65 |
def __init__(
|
| 66 |
self,
|
| 67 |
dataset_path: str | pathlib.Path = _DEFAULT_DATASET,
|
| 68 |
+
level2_dataset_path: str | pathlib.Path = _DEFAULT_LEVEL2_DATASET,
|
| 69 |
grader: Optional[Grader] = None,
|
| 70 |
seed: Optional[int] = None,
|
| 71 |
) -> None:
|
| 72 |
super().__init__()
|
| 73 |
self._dataset = self._load_dataset(pathlib.Path(dataset_path))
|
| 74 |
+
self._level2_dataset_path = pathlib.Path(level2_dataset_path)
|
| 75 |
+
self._level2_dataset: list[dict] | None = None
|
| 76 |
self._grader = grader or Grader(
|
| 77 |
openai_api_key=os.environ.get("OPENAI_API_KEY")
|
| 78 |
)
|
|
|
|
| 88 |
self,
|
| 89 |
seed: Optional[int] = None,
|
| 90 |
episode_id: Optional[str] = None,
|
| 91 |
+
level: int = 1,
|
| 92 |
**kwargs,
|
| 93 |
) -> DeceitObservation:
|
| 94 |
"""Pick a random question and initialize a new episode."""
|
| 95 |
if seed is not None:
|
| 96 |
self._rng = random.Random(seed)
|
| 97 |
|
| 98 |
+
if level == 2:
|
| 99 |
+
dataset = self._get_level2_dataset()
|
| 100 |
+
question_row = self._rng.choice(dataset)
|
| 101 |
+
distractors: list[str] = list(question_row.get("distractors", []))
|
| 102 |
+
self._rng.shuffle(distractors)
|
| 103 |
+
context = distractors
|
| 104 |
+
else:
|
| 105 |
+
question_row = self._rng.choice(self._dataset)
|
| 106 |
+
context = []
|
| 107 |
+
|
| 108 |
self._current_question = question_row["question"]
|
| 109 |
self._state = DeceitState(
|
| 110 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 111 |
step_count=0,
|
| 112 |
+
level=level,
|
| 113 |
ground_truth=question_row["ground_truth"],
|
| 114 |
current_question_id=question_row["id"],
|
| 115 |
episode_rewards=[],
|
|
|
|
| 118 |
)
|
| 119 |
return DeceitObservation(
|
| 120 |
question=self._current_question,
|
| 121 |
+
context=context,
|
| 122 |
turn_index=0,
|
| 123 |
max_turns=MAX_TURNS,
|
| 124 |
+
level=level,
|
| 125 |
)
|
| 126 |
|
| 127 |
def step(
|
|
|
|
| 223 |
if not rows:
|
| 224 |
raise ValueError(f"Dataset at {path} is empty.")
|
| 225 |
return rows
|
| 226 |
+
|
| 227 |
+
def _get_level2_dataset(self) -> list[dict]:
|
| 228 |
+
if self._level2_dataset is None:
|
| 229 |
+
self._level2_dataset = self._load_level2_dataset(self._level2_dataset_path)
|
| 230 |
+
return self._level2_dataset
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _load_level2_dataset(path: pathlib.Path) -> list[dict]:
|
| 234 |
+
if not path.exists():
|
| 235 |
+
raise FileNotFoundError(
|
| 236 |
+
f"Level 2 dataset not found at {path}. "
|
| 237 |
+
"Run scripts/generate_distractors.py first."
|
| 238 |
+
)
|
| 239 |
+
rows = []
|
| 240 |
+
with open(path, encoding="utf-8") as f:
|
| 241 |
+
for line in f:
|
| 242 |
+
line = line.strip()
|
| 243 |
+
if line:
|
| 244 |
+
rows.append(json.loads(line))
|
| 245 |
+
if not rows:
|
| 246 |
+
raise ValueError(f"Level 2 dataset at {path} is empty.")
|
| 247 |
+
return rows
|