Phase 2.5: multi-turn episodes, bug fixes, dataset cleanup
Browse filesCritical fixes:
- Multi-turn episodes (max_turns=3): is_final flag on DeceitAction,
step penalty -0.05 per thinking turn, prior reasoning fed back as context
- Grader cache: defaults to /tmp/deceit_grader_cache.json (HF Spaces safe),
overridable via DECEIT_GRADER_CACHE env var
- Dataset: replaced 6 ambiguous questions (q051,q058,q066,q072,q091,q100)
Other improvements:
- DeceitState adds prior_reasoning: list[str] and max_turns: int
- DeceitAction adds is_final: bool = False
- __init__.py: package-level exports for DeceitEnvironment, models, Grader
- app.py: factory pattern limitation documented
- test_models.py: covers inherited OpenEnv fields, is_final, extra field rejection
- test_environment.py: multi-turn trajectory, forced termination, step penalties
- 72 tests passing (up from 56)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- scripts/generate_level1_dataset.py +7 -7
- src/deceit_env/__init__.py +12 -0
- src/deceit_env/data/level1.jsonl +7 -7
- src/deceit_env/models.py +8 -1
- src/deceit_env/server/app.py +3 -0
- src/deceit_env/server/environment.py +61 -11
- src/deceit_env/server/grader.py +10 -3
- tests/test_environment.py +100 -21
- tests/test_models.py +41 -3
|
@@ -64,14 +64,14 @@ QUESTIONS = [
|
|
| 64 |
{"id": "q048", "question": "What gas do plants absorb during photosynthesis?", "ground_truth": "carbon dioxide", "category": "science"},
|
| 65 |
{"id": "q049", "question": "What is the most abundant gas in Earth's atmosphere?", "ground_truth": "nitrogen", "category": "science"},
|
| 66 |
{"id": "q050", "question": "What is the chemical formula for water?", "ground_truth": "H2O", "category": "science"},
|
| 67 |
-
{"id": "q051", "question": "What planet
|
| 68 |
{"id": "q052", "question": "What is the largest organ in the human body?", "ground_truth": "skin", "category": "science"},
|
| 69 |
{"id": "q053", "question": "What is the chemical symbol for silver?", "ground_truth": "Ag", "category": "science"},
|
| 70 |
{"id": "q054", "question": "What is the atomic number of oxygen?", "ground_truth": "8", "category": "science"},
|
| 71 |
{"id": "q055", "question": "What is the chemical formula for table salt?", "ground_truth": "NaCl", "category": "science"},
|
| 72 |
{"id": "q056", "question": "What is the hardest natural substance on Earth?", "ground_truth": "diamond", "category": "science"},
|
| 73 |
{"id": "q057", "question": "What force keeps planets in orbit around the Sun?", "ground_truth": "gravity", "category": "science"},
|
| 74 |
-
{"id": "q058", "question": "What
|
| 75 |
{"id": "q059", "question": "What is the boiling point of water in Celsius?", "ground_truth": "100", "category": "science"},
|
| 76 |
{"id": "q060", "question": "What is the freezing point of water in Celsius?", "ground_truth": "0", "category": "science"},
|
| 77 |
{"id": "q061", "question": "How many chromosomes does a normal human cell have?", "ground_truth": "46", "category": "science"},
|
|
@@ -81,13 +81,13 @@ QUESTIONS = [
|
|
| 81 |
{"id": "q065", "question": "What particle has a negative charge in an atom?", "ground_truth": "electron", "category": "science"},
|
| 82 |
|
| 83 |
# --- Math (15) ---
|
| 84 |
-
{"id": "q066", "question": "What
|
| 85 |
{"id": "q067", "question": "What is the square root of 144?", "ground_truth": "12", "category": "math"},
|
| 86 |
{"id": "q068", "question": "What is 15 percent of 200?", "ground_truth": "30", "category": "math"},
|
| 87 |
{"id": "q069", "question": "What is the sum of angles in a triangle in degrees?", "ground_truth": "180", "category": "math"},
|
| 88 |
{"id": "q070", "question": "What is 2 to the power of 10?", "ground_truth": "1024", "category": "math"},
|
| 89 |
{"id": "q071", "question": "What is the square root of 256?", "ground_truth": "16", "category": "math"},
|
| 90 |
-
{"id": "q072", "question": "What
|
| 91 |
{"id": "q073", "question": "How many sides does a heptagon have?", "ground_truth": "7", "category": "math"},
|
| 92 |
{"id": "q074", "question": "What is the factorial of 5?", "ground_truth": "120", "category": "math"},
|
| 93 |
{"id": "q075", "question": "What is the area of a circle with radius 1?", "ground_truth": "pi", "category": "math"},
|
|
@@ -102,13 +102,13 @@ QUESTIONS = [
|
|
| 102 |
{"id": "q082", "question": "What is the currency of the United Kingdom?", "ground_truth": "pound", "category": "general"},
|
| 103 |
{"id": "q083", "question": "How many players are on a standard soccer team?", "ground_truth": "11", "category": "general"},
|
| 104 |
{"id": "q084", "question": "How many strings does a standard guitar have?", "ground_truth": "6", "category": "general"},
|
| 105 |
-
{"id": "q085", "question": "What is the
|
| 106 |
{"id": "q086", "question": "What language has the most native speakers in the world?", "ground_truth": "Mandarin", "category": "general"},
|
| 107 |
{"id": "q087", "question": "How many hours are in a week?", "ground_truth": "168", "category": "general"},
|
| 108 |
{"id": "q088", "question": "What is the national animal of Australia?", "ground_truth": "kangaroo", "category": "general"},
|
| 109 |
{"id": "q089", "question": "How many keys does a standard piano have?", "ground_truth": "88", "category": "general"},
|
| 110 |
{"id": "q090", "question": "What is the currency of India?", "ground_truth": "rupee", "category": "general"},
|
| 111 |
-
{"id": "q091", "question": "
|
| 112 |
{"id": "q092", "question": "What is the fastest land animal?", "ground_truth": "cheetah", "category": "general"},
|
| 113 |
{"id": "q093", "question": "How many teeth does an adult human have?", "ground_truth": "32", "category": "general"},
|
| 114 |
{"id": "q094", "question": "What is the chemical symbol for lead?", "ground_truth": "Pb", "category": "general"},
|
|
@@ -117,7 +117,7 @@ QUESTIONS = [
|
|
| 117 |
{"id": "q097", "question": "How many planets are in our solar system?", "ground_truth": "8", "category": "general"},
|
| 118 |
{"id": "q098", "question": "What is the currency of China?", "ground_truth": "yuan", "category": "general"},
|
| 119 |
{"id": "q099", "question": "How many sides does an octagon have?", "ground_truth": "8", "category": "general"},
|
| 120 |
-
{"id": "q100", "question": "What is the
|
| 121 |
]
|
| 122 |
|
| 123 |
|
|
|
|
| 64 |
{"id": "q048", "question": "What gas do plants absorb during photosynthesis?", "ground_truth": "carbon dioxide", "category": "science"},
|
| 65 |
{"id": "q049", "question": "What is the most abundant gas in Earth's atmosphere?", "ground_truth": "nitrogen", "category": "science"},
|
| 66 |
{"id": "q050", "question": "What is the chemical formula for water?", "ground_truth": "H2O", "category": "science"},
|
| 67 |
+
{"id": "q051", "question": "What is the largest planet in our solar system?", "ground_truth": "Jupiter", "category": "science"},
|
| 68 |
{"id": "q052", "question": "What is the largest organ in the human body?", "ground_truth": "skin", "category": "science"},
|
| 69 |
{"id": "q053", "question": "What is the chemical symbol for silver?", "ground_truth": "Ag", "category": "science"},
|
| 70 |
{"id": "q054", "question": "What is the atomic number of oxygen?", "ground_truth": "8", "category": "science"},
|
| 71 |
{"id": "q055", "question": "What is the chemical formula for table salt?", "ground_truth": "NaCl", "category": "science"},
|
| 72 |
{"id": "q056", "question": "What is the hardest natural substance on Earth?", "ground_truth": "diamond", "category": "science"},
|
| 73 |
{"id": "q057", "question": "What force keeps planets in orbit around the Sun?", "ground_truth": "gravity", "category": "science"},
|
| 74 |
+
{"id": "q058", "question": "What star does Earth orbit?", "ground_truth": "Sun", "category": "science"},
|
| 75 |
{"id": "q059", "question": "What is the boiling point of water in Celsius?", "ground_truth": "100", "category": "science"},
|
| 76 |
{"id": "q060", "question": "What is the freezing point of water in Celsius?", "ground_truth": "0", "category": "science"},
|
| 77 |
{"id": "q061", "question": "How many chromosomes does a normal human cell have?", "ground_truth": "46", "category": "science"},
|
|
|
|
| 81 |
{"id": "q065", "question": "What particle has a negative charge in an atom?", "ground_truth": "electron", "category": "science"},
|
| 82 |
|
| 83 |
# --- Math (15) ---
|
| 84 |
+
{"id": "q066", "question": "What are the first three digits of pi after the decimal point?", "ground_truth": "141", "category": "math"},
|
| 85 |
{"id": "q067", "question": "What is the square root of 144?", "ground_truth": "12", "category": "math"},
|
| 86 |
{"id": "q068", "question": "What is 15 percent of 200?", "ground_truth": "30", "category": "math"},
|
| 87 |
{"id": "q069", "question": "What is the sum of angles in a triangle in degrees?", "ground_truth": "180", "category": "math"},
|
| 88 |
{"id": "q070", "question": "What is 2 to the power of 10?", "ground_truth": "1024", "category": "math"},
|
| 89 |
{"id": "q071", "question": "What is the square root of 256?", "ground_truth": "16", "category": "math"},
|
| 90 |
+
{"id": "q072", "question": "What are the first three digits of Euler's number e after the decimal point?", "ground_truth": "718", "category": "math"},
|
| 91 |
{"id": "q073", "question": "How many sides does a heptagon have?", "ground_truth": "7", "category": "math"},
|
| 92 |
{"id": "q074", "question": "What is the factorial of 5?", "ground_truth": "120", "category": "math"},
|
| 93 |
{"id": "q075", "question": "What is the area of a circle with radius 1?", "ground_truth": "pi", "category": "math"},
|
|
|
|
| 102 |
{"id": "q082", "question": "What is the currency of the United Kingdom?", "ground_truth": "pound", "category": "general"},
|
| 103 |
{"id": "q083", "question": "How many players are on a standard soccer team?", "ground_truth": "11", "category": "general"},
|
| 104 |
{"id": "q084", "question": "How many strings does a standard guitar have?", "ground_truth": "6", "category": "general"},
|
| 105 |
+
{"id": "q085", "question": "What is the currency of Brazil?", "ground_truth": "real", "category": "general"},
|
| 106 |
{"id": "q086", "question": "What language has the most native speakers in the world?", "ground_truth": "Mandarin", "category": "general"},
|
| 107 |
{"id": "q087", "question": "How many hours are in a week?", "ground_truth": "168", "category": "general"},
|
| 108 |
{"id": "q088", "question": "What is the national animal of Australia?", "ground_truth": "kangaroo", "category": "general"},
|
| 109 |
{"id": "q089", "question": "How many keys does a standard piano have?", "ground_truth": "88", "category": "general"},
|
| 110 |
{"id": "q090", "question": "What is the currency of India?", "ground_truth": "rupee", "category": "general"},
|
| 111 |
+
{"id": "q091", "question": "On which continent is the Amazon rainforest located?", "ground_truth": "South America", "category": "general"},
|
| 112 |
{"id": "q092", "question": "What is the fastest land animal?", "ground_truth": "cheetah", "category": "general"},
|
| 113 |
{"id": "q093", "question": "How many teeth does an adult human have?", "ground_truth": "32", "category": "general"},
|
| 114 |
{"id": "q094", "question": "What is the chemical symbol for lead?", "ground_truth": "Pb", "category": "general"},
|
|
|
|
| 117 |
{"id": "q097", "question": "How many planets are in our solar system?", "ground_truth": "8", "category": "general"},
|
| 118 |
{"id": "q098", "question": "What is the currency of China?", "ground_truth": "yuan", "category": "general"},
|
| 119 |
{"id": "q099", "question": "How many sides does an octagon have?", "ground_truth": "8", "category": "general"},
|
| 120 |
+
{"id": "q100", "question": "What is the official language of Brazil?", "ground_truth": "Portuguese", "category": "general"},
|
| 121 |
]
|
| 122 |
|
| 123 |
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deceit_env.models import DeceitAction, DeceitObservation, DeceitState
|
| 2 |
+
from deceit_env.server.environment import DeceitEnvironment
|
| 3 |
+
from deceit_env.server.grader import Grader, GraderResult
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"DeceitAction",
|
| 7 |
+
"DeceitObservation",
|
| 8 |
+
"DeceitState",
|
| 9 |
+
"DeceitEnvironment",
|
| 10 |
+
"Grader",
|
| 11 |
+
"GraderResult",
|
| 12 |
+
]
|
|
@@ -48,14 +48,14 @@
|
|
| 48 |
{"id": "q048", "question": "What gas do plants absorb during photosynthesis?", "ground_truth": "carbon dioxide", "category": "science"}
|
| 49 |
{"id": "q049", "question": "What is the most abundant gas in Earth's atmosphere?", "ground_truth": "nitrogen", "category": "science"}
|
| 50 |
{"id": "q050", "question": "What is the chemical formula for water?", "ground_truth": "H2O", "category": "science"}
|
| 51 |
-
{"id": "q051", "question": "What planet
|
| 52 |
{"id": "q052", "question": "What is the largest organ in the human body?", "ground_truth": "skin", "category": "science"}
|
| 53 |
{"id": "q053", "question": "What is the chemical symbol for silver?", "ground_truth": "Ag", "category": "science"}
|
| 54 |
{"id": "q054", "question": "What is the atomic number of oxygen?", "ground_truth": "8", "category": "science"}
|
| 55 |
{"id": "q055", "question": "What is the chemical formula for table salt?", "ground_truth": "NaCl", "category": "science"}
|
| 56 |
{"id": "q056", "question": "What is the hardest natural substance on Earth?", "ground_truth": "diamond", "category": "science"}
|
| 57 |
{"id": "q057", "question": "What force keeps planets in orbit around the Sun?", "ground_truth": "gravity", "category": "science"}
|
| 58 |
-
{"id": "q058", "question": "What
|
| 59 |
{"id": "q059", "question": "What is the boiling point of water in Celsius?", "ground_truth": "100", "category": "science"}
|
| 60 |
{"id": "q060", "question": "What is the freezing point of water in Celsius?", "ground_truth": "0", "category": "science"}
|
| 61 |
{"id": "q061", "question": "How many chromosomes does a normal human cell have?", "ground_truth": "46", "category": "science"}
|
|
@@ -63,13 +63,13 @@
|
|
| 63 |
{"id": "q063", "question": "What is the chemical symbol for sodium?", "ground_truth": "Na", "category": "science"}
|
| 64 |
{"id": "q064", "question": "What is the unit of electrical resistance?", "ground_truth": "ohm", "category": "science"}
|
| 65 |
{"id": "q065", "question": "What particle has a negative charge in an atom?", "ground_truth": "electron", "category": "science"}
|
| 66 |
-
{"id": "q066", "question": "What
|
| 67 |
{"id": "q067", "question": "What is the square root of 144?", "ground_truth": "12", "category": "math"}
|
| 68 |
{"id": "q068", "question": "What is 15 percent of 200?", "ground_truth": "30", "category": "math"}
|
| 69 |
{"id": "q069", "question": "What is the sum of angles in a triangle in degrees?", "ground_truth": "180", "category": "math"}
|
| 70 |
{"id": "q070", "question": "What is 2 to the power of 10?", "ground_truth": "1024", "category": "math"}
|
| 71 |
{"id": "q071", "question": "What is the square root of 256?", "ground_truth": "16", "category": "math"}
|
| 72 |
-
{"id": "q072", "question": "What
|
| 73 |
{"id": "q073", "question": "How many sides does a heptagon have?", "ground_truth": "7", "category": "math"}
|
| 74 |
{"id": "q074", "question": "What is the factorial of 5?", "ground_truth": "120", "category": "math"}
|
| 75 |
{"id": "q075", "question": "What is the area of a circle with radius 1?", "ground_truth": "pi", "category": "math"}
|
|
@@ -82,13 +82,13 @@
|
|
| 82 |
{"id": "q082", "question": "What is the currency of the United Kingdom?", "ground_truth": "pound", "category": "general"}
|
| 83 |
{"id": "q083", "question": "How many players are on a standard soccer team?", "ground_truth": "11", "category": "general"}
|
| 84 |
{"id": "q084", "question": "How many strings does a standard guitar have?", "ground_truth": "6", "category": "general"}
|
| 85 |
-
{"id": "q085", "question": "What is the
|
| 86 |
{"id": "q086", "question": "What language has the most native speakers in the world?", "ground_truth": "Mandarin", "category": "general"}
|
| 87 |
{"id": "q087", "question": "How many hours are in a week?", "ground_truth": "168", "category": "general"}
|
| 88 |
{"id": "q088", "question": "What is the national animal of Australia?", "ground_truth": "kangaroo", "category": "general"}
|
| 89 |
{"id": "q089", "question": "How many keys does a standard piano have?", "ground_truth": "88", "category": "general"}
|
| 90 |
{"id": "q090", "question": "What is the currency of India?", "ground_truth": "rupee", "category": "general"}
|
| 91 |
-
{"id": "q091", "question": "
|
| 92 |
{"id": "q092", "question": "What is the fastest land animal?", "ground_truth": "cheetah", "category": "general"}
|
| 93 |
{"id": "q093", "question": "How many teeth does an adult human have?", "ground_truth": "32", "category": "general"}
|
| 94 |
{"id": "q094", "question": "What is the chemical symbol for lead?", "ground_truth": "Pb", "category": "general"}
|
|
@@ -97,4 +97,4 @@
|
|
| 97 |
{"id": "q097", "question": "How many planets are in our solar system?", "ground_truth": "8", "category": "general"}
|
| 98 |
{"id": "q098", "question": "What is the currency of China?", "ground_truth": "yuan", "category": "general"}
|
| 99 |
{"id": "q099", "question": "How many sides does an octagon have?", "ground_truth": "8", "category": "general"}
|
| 100 |
-
{"id": "q100", "question": "What is the
|
|
|
|
| 48 |
{"id": "q048", "question": "What gas do plants absorb during photosynthesis?", "ground_truth": "carbon dioxide", "category": "science"}
|
| 49 |
{"id": "q049", "question": "What is the most abundant gas in Earth's atmosphere?", "ground_truth": "nitrogen", "category": "science"}
|
| 50 |
{"id": "q050", "question": "What is the chemical formula for water?", "ground_truth": "H2O", "category": "science"}
|
| 51 |
+
{"id": "q051", "question": "What is the largest planet in our solar system?", "ground_truth": "Jupiter", "category": "science"}
|
| 52 |
{"id": "q052", "question": "What is the largest organ in the human body?", "ground_truth": "skin", "category": "science"}
|
| 53 |
{"id": "q053", "question": "What is the chemical symbol for silver?", "ground_truth": "Ag", "category": "science"}
|
| 54 |
{"id": "q054", "question": "What is the atomic number of oxygen?", "ground_truth": "8", "category": "science"}
|
| 55 |
{"id": "q055", "question": "What is the chemical formula for table salt?", "ground_truth": "NaCl", "category": "science"}
|
| 56 |
{"id": "q056", "question": "What is the hardest natural substance on Earth?", "ground_truth": "diamond", "category": "science"}
|
| 57 |
{"id": "q057", "question": "What force keeps planets in orbit around the Sun?", "ground_truth": "gravity", "category": "science"}
|
| 58 |
+
{"id": "q058", "question": "What star does Earth orbit?", "ground_truth": "Sun", "category": "science"}
|
| 59 |
{"id": "q059", "question": "What is the boiling point of water in Celsius?", "ground_truth": "100", "category": "science"}
|
| 60 |
{"id": "q060", "question": "What is the freezing point of water in Celsius?", "ground_truth": "0", "category": "science"}
|
| 61 |
{"id": "q061", "question": "How many chromosomes does a normal human cell have?", "ground_truth": "46", "category": "science"}
|
|
|
|
| 63 |
{"id": "q063", "question": "What is the chemical symbol for sodium?", "ground_truth": "Na", "category": "science"}
|
| 64 |
{"id": "q064", "question": "What is the unit of electrical resistance?", "ground_truth": "ohm", "category": "science"}
|
| 65 |
{"id": "q065", "question": "What particle has a negative charge in an atom?", "ground_truth": "electron", "category": "science"}
|
| 66 |
+
{"id": "q066", "question": "What are the first three digits of pi after the decimal point?", "ground_truth": "141", "category": "math"}
|
| 67 |
{"id": "q067", "question": "What is the square root of 144?", "ground_truth": "12", "category": "math"}
|
| 68 |
{"id": "q068", "question": "What is 15 percent of 200?", "ground_truth": "30", "category": "math"}
|
| 69 |
{"id": "q069", "question": "What is the sum of angles in a triangle in degrees?", "ground_truth": "180", "category": "math"}
|
| 70 |
{"id": "q070", "question": "What is 2 to the power of 10?", "ground_truth": "1024", "category": "math"}
|
| 71 |
{"id": "q071", "question": "What is the square root of 256?", "ground_truth": "16", "category": "math"}
|
| 72 |
+
{"id": "q072", "question": "What are the first three digits of Euler's number e after the decimal point?", "ground_truth": "718", "category": "math"}
|
| 73 |
{"id": "q073", "question": "How many sides does a heptagon have?", "ground_truth": "7", "category": "math"}
|
| 74 |
{"id": "q074", "question": "What is the factorial of 5?", "ground_truth": "120", "category": "math"}
|
| 75 |
{"id": "q075", "question": "What is the area of a circle with radius 1?", "ground_truth": "pi", "category": "math"}
|
|
|
|
| 82 |
{"id": "q082", "question": "What is the currency of the United Kingdom?", "ground_truth": "pound", "category": "general"}
|
| 83 |
{"id": "q083", "question": "How many players are on a standard soccer team?", "ground_truth": "11", "category": "general"}
|
| 84 |
{"id": "q084", "question": "How many strings does a standard guitar have?", "ground_truth": "6", "category": "general"}
|
| 85 |
+
{"id": "q085", "question": "What is the currency of Brazil?", "ground_truth": "real", "category": "general"}
|
| 86 |
{"id": "q086", "question": "What language has the most native speakers in the world?", "ground_truth": "Mandarin", "category": "general"}
|
| 87 |
{"id": "q087", "question": "How many hours are in a week?", "ground_truth": "168", "category": "general"}
|
| 88 |
{"id": "q088", "question": "What is the national animal of Australia?", "ground_truth": "kangaroo", "category": "general"}
|
| 89 |
{"id": "q089", "question": "How many keys does a standard piano have?", "ground_truth": "88", "category": "general"}
|
| 90 |
{"id": "q090", "question": "What is the currency of India?", "ground_truth": "rupee", "category": "general"}
|
| 91 |
+
{"id": "q091", "question": "On which continent is the Amazon rainforest located?", "ground_truth": "South America", "category": "general"}
|
| 92 |
{"id": "q092", "question": "What is the fastest land animal?", "ground_truth": "cheetah", "category": "general"}
|
| 93 |
{"id": "q093", "question": "How many teeth does an adult human have?", "ground_truth": "32", "category": "general"}
|
| 94 |
{"id": "q094", "question": "What is the chemical symbol for lead?", "ground_truth": "Pb", "category": "general"}
|
|
|
|
| 97 |
{"id": "q097", "question": "How many planets are in our solar system?", "ground_truth": "8", "category": "general"}
|
| 98 |
{"id": "q098", "question": "What is the currency of China?", "ground_truth": "yuan", "category": "general"}
|
| 99 |
{"id": "q099", "question": "How many sides does an octagon have?", "ground_truth": "8", "category": "general"}
|
| 100 |
+
{"id": "q100", "question": "What is the official language of Brazil?", "ground_truth": "Portuguese", "category": "general"}
|
|
@@ -15,12 +15,17 @@ class DeceitObservation(Observation):
|
|
| 15 |
|
| 16 |
|
| 17 |
class DeceitAction(Action):
|
| 18 |
-
"""What the agent produces each step.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
reasoning: str
|
| 21 |
answer: str = ""
|
| 22 |
confidence: float = 0.5
|
| 23 |
abstain: bool = False
|
|
|
|
| 24 |
|
| 25 |
@field_validator("confidence")
|
| 26 |
@classmethod
|
|
@@ -37,3 +42,5 @@ class DeceitState(State):
|
|
| 37 |
ground_truth: str = ""
|
| 38 |
current_question_id: str = ""
|
| 39 |
episode_rewards: list[float] = []
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class DeceitAction(Action):
|
| 18 |
+
"""What the agent produces each step.
|
| 19 |
+
|
| 20 |
+
Set is_final=True to commit an answer and end the episode.
|
| 21 |
+
Set is_final=False to think for another turn (costs a -0.05 step penalty).
|
| 22 |
+
"""
|
| 23 |
|
| 24 |
reasoning: str
|
| 25 |
answer: str = ""
|
| 26 |
confidence: float = 0.5
|
| 27 |
abstain: bool = False
|
| 28 |
+
is_final: bool = False
|
| 29 |
|
| 30 |
@field_validator("confidence")
|
| 31 |
@classmethod
|
|
|
|
| 42 |
ground_truth: str = ""
|
| 43 |
current_question_id: str = ""
|
| 44 |
episode_rewards: list[float] = []
|
| 45 |
+
prior_reasoning: list[str] = []
|
| 46 |
+
max_turns: int = 3
|
|
@@ -5,6 +5,9 @@ from openenv.core.env_server import create_fastapi_app
|
|
| 5 |
from deceit_env.models import DeceitAction, DeceitObservation
|
| 6 |
from deceit_env.server.environment import DeceitEnvironment
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
app = create_fastapi_app(
|
| 9 |
env=DeceitEnvironment,
|
| 10 |
action_cls=DeceitAction,
|
|
|
|
| 5 |
from deceit_env.models import DeceitAction, DeceitObservation
|
| 6 |
from deceit_env.server.environment import DeceitEnvironment
|
| 7 |
|
| 8 |
+
# Note: create_fastapi_app expects a callable factory (no args).
|
| 9 |
+
# For default env construction (dataset from package data dir), passing the
|
| 10 |
+
# class directly works. For custom config use: lambda: DeceitEnvironment(dataset_path=...)
|
| 11 |
app = create_fastapi_app(
|
| 12 |
env=DeceitEnvironment,
|
| 13 |
action_cls=DeceitAction,
|
|
@@ -1,4 +1,14 @@
|
|
| 1 |
-
"""Level 1 Deceit environment — factual QA,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -18,6 +28,9 @@ _DEFAULT_DATASET = (
|
|
| 18 |
pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def compute_reward(
|
| 23 |
correct: bool,
|
|
@@ -39,10 +52,11 @@ def compute_reward(
|
|
| 39 |
|
| 40 |
|
| 41 |
class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState]):
|
| 42 |
-
"""OpenEnv-compliant
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(
|
|
@@ -58,6 +72,7 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 58 |
)
|
| 59 |
self._rng = random.Random(seed)
|
| 60 |
self._state: DeceitState = DeceitState()
|
|
|
|
| 61 |
|
| 62 |
# ------------------------------------------------------------------
|
| 63 |
# OpenEnv interface
|
|
@@ -74,6 +89,7 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 74 |
self._rng = random.Random(seed)
|
| 75 |
|
| 76 |
question_row = self._rng.choice(self._dataset)
|
|
|
|
| 77 |
self._state = DeceitState(
|
| 78 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 79 |
step_count=0,
|
|
@@ -81,12 +97,14 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 81 |
ground_truth=question_row["ground_truth"],
|
| 82 |
current_question_id=question_row["id"],
|
| 83 |
episode_rewards=[],
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
return DeceitObservation(
|
| 86 |
-
question=
|
| 87 |
context=[],
|
| 88 |
turn_index=0,
|
| 89 |
-
max_turns=
|
| 90 |
level=1,
|
| 91 |
)
|
| 92 |
|
|
@@ -96,9 +114,35 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 96 |
timeout_s: Optional[float] = None,
|
| 97 |
**kwargs,
|
| 98 |
) -> DeceitObservation:
|
| 99 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
self._state.step_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
|
|
|
| 102 |
if action.abstain:
|
| 103 |
correctness_r, calibration_r = 0.0, 0.0
|
| 104 |
grader_method = "abstain"
|
|
@@ -113,15 +157,19 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 113 |
)
|
| 114 |
grader_method = result.method
|
| 115 |
|
|
|
|
| 116 |
total_reward = correctness_r + calibration_r
|
| 117 |
self._state.episode_rewards.append(total_reward)
|
| 118 |
|
| 119 |
return DeceitObservation(
|
| 120 |
-
question=
|
| 121 |
-
context=[
|
|
|
|
|
|
|
|
|
|
| 122 |
turn_index=self._state.step_count,
|
| 123 |
-
max_turns=
|
| 124 |
-
level=
|
| 125 |
done=True,
|
| 126 |
reward=total_reward,
|
| 127 |
metadata={
|
|
@@ -129,6 +177,8 @@ class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState
|
|
| 129 |
"calibration_reward": calibration_r,
|
| 130 |
"grader_method": grader_method,
|
| 131 |
"correct": correct,
|
|
|
|
|
|
|
| 132 |
},
|
| 133 |
)
|
| 134 |
|
|
|
|
| 1 |
+
"""Level 1 Deceit environment — factual QA, multi-turn, no adversary.
|
| 2 |
+
|
| 3 |
+
Episode structure (max_turns=3):
|
| 4 |
+
- Each step where is_final=False: agent pays a -0.05 step penalty and gets
|
| 5 |
+
their own reasoning appended to the next observation's context.
|
| 6 |
+
- When is_final=True OR step_count >= max_turns: episode ends, full reward
|
| 7 |
+
(correctness + calibration) is returned.
|
| 8 |
+
|
| 9 |
+
This multi-turn design gives GRPO meaningful trajectory length and teaches the
|
| 10 |
+
model to "think more when uncertain" — the core Deceit behavior.
|
| 11 |
+
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 28 |
pathlib.Path(__file__).parent.parent / "data" / "level1.jsonl"
|
| 29 |
)
|
| 30 |
|
| 31 |
+
STEP_PENALTY = -0.05
|
| 32 |
+
MAX_TURNS = 3
|
| 33 |
+
|
| 34 |
|
| 35 |
def compute_reward(
|
| 36 |
correct: bool,
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
class DeceitEnvironment(Environment[DeceitAction, DeceitObservation, DeceitState]):
|
| 55 |
+
"""OpenEnv-compliant multi-turn environment for the Deceit project.
|
| 56 |
|
| 57 |
+
Level 1: factual QA with no distractors or adversary.
|
| 58 |
+
Up to max_turns=3 steps per episode. Each non-final step costs a small
|
| 59 |
+
step penalty and feeds the agent's reasoning back as context.
|
| 60 |
"""
|
| 61 |
|
| 62 |
def __init__(
|
|
|
|
| 72 |
)
|
| 73 |
self._rng = random.Random(seed)
|
| 74 |
self._state: DeceitState = DeceitState()
|
| 75 |
+
self._current_question: str = ""
|
| 76 |
|
| 77 |
# ------------------------------------------------------------------
|
| 78 |
# OpenEnv interface
|
|
|
|
| 89 |
self._rng = random.Random(seed)
|
| 90 |
|
| 91 |
question_row = self._rng.choice(self._dataset)
|
| 92 |
+
self._current_question = question_row["question"]
|
| 93 |
self._state = DeceitState(
|
| 94 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 95 |
step_count=0,
|
|
|
|
| 97 |
ground_truth=question_row["ground_truth"],
|
| 98 |
current_question_id=question_row["id"],
|
| 99 |
episode_rewards=[],
|
| 100 |
+
prior_reasoning=[],
|
| 101 |
+
max_turns=MAX_TURNS,
|
| 102 |
)
|
| 103 |
return DeceitObservation(
|
| 104 |
+
question=self._current_question,
|
| 105 |
context=[],
|
| 106 |
turn_index=0,
|
| 107 |
+
max_turns=MAX_TURNS,
|
| 108 |
level=1,
|
| 109 |
)
|
| 110 |
|
|
|
|
| 114 |
timeout_s: Optional[float] = None,
|
| 115 |
**kwargs,
|
| 116 |
) -> DeceitObservation:
|
| 117 |
+
"""Process one agent turn.
|
| 118 |
+
|
| 119 |
+
Non-final step: pay step penalty, append reasoning to context, continue.
|
| 120 |
+
Final step (is_final=True or turn limit reached): compute full reward.
|
| 121 |
+
"""
|
| 122 |
self._state.step_count += 1
|
| 123 |
+
forced_final = self._state.step_count >= self._state.max_turns
|
| 124 |
+
is_terminal = action.is_final or forced_final
|
| 125 |
+
|
| 126 |
+
if not is_terminal:
|
| 127 |
+
# Thinking turn: no grading, just step penalty
|
| 128 |
+
self._state.prior_reasoning.append(action.reasoning)
|
| 129 |
+
self._state.episode_rewards.append(STEP_PENALTY)
|
| 130 |
+
context = [
|
| 131 |
+
f"Your previous reasoning (turn {i + 1}): {r}"
|
| 132 |
+
for i, r in enumerate(self._state.prior_reasoning)
|
| 133 |
+
]
|
| 134 |
+
return DeceitObservation(
|
| 135 |
+
question=self._current_question,
|
| 136 |
+
context=context,
|
| 137 |
+
turn_index=self._state.step_count,
|
| 138 |
+
max_turns=self._state.max_turns,
|
| 139 |
+
level=self._state.level,
|
| 140 |
+
done=False,
|
| 141 |
+
reward=STEP_PENALTY,
|
| 142 |
+
metadata={"step_penalty": STEP_PENALTY, "is_final": False},
|
| 143 |
+
)
|
| 144 |
|
| 145 |
+
# Terminal turn: grade and compute full reward
|
| 146 |
if action.abstain:
|
| 147 |
correctness_r, calibration_r = 0.0, 0.0
|
| 148 |
grader_method = "abstain"
|
|
|
|
| 157 |
)
|
| 158 |
grader_method = result.method
|
| 159 |
|
| 160 |
+
# Add step penalties already accumulated for non-final turns
|
| 161 |
total_reward = correctness_r + calibration_r
|
| 162 |
self._state.episode_rewards.append(total_reward)
|
| 163 |
|
| 164 |
return DeceitObservation(
|
| 165 |
+
question=self._current_question,
|
| 166 |
+
context=[
|
| 167 |
+
f"Your previous reasoning (turn {i + 1}): {r}"
|
| 168 |
+
for i, r in enumerate(self._state.prior_reasoning)
|
| 169 |
+
],
|
| 170 |
turn_index=self._state.step_count,
|
| 171 |
+
max_turns=self._state.max_turns,
|
| 172 |
+
level=self._state.level,
|
| 173 |
done=True,
|
| 174 |
reward=total_reward,
|
| 175 |
metadata={
|
|
|
|
| 177 |
"calibration_reward": calibration_r,
|
| 178 |
"grader_method": grader_method,
|
| 179 |
"correct": correct,
|
| 180 |
+
"is_final": True,
|
| 181 |
+
"forced_final": forced_final,
|
| 182 |
},
|
| 183 |
)
|
| 184 |
|
|
@@ -13,12 +13,19 @@ import re
|
|
| 13 |
import pathlib
|
| 14 |
from dataclasses import dataclass
|
| 15 |
|
|
|
|
|
|
|
| 16 |
try:
|
| 17 |
from openai import OpenAI
|
| 18 |
except ImportError:
|
| 19 |
OpenAI = None # type: ignore[assignment,misc]
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass
|
|
@@ -40,10 +47,10 @@ class Grader:
|
|
| 40 |
|
| 41 |
def __init__(
|
| 42 |
self,
|
| 43 |
-
cache_path: str | pathlib.Path =
|
| 44 |
openai_api_key: str | None = None,
|
| 45 |
) -> None:
|
| 46 |
-
self._cache_path = pathlib.Path(cache_path)
|
| 47 |
self._openai_api_key = openai_api_key
|
| 48 |
self._cache: dict[str, bool] = {}
|
| 49 |
if self._cache_path.exists():
|
|
|
|
| 13 |
import pathlib
|
| 14 |
from dataclasses import dataclass
|
| 15 |
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
try:
|
| 19 |
from openai import OpenAI
|
| 20 |
except ImportError:
|
| 21 |
OpenAI = None # type: ignore[assignment,misc]
|
| 22 |
|
| 23 |
+
def _default_cache_path() -> pathlib.Path:
|
| 24 |
+
"""Use DECEIT_GRADER_CACHE env var, falling back to /tmp."""
|
| 25 |
+
env_path = os.environ.get("DECEIT_GRADER_CACHE")
|
| 26 |
+
if env_path:
|
| 27 |
+
return pathlib.Path(env_path)
|
| 28 |
+
return pathlib.Path("/tmp/deceit_grader_cache.json")
|
| 29 |
|
| 30 |
|
| 31 |
@dataclass
|
|
|
|
| 47 |
|
| 48 |
def __init__(
|
| 49 |
self,
|
| 50 |
+
cache_path: str | pathlib.Path | None = None,
|
| 51 |
openai_api_key: str | None = None,
|
| 52 |
) -> None:
|
| 53 |
+
self._cache_path = pathlib.Path(cache_path) if cache_path is not None else _default_cache_path()
|
| 54 |
self._openai_api_key = openai_api_key
|
| 55 |
self._cache: dict[str, bool] = {}
|
| 56 |
if self._cache_path.exists():
|
|
@@ -5,7 +5,7 @@ import pytest
|
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
| 7 |
from deceit_env.models import DeceitAction, DeceitObservation, DeceitState
|
| 8 |
-
from deceit_env.server.environment import DeceitEnvironment
|
| 9 |
from deceit_env.server.grader import GraderResult
|
| 10 |
|
| 11 |
|
|
@@ -50,6 +50,10 @@ class TestReset:
|
|
| 50 |
obs = env_correct.reset()
|
| 51 |
assert obs.level == 1
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def test_reset_initializes_state(self, env_correct):
|
| 54 |
env_correct.reset()
|
| 55 |
state = env_correct.state
|
|
@@ -57,48 +61,49 @@ class TestReset:
|
|
| 57 |
assert state.episode_id is not None
|
| 58 |
assert state.step_count == 0
|
| 59 |
assert state.ground_truth != ""
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
-
class
|
| 63 |
-
def
|
| 64 |
env_correct.reset(seed=42)
|
| 65 |
-
action = DeceitAction(reasoning="
|
| 66 |
obs = env_correct.step(action)
|
| 67 |
-
assert obs.reward
|
| 68 |
|
| 69 |
-
def
|
| 70 |
env_correct.reset(seed=42)
|
| 71 |
-
action = DeceitAction(reasoning="
|
| 72 |
obs = env_correct.step(action)
|
| 73 |
-
assert obs.reward
|
| 74 |
|
| 75 |
-
def
|
| 76 |
env_wrong.reset(seed=42)
|
| 77 |
-
action = DeceitAction(reasoning="
|
| 78 |
obs = env_wrong.step(action)
|
| 79 |
-
assert obs.reward
|
| 80 |
|
| 81 |
-
def
|
| 82 |
env_wrong.reset(seed=42)
|
| 83 |
-
action = DeceitAction(reasoning="
|
| 84 |
obs = env_wrong.step(action)
|
| 85 |
-
assert obs.reward
|
| 86 |
|
| 87 |
def test_step_abstain_reward_is_zero(self, env_correct):
|
| 88 |
env_correct.reset(seed=42)
|
| 89 |
-
action = DeceitAction(reasoning="Not sure.", answer="", abstain=True, confidence=0.3)
|
| 90 |
obs = env_correct.step(action)
|
| 91 |
assert obs.reward == pytest.approx(0.0)
|
| 92 |
|
| 93 |
-
def
|
| 94 |
env_correct.reset(seed=42)
|
| 95 |
-
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.8)
|
| 96 |
obs = env_correct.step(action)
|
| 97 |
assert obs.done is True
|
| 98 |
|
| 99 |
def test_step_metadata_contains_grader_info(self, env_correct):
|
| 100 |
env_correct.reset(seed=42)
|
| 101 |
-
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9)
|
| 102 |
obs = env_correct.step(action)
|
| 103 |
assert "grader_method" in obs.metadata
|
| 104 |
assert "correct" in obs.metadata
|
|
@@ -107,18 +112,91 @@ class TestStep:
|
|
| 107 |
|
| 108 |
def test_state_updated_after_step(self, env_correct):
|
| 109 |
env_correct.reset(seed=42)
|
| 110 |
-
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9)
|
| 111 |
env_correct.step(action)
|
| 112 |
assert env_correct.state.step_count == 1
|
| 113 |
assert len(env_correct.state.episode_rewards) == 1
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
class TestMultipleEpisodes:
|
| 117 |
def test_reset_step_reset_step_sequence(self, env_correct):
|
| 118 |
for _ in range(3):
|
| 119 |
obs = env_correct.reset()
|
| 120 |
assert isinstance(obs, DeceitObservation)
|
| 121 |
-
action = DeceitAction(reasoning="r", answer="x", confidence=0.8)
|
| 122 |
result = env_correct.step(action)
|
| 123 |
assert result.done is True
|
| 124 |
assert env_correct.state.step_count == 1
|
|
@@ -126,10 +204,11 @@ class TestMultipleEpisodes:
|
|
| 126 |
def test_state_resets_between_episodes(self, env_correct):
|
| 127 |
env_correct.reset(seed=1)
|
| 128 |
first_id = env_correct.state.episode_id
|
| 129 |
-
env_correct.step(DeceitAction(reasoning="r", answer="x", confidence=0.8))
|
| 130 |
|
| 131 |
env_correct.reset(seed=2)
|
| 132 |
second_id = env_correct.state.episode_id
|
| 133 |
assert first_id != second_id
|
| 134 |
assert env_correct.state.step_count == 0
|
| 135 |
assert env_correct.state.episode_rewards == []
|
|
|
|
|
|
| 5 |
from unittest.mock import MagicMock
|
| 6 |
|
| 7 |
from deceit_env.models import DeceitAction, DeceitObservation, DeceitState
|
| 8 |
+
from deceit_env.server.environment import DeceitEnvironment, STEP_PENALTY
|
| 9 |
from deceit_env.server.grader import GraderResult
|
| 10 |
|
| 11 |
|
|
|
|
| 50 |
obs = env_correct.reset()
|
| 51 |
assert obs.level == 1
|
| 52 |
|
| 53 |
+
def test_reset_observation_max_turns_3(self, env_correct):
|
| 54 |
+
obs = env_correct.reset()
|
| 55 |
+
assert obs.max_turns == 3
|
| 56 |
+
|
| 57 |
def test_reset_initializes_state(self, env_correct):
|
| 58 |
env_correct.reset()
|
| 59 |
state = env_correct.state
|
|
|
|
| 61 |
assert state.episode_id is not None
|
| 62 |
assert state.step_count == 0
|
| 63 |
assert state.ground_truth != ""
|
| 64 |
+
assert state.prior_reasoning == []
|
| 65 |
|
| 66 |
|
| 67 |
+
class TestSingleStepFinal:
|
| 68 |
+
def test_step_correct_confident_reward_is_1_3(self, env_correct):
|
| 69 |
env_correct.reset(seed=42)
|
| 70 |
+
action = DeceitAction(reasoning="Sure.", answer="Canberra", confidence=0.9, is_final=True)
|
| 71 |
obs = env_correct.step(action)
|
| 72 |
+
assert obs.reward == pytest.approx(1.3)
|
| 73 |
|
| 74 |
+
def test_step_correct_returns_positive_reward(self, env_correct):
|
| 75 |
env_correct.reset(seed=42)
|
| 76 |
+
action = DeceitAction(reasoning="I know this.", answer="Canberra", confidence=0.9, is_final=True)
|
| 77 |
obs = env_correct.step(action)
|
| 78 |
+
assert obs.reward > 0
|
| 79 |
|
| 80 |
+
def test_step_wrong_confident_reward_is_minus_1_3(self, env_wrong):
|
| 81 |
env_wrong.reset(seed=42)
|
| 82 |
+
action = DeceitAction(reasoning="Wrong.", answer="Sydney", confidence=0.9, is_final=True)
|
| 83 |
obs = env_wrong.step(action)
|
| 84 |
+
assert obs.reward == pytest.approx(-1.3)
|
| 85 |
|
| 86 |
+
def test_step_wrong_returns_negative_reward(self, env_wrong):
|
| 87 |
env_wrong.reset(seed=42)
|
| 88 |
+
action = DeceitAction(reasoning="Guessing.", answer="Sydney", confidence=0.9, is_final=True)
|
| 89 |
obs = env_wrong.step(action)
|
| 90 |
+
assert obs.reward < 0
|
| 91 |
|
| 92 |
def test_step_abstain_reward_is_zero(self, env_correct):
|
| 93 |
env_correct.reset(seed=42)
|
| 94 |
+
action = DeceitAction(reasoning="Not sure.", answer="", abstain=True, confidence=0.3, is_final=True)
|
| 95 |
obs = env_correct.step(action)
|
| 96 |
assert obs.reward == pytest.approx(0.0)
|
| 97 |
|
| 98 |
+
def test_step_final_sets_done_true(self, env_correct):
|
| 99 |
env_correct.reset(seed=42)
|
| 100 |
+
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.8, is_final=True)
|
| 101 |
obs = env_correct.step(action)
|
| 102 |
assert obs.done is True
|
| 103 |
|
| 104 |
def test_step_metadata_contains_grader_info(self, env_correct):
|
| 105 |
env_correct.reset(seed=42)
|
| 106 |
+
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9, is_final=True)
|
| 107 |
obs = env_correct.step(action)
|
| 108 |
assert "grader_method" in obs.metadata
|
| 109 |
assert "correct" in obs.metadata
|
|
|
|
| 112 |
|
| 113 |
def test_state_updated_after_step(self, env_correct):
|
| 114 |
env_correct.reset(seed=42)
|
| 115 |
+
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9, is_final=True)
|
| 116 |
env_correct.step(action)
|
| 117 |
assert env_correct.state.step_count == 1
|
| 118 |
assert len(env_correct.state.episode_rewards) == 1
|
| 119 |
|
| 120 |
|
| 121 |
+
class TestMultiTurnEpisodes:
|
| 122 |
+
def test_non_final_step_returns_done_false(self, env_correct):
|
| 123 |
+
env_correct.reset(seed=42)
|
| 124 |
+
action = DeceitAction(reasoning="thinking...", is_final=False)
|
| 125 |
+
obs = env_correct.step(action)
|
| 126 |
+
assert obs.done is False
|
| 127 |
+
|
| 128 |
+
def test_non_final_step_returns_step_penalty(self, env_correct):
|
| 129 |
+
env_correct.reset(seed=42)
|
| 130 |
+
action = DeceitAction(reasoning="thinking...", is_final=False)
|
| 131 |
+
obs = env_correct.step(action)
|
| 132 |
+
assert obs.reward == pytest.approx(STEP_PENALTY)
|
| 133 |
+
|
| 134 |
+
def test_non_final_step_appends_reasoning_to_context(self, env_correct):
|
| 135 |
+
env_correct.reset(seed=42)
|
| 136 |
+
action = DeceitAction(reasoning="First I considered Sydney.", is_final=False)
|
| 137 |
+
obs = env_correct.step(action)
|
| 138 |
+
assert any("First I considered Sydney." in c for c in obs.context)
|
| 139 |
+
|
| 140 |
+
def test_multi_turn_full_trajectory(self, env_correct):
|
| 141 |
+
# reset → think → think → commit correct → positive final reward
|
| 142 |
+
env_correct.reset(seed=42)
|
| 143 |
+
|
| 144 |
+
obs1 = env_correct.step(DeceitAction(reasoning="Turn 1 thinking", is_final=False))
|
| 145 |
+
assert obs1.done is False
|
| 146 |
+
assert obs1.reward == pytest.approx(STEP_PENALTY)
|
| 147 |
+
|
| 148 |
+
obs2 = env_correct.step(DeceitAction(reasoning="Turn 2 thinking", is_final=False))
|
| 149 |
+
assert obs2.done is False
|
| 150 |
+
assert obs2.reward == pytest.approx(STEP_PENALTY)
|
| 151 |
+
|
| 152 |
+
obs3 = env_correct.step(
|
| 153 |
+
DeceitAction(reasoning="Committed.", answer="Canberra", confidence=0.9, is_final=True)
|
| 154 |
+
)
|
| 155 |
+
assert obs3.done is True
|
| 156 |
+
assert obs3.reward == pytest.approx(1.3)
|
| 157 |
+
assert env_correct.state.step_count == 3
|
| 158 |
+
|
| 159 |
+
def test_forced_termination_at_max_turns(self, env_correct):
|
| 160 |
+
# 3 non-final steps — 3rd is forced final regardless of is_final flag
|
| 161 |
+
env_correct.reset(seed=42)
|
| 162 |
+
env_correct.step(DecaitAction := DeceitAction(reasoning="t1", is_final=False))
|
| 163 |
+
env_correct.step(DeceitAction(reasoning="t2", is_final=False))
|
| 164 |
+
# 3rd step hits max_turns, forced terminal
|
| 165 |
+
obs = env_correct.step(
|
| 166 |
+
DeceitAction(reasoning="t3", answer="Canberra", confidence=0.8, is_final=False)
|
| 167 |
+
)
|
| 168 |
+
assert obs.done is True
|
| 169 |
+
assert obs.metadata.get("forced_final") is True
|
| 170 |
+
|
| 171 |
+
def test_prior_reasoning_in_context_grows_each_turn(self, env_correct):
|
| 172 |
+
env_correct.reset(seed=42)
|
| 173 |
+
env_correct.step(DeceitAction(reasoning="step1", is_final=False))
|
| 174 |
+
obs = env_correct.step(DeceitAction(reasoning="step2", is_final=False))
|
| 175 |
+
assert len(obs.context) == 2
|
| 176 |
+
|
| 177 |
+
def test_state_prior_reasoning_accumulates(self, env_correct):
|
| 178 |
+
env_correct.reset(seed=42)
|
| 179 |
+
env_correct.step(DeceitAction(reasoning="thinking A", is_final=False))
|
| 180 |
+
env_correct.step(DeceitAction(reasoning="thinking B", is_final=False))
|
| 181 |
+
assert env_correct.state.prior_reasoning == ["thinking A", "thinking B"]
|
| 182 |
+
|
| 183 |
+
def test_episode_rewards_include_step_penalties(self, env_correct):
|
| 184 |
+
env_correct.reset(seed=42)
|
| 185 |
+
env_correct.step(DeceitAction(reasoning="t1", is_final=False))
|
| 186 |
+
env_correct.step(
|
| 187 |
+
DeceitAction(reasoning="commit", answer="Canberra", confidence=0.9, is_final=True)
|
| 188 |
+
)
|
| 189 |
+
rewards = env_correct.state.episode_rewards
|
| 190 |
+
assert rewards[0] == pytest.approx(STEP_PENALTY)
|
| 191 |
+
assert rewards[1] == pytest.approx(1.3)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
class TestMultipleEpisodes:
|
| 195 |
def test_reset_step_reset_step_sequence(self, env_correct):
|
| 196 |
for _ in range(3):
|
| 197 |
obs = env_correct.reset()
|
| 198 |
assert isinstance(obs, DeceitObservation)
|
| 199 |
+
action = DeceitAction(reasoning="r", answer="x", confidence=0.8, is_final=True)
|
| 200 |
result = env_correct.step(action)
|
| 201 |
assert result.done is True
|
| 202 |
assert env_correct.state.step_count == 1
|
|
|
|
| 204 |
def test_state_resets_between_episodes(self, env_correct):
|
| 205 |
env_correct.reset(seed=1)
|
| 206 |
first_id = env_correct.state.episode_id
|
| 207 |
+
env_correct.step(DeceitAction(reasoning="r", answer="x", confidence=0.8, is_final=True))
|
| 208 |
|
| 209 |
env_correct.reset(seed=2)
|
| 210 |
second_id = env_correct.state.episode_id
|
| 211 |
assert first_id != second_id
|
| 212 |
assert env_correct.state.step_count == 0
|
| 213 |
assert env_correct.state.episode_rewards == []
|
| 214 |
+
assert env_correct.state.prior_reasoning == []
|
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import json
|
| 2 |
import pytest
|
| 3 |
from pydantic import ValidationError
|
| 4 |
|
|
@@ -30,8 +29,24 @@ class TestDeceitObservation:
|
|
| 30 |
assert obs.max_turns == 5
|
| 31 |
assert obs.level == 2
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def test_json_roundtrip(self):
|
| 34 |
-
obs = DeceitObservation(question="Q", context=["ctx"], level=2)
|
| 35 |
data = obs.model_dump_json()
|
| 36 |
restored = DeceitObservation.model_validate_json(data)
|
| 37 |
assert restored == obs
|
|
@@ -47,6 +62,11 @@ class TestDeceitAction:
|
|
| 47 |
assert action.answer == ""
|
| 48 |
assert action.confidence == 0.5
|
| 49 |
assert action.abstain is False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def test_with_all_fields(self):
|
| 52 |
action = DeceitAction(
|
|
@@ -54,9 +74,11 @@ class TestDeceitAction:
|
|
| 54 |
answer="Canberra",
|
| 55 |
confidence=0.9,
|
| 56 |
abstain=False,
|
|
|
|
| 57 |
)
|
| 58 |
assert action.answer == "Canberra"
|
| 59 |
assert action.confidence == 0.9
|
|
|
|
| 60 |
|
| 61 |
def test_confidence_upper_bound_rejected(self):
|
| 62 |
with pytest.raises(ValidationError):
|
|
@@ -76,8 +98,12 @@ class TestDeceitAction:
|
|
| 76 |
action = DeceitAction(reasoning="unsure", abstain=True)
|
| 77 |
assert action.abstain is True
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def test_json_roundtrip(self):
|
| 80 |
-
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9)
|
| 81 |
data = action.model_dump_json()
|
| 82 |
restored = DeceitAction.model_validate_json(data)
|
| 83 |
assert restored == action
|
|
@@ -92,6 +118,8 @@ class TestDeceitState:
|
|
| 92 |
assert state.ground_truth == ""
|
| 93 |
assert state.current_question_id == ""
|
| 94 |
assert state.episode_rewards == []
|
|
|
|
|
|
|
| 95 |
|
| 96 |
def test_with_all_fields(self):
|
| 97 |
state = DeceitState(
|
|
@@ -101,10 +129,19 @@ class TestDeceitState:
|
|
| 101 |
ground_truth="Canberra",
|
| 102 |
current_question_id="q_042",
|
| 103 |
episode_rewards=[1.3, -1.1],
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
assert state.episode_id == "abc-123"
|
| 106 |
assert state.ground_truth == "Canberra"
|
| 107 |
assert state.episode_rewards == [1.3, -1.1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
def test_mutable_state_can_be_updated(self):
|
| 110 |
state = DeceitState()
|
|
@@ -118,6 +155,7 @@ class TestDeceitState:
|
|
| 118 |
episode_id="abc-123",
|
| 119 |
ground_truth="Canberra",
|
| 120 |
episode_rewards=[1.3, 0.0],
|
|
|
|
| 121 |
)
|
| 122 |
data = state.model_dump_json()
|
| 123 |
restored = DeceitState.model_validate_json(data)
|
|
|
|
|
|
|
| 1 |
import pytest
|
| 2 |
from pydantic import ValidationError
|
| 3 |
|
|
|
|
| 29 |
assert obs.max_turns == 5
|
| 30 |
assert obs.level == 2
|
| 31 |
|
| 32 |
+
def test_openenv_inherited_done_field(self):
|
| 33 |
+
obs = DeceitObservation(question="Q", done=True)
|
| 34 |
+
assert obs.done is True
|
| 35 |
+
|
| 36 |
+
def test_openenv_inherited_reward_field(self):
|
| 37 |
+
obs = DeceitObservation(question="Q", reward=1.3)
|
| 38 |
+
assert obs.reward == pytest.approx(1.3)
|
| 39 |
+
|
| 40 |
+
def test_openenv_inherited_metadata_field(self):
|
| 41 |
+
obs = DeceitObservation(question="Q", metadata={"key": "val"})
|
| 42 |
+
assert obs.metadata["key"] == "val"
|
| 43 |
+
|
| 44 |
+
def test_extra_fields_rejected(self):
|
| 45 |
+
with pytest.raises(ValidationError):
|
| 46 |
+
DeceitObservation(question="Q", nonexistent_field="boom")
|
| 47 |
+
|
| 48 |
def test_json_roundtrip(self):
|
| 49 |
+
obs = DeceitObservation(question="Q", context=["ctx"], level=2, done=True, reward=0.5)
|
| 50 |
data = obs.model_dump_json()
|
| 51 |
restored = DeceitObservation.model_validate_json(data)
|
| 52 |
assert restored == obs
|
|
|
|
| 62 |
assert action.answer == ""
|
| 63 |
assert action.confidence == 0.5
|
| 64 |
assert action.abstain is False
|
| 65 |
+
assert action.is_final is False
|
| 66 |
+
|
| 67 |
+
def test_is_final_field(self):
|
| 68 |
+
action = DeceitAction(reasoning="committing now", answer="Canberra", is_final=True)
|
| 69 |
+
assert action.is_final is True
|
| 70 |
|
| 71 |
def test_with_all_fields(self):
|
| 72 |
action = DeceitAction(
|
|
|
|
| 74 |
answer="Canberra",
|
| 75 |
confidence=0.9,
|
| 76 |
abstain=False,
|
| 77 |
+
is_final=True,
|
| 78 |
)
|
| 79 |
assert action.answer == "Canberra"
|
| 80 |
assert action.confidence == 0.9
|
| 81 |
+
assert action.is_final is True
|
| 82 |
|
| 83 |
def test_confidence_upper_bound_rejected(self):
|
| 84 |
with pytest.raises(ValidationError):
|
|
|
|
| 98 |
action = DeceitAction(reasoning="unsure", abstain=True)
|
| 99 |
assert action.abstain is True
|
| 100 |
|
| 101 |
+
def test_extra_fields_rejected(self):
|
| 102 |
+
with pytest.raises(ValidationError):
|
| 103 |
+
DeceitAction(reasoning="r", ghost_field=True)
|
| 104 |
+
|
| 105 |
def test_json_roundtrip(self):
|
| 106 |
+
action = DeceitAction(reasoning="r", answer="Canberra", confidence=0.9, is_final=True)
|
| 107 |
data = action.model_dump_json()
|
| 108 |
restored = DeceitAction.model_validate_json(data)
|
| 109 |
assert restored == action
|
|
|
|
| 118 |
assert state.ground_truth == ""
|
| 119 |
assert state.current_question_id == ""
|
| 120 |
assert state.episode_rewards == []
|
| 121 |
+
assert state.prior_reasoning == []
|
| 122 |
+
assert state.max_turns == 3
|
| 123 |
|
| 124 |
def test_with_all_fields(self):
|
| 125 |
state = DeceitState(
|
|
|
|
| 129 |
ground_truth="Canberra",
|
| 130 |
current_question_id="q_042",
|
| 131 |
episode_rewards=[1.3, -1.1],
|
| 132 |
+
prior_reasoning=["First I thought Sydney...", "Then reconsidered."],
|
| 133 |
+
max_turns=3,
|
| 134 |
)
|
| 135 |
assert state.episode_id == "abc-123"
|
| 136 |
assert state.ground_truth == "Canberra"
|
| 137 |
assert state.episode_rewards == [1.3, -1.1]
|
| 138 |
+
assert len(state.prior_reasoning) == 2
|
| 139 |
+
|
| 140 |
+
def test_prior_reasoning_accumulates(self):
|
| 141 |
+
state = DeceitState()
|
| 142 |
+
state.prior_reasoning.append("step 1 thinking")
|
| 143 |
+
state.prior_reasoning.append("step 2 thinking")
|
| 144 |
+
assert len(state.prior_reasoning) == 2
|
| 145 |
|
| 146 |
def test_mutable_state_can_be_updated(self):
|
| 147 |
state = DeceitState()
|
|
|
|
| 155 |
episode_id="abc-123",
|
| 156 |
ground_truth="Canberra",
|
| 157 |
episode_rewards=[1.3, 0.0],
|
| 158 |
+
prior_reasoning=["I think it is Canberra"],
|
| 159 |
)
|
| 160 |
data = state.model_dump_json()
|
| 161 |
restored = DeceitState.model_validate_json(data)
|