Spaces:
Sleeping
Sleeping
File size: 4,828 Bytes
a8211b4 7813169 a8211b4 7813169 a8211b4 7813169 a8211b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import logging
from uuid import uuid4
from collections import deque
from typing import Dict, Any, List
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from .models import AutomathreasonerAction, AutomathreasonerObservation
from .generator import TaskGenerationEngine
from .verifier import VerifierSystem
from .rewards import RewardSystem
except ImportError:
from env.models import AutomathreasonerAction, AutomathreasonerObservation
from env.generator import TaskGenerationEngine
from env.verifier import VerifierSystem
from env.rewards import RewardSystem
logger = logging.getLogger(__name__)
class AutomathreasonerEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self.generator = TaskGenerationEngine()
self.verifier = VerifierSystem()
self.reward_system = RewardSystem(max_len=2000)
# Curriculum tracking
self.difficulty_level = 2.0 # Starting difficulty
self.rolling_results = deque(maxlen=20) # Keep track of last 20 results (1 for correct, 0 for incorrect)
# Current problem state
self.current_problem = ""
self.current_solution = ""
self.current_sympy_f = None # Integration Ground Truth
self.times_seen_problem = 0
self.history: List[Dict[str, Any]] = []
self.max_steps = 3
def _update_curriculum(self):
"""Update difficulty based on rolling accuracy"""
if len(self.rolling_results) >= 5:
accuracy = sum(self.rolling_results) / len(self.rolling_results)
if accuracy > 0.7:
self.difficulty_level += 0.5
elif accuracy < 0.6:
self.difficulty_level = max(1.0, self.difficulty_level - 0.5)
logger.info(f"Curriculum Updated: Accuracy={accuracy:.2f}, New Difficulty={self.difficulty_level}")
def reset(self) -> AutomathreasonerObservation:
"""Reset environment to a new problem."""
self._update_curriculum()
self._state = State(episode_id=str(uuid4()), step_count=0)
task = self.generator.generate_task(target_difficulty_band=self.difficulty_level)
self.current_problem = task['problem']
self.current_solution = task['solution']
self.current_sympy_f = task.get('sympy_f')
# The generator returns its own continuous difficulty score; we'll expose the target difficulty band
self.times_seen_problem = 0
self.history = []
return AutomathreasonerObservation(
problem_text=self.current_problem,
difficulty_level=self.difficulty_level,
history=[],
reward=0.0,
done=False
)
def step(self, action: AutomathreasonerAction) -> AutomathreasonerObservation: # type: ignore[override]
self._state.step_count += 1
# Verification
c, q, p_sup, r_ref = self.verifier.verify(
action.reasoning,
action.final_answer,
self.current_solution,
sympy_f=self.current_sympy_f
)
# Reward
action_str = f"{action.reasoning} \n {action.final_answer}"
total_r, components = self.reward_system.compute_reward(
correctness=c,
reasoning_quality=q,
process_supervision=p_sup,
reflection_score=r_ref,
action_str=action_str,
final_answer=action.final_answer,
history=self.history,
times_seen_problem=self.times_seen_problem
)
self.times_seen_problem += 1
# Update history
attempt = {
"prediction": action.final_answer,
"correctness": c
}
self.history.append(attempt)
# Keep only last 3 attempts for observation
obs_history = self.history[-3:]
is_correct = (c == 1.0)
done = is_correct or self._state.step_count >= self.max_steps
if done:
self.rolling_results.append(1 if is_correct else 0)
return AutomathreasonerObservation(
problem_text=self.current_problem,
difficulty_level=self.difficulty_level,
history=obs_history,
reward=total_r,
done=done,
metadata={
"reward_components": components,
"ground_truth": self.current_solution if done else "HIDDEN", # Only reveal on done or not at all
"is_correct": is_correct
}
)
@property
def state(self) -> State:
return self._state
|