File size: 3,383 Bytes
53dbcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations

from types import SimpleNamespace

import pytest

from flatmate_rl.inference import (
    ModelConfigurationError,
    build_user_prompt,
    get_model_action,
    malformed_action_observation,
    parse_action,
)
from flatmate_rl.server.flatmate_rl_environment import FlatmateRlEnvironment
from flatmate_rl.models import FlatmateRlAction


def test_strict_parse_rejects_tool_name_in_action_type() -> None:
    parsed = parse_action('{"action_type":"store_user_details","tool_arguments":{}}', strict=True)

    assert parsed.action is None
    assert parsed.error is not None
    assert "schema_validation_failed" in parsed.error
    assert "action_type must be" in parsed.error


def test_legacy_parse_can_coerce_tool_name_in_action_type() -> None:
    parsed = parse_action('{"action_type":"store_user_details","tool_arguments":{}}', strict=False)

    assert parsed.action is not None
    assert parsed.action.action_type == "tool_call"
    assert parsed.action.tool_name == "store_user_details"
    assert parsed.warning is not None
    assert "coerced invalid action_type" in parsed.warning


def test_strict_parse_reports_json_error() -> None:
    parsed = parse_action('{"action_type":"tool_call"', strict=True)

    assert parsed.action is None
    assert parsed.error is not None
    assert parsed.error.startswith("json_parse_failed")


def test_malformed_action_feedback_is_recoverable() -> None:
    env = FlatmateRlEnvironment()
    obs = env.reset(scenario_id="task_visit_single")

    feedback_obs = malformed_action_observation(obs, "schema_validation_failed: bad action")

    assert feedback_obs.done is False
    assert feedback_obs.step_reward == pytest.approx(-0.05)
    assert feedback_obs.total_reward == pytest.approx(-0.05)
    assert feedback_obs.last_tool_result["error"] == "schema_validation_failed"
    assert "expected_schema" in feedback_obs.last_tool_result


def test_user_prompt_renders_prerequisites_and_recent_tools() -> None:
    env = FlatmateRlEnvironment()
    obs = env.reset(scenario_id="task_visit_single")
    obs = env.step(
        FlatmateRlAction(
            action_type="assistant_message",
            assistant_message="Please share your dietary preference and visit availability.",
        )
    )
    obs = env.step(
        FlatmateRlAction(
            action_type="tool_call",
            tool_name="store_user_details",
            tool_arguments={},
        )
    )

    prompt = build_user_prompt(step=2, observation=obs)

    assert "Prerequisites satisfied:" in prompt
    assert '"details_stored": true' in prompt
    assert "Recent tool calls:" in prompt
    assert "store_user_details" in prompt


def test_model_call_error_does_not_fallback_to_heuristic() -> None:
    class FailingCompletions:
        def create(self, **kwargs):
            raise RuntimeError("requested model is not supported")

    client = SimpleNamespace(chat=SimpleNamespace(completions=FailingCompletions()))
    env = FlatmateRlEnvironment()
    obs = env.reset(scenario_id="task_visit_single")

    with pytest.raises(ModelConfigurationError, match="MODEL_NAME is invalid or unsupported"):
        get_model_action(
            client=client,
            task_id="task_visit_single",
            step=1,
            observation=obs,
            explain=False,
            strict_parsing=True,
        )