File size: 3,939 Bytes
3eec386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Regression tests for `_patch_dangling_tool_calls`.

Reproduces the failure mode behind observatory sessions 8dd2ce30 and
59c9e678 (2026-04-25): a tool call cancelled mid-execution leaves an
orphan ``tool_use`` in history; the user types a follow-up; Bedrock
rejects the next request with HTTP 400 ``messages.N: tool_use ids were
found without tool_result blocks immediately after``.
"""

from litellm import ChatCompletionMessageToolCall, Message

from agent.context_manager.manager import ContextManager


def _tool_call(call_id: str, name: str = "research") -> ChatCompletionMessageToolCall:
    return ChatCompletionMessageToolCall(
        id=call_id,
        type="function",
        function={"name": name, "arguments": "{}"},
    )


def _make_cm() -> ContextManager:
    cm = ContextManager.__new__(ContextManager)
    cm.system_prompt = "system"
    cm.model_max_tokens = 100_000
    cm.compact_size = 1_000
    cm.running_context_usage = 0
    cm.untouched_messages = 5
    cm.items = [Message(role="system", content="system")]
    return cm


def test_orphan_tool_use_followed_by_user_message_is_patched():
    cm = _make_cm()
    cm.items.extend([
        Message(role="user", content="Research X"),
        Message(
            role="assistant",
            content=None,
            tool_calls=[_tool_call("call_abc", "research")],
        ),
        Message(role="user", content="??"),
    ])
    msgs = cm.get_messages()
    tool_msgs = [m for m in msgs if getattr(m, "role", None) == "tool"]
    assert len(tool_msgs) == 1
    assert tool_msgs[0].tool_call_id == "call_abc"
    assert "interrupted" in (tool_msgs[0].content or "").lower() or "not executed" in (tool_msgs[0].content or "").lower()


def test_no_orphan_means_no_stub():
    cm = _make_cm()
    cm.items.extend([
        Message(role="user", content="Research X"),
        Message(
            role="assistant",
            content=None,
            tool_calls=[_tool_call("call_abc", "research")],
        ),
        Message(role="tool", content="ok", tool_call_id="call_abc", name="research"),
    ])
    cm.get_messages()
    tool_msgs = [m for m in cm.items if getattr(m, "role", None) == "tool"]
    assert len(tool_msgs) == 1
    assert tool_msgs[0].content == "ok"


def test_multiple_dangling_tool_calls_in_one_assistant_message_are_all_patched():
    cm = _make_cm()
    cm.items.extend([
        Message(role="user", content="do two things"),
        Message(
            role="assistant",
            content=None,
            tool_calls=[
                _tool_call("call_1", "research"),
                _tool_call("call_2", "bash"),
            ],
        ),
        Message(role="user", content="follow up"),
    ])
    cm.get_messages()
    tool_ids = {
        getattr(m, "tool_call_id", None)
        for m in cm.items
        if getattr(m, "role", None) == "tool"
    }
    assert tool_ids == {"call_1", "call_2"}


def test_orphan_in_earlier_turn_still_gets_patched():
    """Two-turn history where the FIRST turn was interrupted.

    Old patcher stopped at the first user msg encountered while scanning
    backwards, so this case never got fixed and Bedrock rejected.
    """
    cm = _make_cm()
    cm.items.extend([
        Message(role="user", content="turn 1"),
        Message(
            role="assistant",
            content=None,
            tool_calls=[_tool_call("call_old", "research")],
        ),
        Message(role="user", content="turn 2 — please retry"),
        Message(
            role="assistant",
            content=None,
            tool_calls=[_tool_call("call_new", "bash")],
        ),
        Message(role="tool", content="ok", tool_call_id="call_new", name="bash"),
    ])
    cm.get_messages()
    tool_ids = {
        getattr(m, "tool_call_id", None)
        for m in cm.items
        if getattr(m, "role", None) == "tool"
    }
    assert "call_old" in tool_ids
    assert "call_new" in tool_ids