Spaces:
Sleeping
Sleeping
| from tool_use_env.grader import grade_task | |
| from tool_use_env.models import ToolUseAction | |
| from tool_use_env.server.tool_use_env_environment import ToolUseEnvironment | |
| from tool_use_env.tasks import TASKS | |
| def test_easy_task_rewards_progress_before_submission(): | |
| env = ToolUseEnvironment() | |
| reset_obs = env.reset(task_id="damaged-mug-replacement") | |
| assert reset_obs.task_id == "damaged-mug-replacement" | |
| assert "ticket" in reset_obs.collected_evidence | |
| obs = env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="order")) | |
| assert obs.reward > 0 | |
| assert "artifact:order" in obs.collected_evidence | |
| assert not obs.done | |
| def test_correct_resolution_finishes_with_high_score(): | |
| env = ToolUseEnvironment() | |
| env.reset(task_id="duplicate-charge-refund") | |
| env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="order")) | |
| env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="payment")) | |
| env.step(ToolUseAction(action_type="search_policy", query="duplicate_charge")) | |
| env.step( | |
| ToolUseAction( | |
| action_type="draft_reply", | |
| message=( | |
| "We confirmed the duplicate charge and issued a refund. " | |
| "You should see it in 3-5 business days." | |
| ), | |
| ) | |
| ) | |
| final_obs = env.step( | |
| ToolUseAction( | |
| action_type="submit_resolution", | |
| resolution_code="refund_duplicate_charge", | |
| ) | |
| ) | |
| assert final_obs.done is True | |
| assert final_obs.current_score >= 0.9 | |
| assert final_obs.reward >= 0.9 | |
| def test_grader_penalizes_wrong_resolution(): | |
| task = TASKS["account-takeover-fraud"] | |
| result = grade_task( | |
| task=task, | |
| collected_evidence=["ticket", "artifact:account", "artifact:risk_log"], | |
| drafted_reply="We locked the account and the fraud team will contact you within 24 hours.", | |
| resolution_code="issue_refund", | |
| step_count=5, | |
| repeat_action_count=0, | |
| ) | |
| assert 0.0 <= result["final_score"] < 0.6 | |