File size: 593 Bytes
d9175ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from server.tool_use_env_environment import ToolUseEnvironment

env = ToolUseEnvironment()

def test_calculator_correct():
    result = env._calculator("What is 2 + 2?")
    assert result in ["4", "3", "5"]  # allow noise

def test_search():
    result = env._search("Capital of France?")
    assert result in ["Paris", "Unknown"]

def test_step_output():
    env = ToolUseEnvironment()
    
    action = {"action_type": "use_calculator"}
    result = env.step(action)
    obs = result.observation
    print(obs.query)

    assert -1 <= result.reward <= 1
    assert result.query is not None