Spaces:
Sleeping
Sleeping
| import random | |
| import uuid | |
| from openenv.core.env_server import Environment | |
| from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState | |
| from tool_use_env.grader import compute_grade | |
| class ToolUseEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| self._state = ToolUseState() | |
| self._tasks = self._load_tasks() | |
| def _load_tasks(self): | |
| return [ | |
| { | |
| "query": "What is 5 + 7?", | |
| "answer": "12", | |
| "correct_action": "use_calculator", | |
| "difficulty": "easy" | |
| }, | |
| { | |
| "query": "Capital of France?", | |
| "answer": "Paris", | |
| "correct_action": "use_search", | |
| "difficulty": "easy" | |
| }, | |
| { | |
| "query": "What is 123 * 456?", | |
| "answer": "56088", | |
| "correct_action": "use_calculator", | |
| "difficulty": "hard" | |
| }, | |
| { | |
| "query": "What is 25 * 4?", | |
| "answer": "100", | |
| "correct_action": "use_calculator", | |
| "difficulty": "medium" | |
| }, | |
| { | |
| "query": "Who is the CEO of Tesla?", | |
| "answer": "Elon Musk", | |
| "correct_action": "use_search", | |
| "difficulty": "medium" | |
| } | |
| ] | |
| def reset(self, seed=None, episode_id=None, **kwargs) -> ToolUseObservation: | |
| task = random.choice(self._tasks) | |
| self._state = ToolUseState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| current_query=task["query"], | |
| correct_action=task["correct_action"], | |
| correct_answer=task["answer"], | |
| difficulty=task["difficulty"] | |
| ) | |
| return ToolUseObservation( | |
| done=False, | |
| reward=None, | |
| query=task["query"], | |
| tool_output=None, | |
| message="Choose an action" | |
| ) | |
| # π’ Calculator tool (controlled noise) | |
| def _calculator(self, query): | |
| try: | |
| expr = query.lower() | |
| expr = expr.replace("what is", "").replace("?", "").strip() | |
| correct = eval(expr) | |
| difficulty = self._state.difficulty | |
| if difficulty == "easy": | |
| fail_prob = 0.06 | |
| elif difficulty == "medium": | |
| fail_prob = 0.12 | |
| else: | |
| fail_prob = 0.18 | |
| # complexity-based failure | |
| if len(query) > 20: | |
| fail_prob += 0.05 | |
| # π₯ cap failure (IMPORTANT) | |
| fail_prob = min(fail_prob, 0.25) | |
| if random.random() < fail_prob: | |
| # π₯ scale noise based on magnitude | |
| if abs(correct) < 50: | |
| noise = random.randint(-2, 2) | |
| else: | |
| noise = int(correct * random.uniform(-0.05, 0.05)) | |
| return str(correct + noise) | |
| return str(correct) | |
| except Exception: | |
| return "error" | |
| # π Search tool (controlled noise) | |
| def _search(self, query): | |
| kb = { | |
| "Capital of France": "Paris", | |
| "CEO of Tesla": "Elon Musk" | |
| } | |
| difficulty = self._state.difficulty | |
| for key in kb: | |
| if key.lower() in query.lower(): | |
| if difficulty == "easy": | |
| fail_prob = 0.07 | |
| elif difficulty == "medium": | |
| fail_prob = 0.15 | |
| else: | |
| fail_prob = 0.22 | |
| # complexity-based failure | |
| if len(query) > 20: | |
| fail_prob += 0.05 | |
| # π₯ cap failure | |
| fail_prob = min(fail_prob, 0.30) | |
| if random.random() < fail_prob: | |
| return random.choice([ | |
| "Unknown", | |
| "Not sure", | |
| "No results found" | |
| ]) | |
| return kb[key] | |
| return "not found" | |
| def step(self, action: ToolUseAction, timeout_s=None, **kwargs) -> ToolUseObservation: | |
| self._state.step_count += 1 | |
| query = self._state.current_query | |
| correct_action = self._state.correct_action | |
| correct_answer = self._state.correct_answer | |
| difficulty = self._state.difficulty | |
| action_type = action.action_type | |
| # --- Execute tool --- | |
| if action_type == "use_calculator": | |
| output = self._calculator(query) | |
| elif action_type == "use_search": | |
| output = self._search(query) | |
| elif action_type == "answer_directly": | |
| output = "unknown" | |
| else: | |
| output = "invalid action" | |
| # --- Check correctness --- | |
| answer_correct = (output == correct_answer) | |
| # π§ REWARD SYSTEM (FINAL) | |
| # 1. Action correctness | |
| action_score = 0.4 if action_type == correct_action else 0.1 | |
| # 2. Answer correctness | |
| answer_score = 0.5 if answer_correct else 0.0 | |
| # 3. Tool cost (small penalty) | |
| if action_type == "use_calculator": | |
| tool_penalty = 0.05 | |
| elif action_type == "use_search": | |
| tool_penalty = 0.08 | |
| else: | |
| tool_penalty = 0.0 | |
| # 4. Failure bonus (good reasoning but tool failed) | |
| failure_bonus = 0.1 if (not answer_correct and action_type == correct_action) else 0.0 | |
| # 5. Combine | |
| reward = action_score + answer_score + failure_bonus - tool_penalty | |
| # 6. Difficulty scaling (light) | |
| if difficulty == "medium": | |
| reward *= 1.02 | |
| elif difficulty == "hard": | |
| reward *= 0.9 | |
| # 7. Clamp (VERY IMPORTANT) | |
| reward = max(0.0, min(1.0, reward)) | |
| # --- Grade (for reporting only) --- | |
| grade = compute_grade( | |
| action_taken=action_type, | |
| correct_action=correct_action, | |
| output=output, | |
| correct_answer=correct_answer | |
| ) | |
| return ToolUseObservation( | |
| done=True, | |
| reward=reward, | |
| query=query, | |
| tool_output=output, | |
| message=( | |
| f"Action: {action_type}, " | |
| f"Output: {output}, " | |
| f"Correct: {answer_correct}, " | |
| f"Reward: {reward:.2f}, " | |
| f"Grade: {grade:.2f}" | |
| ) | |
| ) | |
| def state(self) -> ToolUseState: | |
| return self._state |