File size: 4,366 Bytes
9dff1b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Shape assertion for the compare-intent SSE endpoint.

Validates that /api/agent/stream emits exactly one `final` event for a compare
query, with the structure that the SvelteKit compare layout depends on:

  final.intent  == "compare"
  final.targets == [{label, address}, {label, address}]   (exactly 2)
  final.paragraph contains "## PLACE A" and "## PLACE B" sections separated
                 by a markdown "---" divider
  each half of the paragraph contains at least one [citation] bracket

Run against a live local server:
    .venv/bin/uvicorn web.main:app --host 127.0.0.1 --port 7860 &
    .venv/bin/python -m pytest tests/test_compare_shape.py -v
"""
from __future__ import annotations

import json
import re

import httpx
import pytest

BASE = "http://127.0.0.1:7860"
COMPARE_QUERY = "Compare 80 Pioneer Street Brooklyn to 100 Gold Street Manhattan"


def _parse_sse_stream(raw: str) -> list[dict]:
    """Parse raw SSE text into a list of {type, data} dicts."""
    events: list[dict] = []
    current: dict = {}
    for line in raw.splitlines():
        if line.startswith("event:"):
            current["type"] = line.split(":", 1)[1].strip()
        elif line.startswith("data:"):
            raw_data = line.split(":", 1)[1].strip()
            try:
                current["data"] = json.loads(raw_data)
            except json.JSONDecodeError:
                current["data"] = raw_data
        elif line == "" and current:
            events.append(current)
            current = {}
    if current:
        events.append(current)
    return events


@pytest.fixture(scope="module")
def compare_events() -> list[dict]:
    """Stream a compare query and return all parsed SSE events."""
    url = f"{BASE}/api/agent/stream"
    params = {"q": COMPARE_QUERY}
    try:
        with httpx.stream("GET", url, params=params, timeout=600.0) as resp:
            resp.raise_for_status()
            raw = resp.read().decode()
    except httpx.ConnectError:
        pytest.skip(f"Server not reachable at {BASE}")
    return _parse_sse_stream(raw)


def test_has_plan_with_compare_intent(compare_events):
    plans = [e for e in compare_events if e.get("type") == "plan"]
    assert len(plans) == 1, f"Expected 1 plan event, got {len(plans)}"
    plan = plans[0]["data"]
    assert plan.get("intent") == "compare", f"plan.intent={plan.get('intent')!r}"


def test_exactly_one_final_event(compare_events):
    finals = [e for e in compare_events if e.get("type") == "final"]
    assert len(finals) == 1, f"Expected 1 final event, got {len(finals)}"


def test_final_intent_is_compare(compare_events):
    final = next(e for e in compare_events if e.get("type") == "final")
    data = final["data"]
    assert data.get("intent") == "compare", f"final.intent={data.get('intent')!r}"


def test_final_targets_has_two_addresses(compare_events):
    final = next(e for e in compare_events if e.get("type") == "final")
    targets = final["data"].get("targets", [])
    assert len(targets) == 2, f"Expected 2 targets, got {len(targets)}: {targets}"
    for i, t in enumerate(targets):
        assert "address" in t, f"target[{i}] missing 'address' key: {t}"
        assert isinstance(t["address"], str) and t["address"].strip(), \
            f"target[{i}].address is blank: {t}"


def test_final_paragraph_has_both_place_sections(compare_events):
    final = next(e for e in compare_events if e.get("type") == "final")
    para = final["data"].get("paragraph", "")
    assert "## PLACE A" in para, "paragraph missing '## PLACE A' header"
    assert "## PLACE B" in para, "paragraph missing '## PLACE B' header"
    assert "---" in para, "paragraph missing '---' divider between places"


def test_each_place_section_has_at_least_one_citation(compare_events):
    final = next(e for e in compare_events if e.get("type") == "final")
    para = final["data"].get("paragraph", "")

    # Split at the markdown divider between PLACE A and PLACE B.
    parts = re.split(r"\n\s*---\s*\n", para, maxsplit=1)
    assert len(parts) == 2, f"Could not split paragraph at '---': got {len(parts)} parts"

    cite_re = re.compile(r"\[[a-z][a-z0-9_]*\]", re.IGNORECASE)
    for i, part in enumerate(parts):
        label = "PLACE A" if i == 0 else "PLACE B"
        cites = cite_re.findall(part)
        assert len(cites) >= 1, f"{label} section has no [citation] brackets"