import random import uuid from openenv.core.env_server import Environment from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState from tool_use_env.grader import compute_grade class ToolUseEnvironment(Environment): SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self): self._state = ToolUseState() self._tasks = self._load_tasks() def _load_tasks(self): return [ { "query": "What is 5 + 7?", "answer": "12", "correct_action": "use_calculator", "difficulty": "easy" }, { "query": "Capital of France?", "answer": "Paris", "correct_action": "use_search", "difficulty": "easy" }, { "query": "What is 123 * 456?", "answer": "56088", "correct_action": "use_calculator", "difficulty": "hard" }, { "query": "What is 25 * 4?", "answer": "100", "correct_action": "use_calculator", "difficulty": "medium" }, { "query": "Who is the CEO of Tesla?", "answer": "Elon Musk", "correct_action": "use_search", "difficulty": "medium" } ] def reset(self, seed=None, episode_id=None, **kwargs) -> ToolUseObservation: task = random.choice(self._tasks) self._state = ToolUseState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, current_query=task["query"], correct_action=task["correct_action"], correct_answer=task["answer"], difficulty=task["difficulty"] ) return ToolUseObservation( done=False, reward=None, query=task["query"], tool_output=None, message="Choose an action" ) # 🔢 Calculator tool (controlled noise) def _calculator(self, query): try: expr = query.lower() expr = expr.replace("what is", "").replace("?", "").strip() correct = eval(expr) difficulty = self._state.difficulty if difficulty == "easy": fail_prob = 0.06 elif difficulty == "medium": fail_prob = 0.12 else: fail_prob = 0.18 # complexity-based failure if len(query) > 20: fail_prob += 0.05 # 🔥 cap failure (IMPORTANT) fail_prob = min(fail_prob, 0.25) if random.random() < fail_prob: # 🔥 scale noise based on magnitude if abs(correct) < 50: noise = random.randint(-2, 2) else: noise = int(correct * random.uniform(-0.05, 0.05)) return str(correct + noise) return str(correct) except Exception: return "error" # 🔍 Search tool (controlled noise) def _search(self, query): kb = { "Capital of France": "Paris", "CEO of Tesla": "Elon Musk" } difficulty = self._state.difficulty for key in kb: if key.lower() in query.lower(): if difficulty == "easy": fail_prob = 0.07 elif difficulty == "medium": fail_prob = 0.15 else: fail_prob = 0.22 # complexity-based failure if len(query) > 20: fail_prob += 0.05 # 🔥 cap failure fail_prob = min(fail_prob, 0.30) if random.random() < fail_prob: return random.choice([ "Unknown", "Not sure", "No results found" ]) return kb[key] return "not found" def step(self, action: ToolUseAction, timeout_s=None, **kwargs) -> ToolUseObservation: self._state.step_count += 1 query = self._state.current_query correct_action = self._state.correct_action correct_answer = self._state.correct_answer difficulty = self._state.difficulty action_type = action.action_type # --- Execute tool --- if action_type == "use_calculator": output = self._calculator(query) elif action_type == "use_search": output = self._search(query) elif action_type == "answer_directly": output = "unknown" else: output = "invalid action" # --- Check correctness --- answer_correct = (output == correct_answer) # 🧠 REWARD SYSTEM (FINAL) # 1. Action correctness action_score = 0.4 if action_type == correct_action else 0.1 # 2. Answer correctness answer_score = 0.5 if answer_correct else 0.0 # 3. Tool cost (small penalty) if action_type == "use_calculator": tool_penalty = 0.05 elif action_type == "use_search": tool_penalty = 0.08 else: tool_penalty = 0.0 # 4. Failure bonus (good reasoning but tool failed) failure_bonus = 0.1 if (not answer_correct and action_type == correct_action) else 0.0 # 5. Combine reward = action_score + answer_score + failure_bonus - tool_penalty # 6. Difficulty scaling (light) if difficulty == "medium": reward *= 1.02 elif difficulty == "hard": reward *= 0.9 # 7. Clamp (VERY IMPORTANT) reward = max(0.0, min(1.0, reward)) # --- Grade (for reporting only) --- grade = compute_grade( action_taken=action_type, correct_action=correct_action, output=output, correct_answer=correct_answer ) return ToolUseObservation( done=True, reward=reward, query=query, tool_output=output, message=( f"Action: {action_type}, " f"Output: {output}, " f"Correct: {answer_correct}, " f"Reward: {reward:.2f}, " f"Grade: {grade:.2f}" ) ) @property def state(self) -> ToolUseState: return self._state