File size: 2,038 Bytes
18feac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from tool_use_env.grader import grade_task
from tool_use_env.models import ToolUseAction
from tool_use_env.server.tool_use_env_environment import ToolUseEnvironment
from tool_use_env.tasks import TASKS


def test_easy_task_rewards_progress_before_submission():
    env = ToolUseEnvironment()
    reset_obs = env.reset(task_id="damaged-mug-replacement")
    assert reset_obs.task_id == "damaged-mug-replacement"
    assert "ticket" in reset_obs.collected_evidence

    obs = env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="order"))
    assert obs.reward > 0
    assert "artifact:order" in obs.collected_evidence
    assert not obs.done


def test_correct_resolution_finishes_with_high_score():
    env = ToolUseEnvironment()
    env.reset(task_id="duplicate-charge-refund")
    env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="order"))
    env.step(ToolUseAction(action_type="inspect_artifact", artifact_id="payment"))
    env.step(ToolUseAction(action_type="search_policy", query="duplicate_charge"))
    env.step(
        ToolUseAction(
            action_type="draft_reply",
            message=(
                "We confirmed the duplicate charge and issued a refund. "
                "You should see it in 3-5 business days."
            ),
        )
    )
    final_obs = env.step(
        ToolUseAction(
            action_type="submit_resolution",
            resolution_code="refund_duplicate_charge",
        )
    )

    assert final_obs.done is True
    assert final_obs.current_score >= 0.9
    assert final_obs.reward >= 0.9


def test_grader_penalizes_wrong_resolution():
    task = TASKS["account-takeover-fraud"]
    result = grade_task(
        task=task,
        collected_evidence=["ticket", "artifact:account", "artifact:risk_log"],
        drafted_reply="We locked the account and the fraud team will contact you within 24 hours.",
        resolution_code="issue_refund",
        step_count=5,
        repeat_action_count=0,
    )
    assert 0.0 <= result["final_score"] < 0.6