tool-use-openenv / tool_use_env /server /tool_use_env_environment.py
Clove25's picture
Upload 41 files
d9175ae verified
import random
import uuid
from openenv.core.env_server import Environment
from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState
from tool_use_env.grader import compute_grade
class ToolUseEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self):
self._state = ToolUseState()
self._tasks = self._load_tasks()
def _load_tasks(self):
return [
{
"query": "What is 5 + 7?",
"answer": "12",
"correct_action": "use_calculator",
"difficulty": "easy"
},
{
"query": "Capital of France?",
"answer": "Paris",
"correct_action": "use_search",
"difficulty": "easy"
},
{
"query": "What is 123 * 456?",
"answer": "56088",
"correct_action": "use_calculator",
"difficulty": "hard"
},
{
"query": "What is 25 * 4?",
"answer": "100",
"correct_action": "use_calculator",
"difficulty": "medium"
},
{
"query": "Who is the CEO of Tesla?",
"answer": "Elon Musk",
"correct_action": "use_search",
"difficulty": "medium"
}
]
def reset(self, seed=None, episode_id=None, **kwargs) -> ToolUseObservation:
task = random.choice(self._tasks)
self._state = ToolUseState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
current_query=task["query"],
correct_action=task["correct_action"],
correct_answer=task["answer"],
difficulty=task["difficulty"]
)
return ToolUseObservation(
done=False,
reward=None,
query=task["query"],
tool_output=None,
message="Choose an action"
)
# πŸ”’ Calculator tool (controlled noise)
def _calculator(self, query):
try:
expr = query.lower()
expr = expr.replace("what is", "").replace("?", "").strip()
correct = eval(expr)
difficulty = self._state.difficulty
if difficulty == "easy":
fail_prob = 0.06
elif difficulty == "medium":
fail_prob = 0.12
else:
fail_prob = 0.18
# complexity-based failure
if len(query) > 20:
fail_prob += 0.05
# πŸ”₯ cap failure (IMPORTANT)
fail_prob = min(fail_prob, 0.25)
if random.random() < fail_prob:
# πŸ”₯ scale noise based on magnitude
if abs(correct) < 50:
noise = random.randint(-2, 2)
else:
noise = int(correct * random.uniform(-0.05, 0.05))
return str(correct + noise)
return str(correct)
except Exception:
return "error"
# πŸ” Search tool (controlled noise)
def _search(self, query):
kb = {
"Capital of France": "Paris",
"CEO of Tesla": "Elon Musk"
}
difficulty = self._state.difficulty
for key in kb:
if key.lower() in query.lower():
if difficulty == "easy":
fail_prob = 0.07
elif difficulty == "medium":
fail_prob = 0.15
else:
fail_prob = 0.22
# complexity-based failure
if len(query) > 20:
fail_prob += 0.05
# πŸ”₯ cap failure
fail_prob = min(fail_prob, 0.30)
if random.random() < fail_prob:
return random.choice([
"Unknown",
"Not sure",
"No results found"
])
return kb[key]
return "not found"
def step(self, action: ToolUseAction, timeout_s=None, **kwargs) -> ToolUseObservation:
self._state.step_count += 1
query = self._state.current_query
correct_action = self._state.correct_action
correct_answer = self._state.correct_answer
difficulty = self._state.difficulty
action_type = action.action_type
# --- Execute tool ---
if action_type == "use_calculator":
output = self._calculator(query)
elif action_type == "use_search":
output = self._search(query)
elif action_type == "answer_directly":
output = "unknown"
else:
output = "invalid action"
# --- Check correctness ---
answer_correct = (output == correct_answer)
# 🧠 REWARD SYSTEM (FINAL)
# 1. Action correctness
action_score = 0.4 if action_type == correct_action else 0.1
# 2. Answer correctness
answer_score = 0.5 if answer_correct else 0.0
# 3. Tool cost (small penalty)
if action_type == "use_calculator":
tool_penalty = 0.05
elif action_type == "use_search":
tool_penalty = 0.08
else:
tool_penalty = 0.0
# 4. Failure bonus (good reasoning but tool failed)
failure_bonus = 0.1 if (not answer_correct and action_type == correct_action) else 0.0
# 5. Combine
reward = action_score + answer_score + failure_bonus - tool_penalty
# 6. Difficulty scaling (light)
if difficulty == "medium":
reward *= 1.02
elif difficulty == "hard":
reward *= 0.9
# 7. Clamp (VERY IMPORTANT)
reward = max(0.0, min(1.0, reward))
# --- Grade (for reporting only) ---
grade = compute_grade(
action_taken=action_type,
correct_action=correct_action,
output=output,
correct_answer=correct_answer
)
return ToolUseObservation(
done=True,
reward=reward,
query=query,
tool_output=output,
message=(
f"Action: {action_type}, "
f"Output: {output}, "
f"Correct: {answer_correct}, "
f"Reward: {reward:.2f}, "
f"Grade: {grade:.2f}"
)
)
@property
def state(self) -> ToolUseState:
return self._state