File size: 6,818 Bytes
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5766b78
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176e3c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c591a7
176e3c1
67b4836
0c591a7
 
dc70069
 
 
 
 
0c591a7
 
 
 
 
 
 
 
5336338
 
 
 
 
0c591a7
 
 
 
5766b78
0c591a7
5766b78
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5766b78
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
be3e4c5
 
 
 
0c591a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Research Gateway Node

Fetches data from the Research Service via A2A protocol.
The Research Service internally calls all 6 MCP servers using TRUE MCP protocol.

This node acts as the gateway between the main SWOT Agent and the external Research Service.
"""

import asyncio
import json
from langsmith import traceable

from src.utils.ticker_lookup import get_ticker, normalize_company_name


async def _fetch_via_research_gateway(company: str, ticker: str = None, progress_callback=None, add_log=None) -> dict:
    """Async helper to fetch data via Research Gateway (A2A protocol)."""
    from src.nodes.research_gateway import call_research_service

    # Use provided ticker or lookup from company name
    if not ticker:
        ticker = get_ticker(company)

    if not ticker:
        print(f"Could not determine ticker for '{company}', using company name as ticker")
        ticker = company.upper().replace(" ", "")[:5]

    # Normalize company name for display
    company_name = normalize_company_name(company)

    print(f"Calling Research Service for {company_name} ({ticker})...")

    # Call Research Service with callbacks for real-time streaming
    result = await call_research_service(
        company_name,
        ticker,
        progress_callback=progress_callback,
        add_log=add_log
    )

    return result


@traceable(name="Researcher")
def researcher_node(state, workflow_id=None, progress_store=None):
    """
    Research Gateway node that fetches data via A2A protocol.

    Calls the external Research Service which internally fetches from 6 MCP servers:
    Fundamentals, Volatility, Macro, Valuation, News, Sentiment
    """
    company = state["company_name"]
    ticker = state.get("ticker")  # Use ticker from stock search if available

    # Extract workflow_id and progress_store from state (graph invokes with state only)
    if workflow_id is None:
        workflow_id = state.get("workflow_id")
    if progress_store is None:
        progress_store = state.get("progress_store")

    print(f"[DEBUG] researcher_node: workflow_id={workflow_id}, progress_store={'yes' if progress_store else 'no'}")

    # Update progress if tracking is enabled
    if workflow_id and progress_store:
        progress_store[workflow_id].update({
            "current_step": "researcher",
            "revision_count": state.get("revision_count", 0),
            "score": state.get("score", 0)
        })

    # Helper to add activity log
    def add_log(step: str, message: str):
        if workflow_id and progress_store:
            from src.services.workflow_store import add_activity_log
            add_activity_log(workflow_id, step, message)

    # Create progress callback for granular metric events
    # Supports both dict payload (new) and positional args (legacy)
    def progress_callback(*args, **kwargs):
        if args and isinstance(args[0], dict):
            # New structured payload format
            p = args[0]
            src = p.get("source")
            metric = p.get("metric")
            value = p.get("value")
            end_date = p.get("end_date")
            fiscal_year = p.get("fiscal_year")
            form = p.get("form")
        else:
            # Legacy positional args format
            src = args[0] if len(args) > 0 else kwargs.get("source")
            metric = args[1] if len(args) > 1 else kwargs.get("metric")
            value = args[2] if len(args) > 2 else kwargs.get("value")
            end_date = args[3] if len(args) > 3 else kwargs.get("end_date")
            fiscal_year = args[4] if len(args) > 4 else kwargs.get("fiscal_year")
            form = args[5] if len(args) > 5 else kwargs.get("form")

        if workflow_id and progress_store and src and metric:
            from src.services.workflow_store import add_metric
            add_metric(workflow_id, src, metric, value,
                       end_date=end_date, fiscal_year=fiscal_year, form=form)

    try:
        # Set all MCP servers to "executing" state before research starts
        if workflow_id and progress_store:
            from src.services.workflow_store import set_mcp_executing
            set_mcp_executing(workflow_id)

        # Fetch via Research Gateway (A2A protocol)
        print("[Research Gateway] Calling Research Service via A2A...")
        result = asyncio.run(_fetch_via_research_gateway(
            company,
            ticker,
            progress_callback=progress_callback,
            add_log=add_log
        ))

        # Validate result
        if not result or not isinstance(result, dict):
            raise RuntimeError(f"Research Service returned invalid data for {company}")

        state["data_source"] = "a2a"
        # Note: Metrics are streamed via partial_metrics during A2A polling

        # Check MCP source availability with tiered logic
        # Core sources (need at least 2 of 3): fundamentals, valuation, volatility
        # Supplementary sources (non-blocking): macro, news, sentiment
        CORE_SOURCES = {"fundamentals", "valuation", "volatility"}
        SUPPLEMENTARY_SOURCES = {"macro", "news", "sentiment"}

        sources_available = set(result.get("sources_available", []))
        sources_failed = result.get("sources_failed", [])

        core_available = sources_available & CORE_SOURCES
        core_failed = CORE_SOURCES - core_available
        supplementary_failed = set(sources_failed) & SUPPLEMENTARY_SOURCES

        # Log supplementary failures as non-critical
        for source in supplementary_failed:
            add_log("researcher", f"{source.capitalize()} unavailable (non-critical)")

        # Log core failures as critical
        for source in core_failed:
            add_log("researcher", f"{source.capitalize()} unavailable (critical)")

        # Abort if 2+ core sources failed (need at least 2 of 3)
        if len(core_available) < 2:
            failed_list = ", ".join(sorted(core_failed))
            raise RuntimeError(f"Insufficient core data: {failed_list} unavailable. Need at least 2 of: Fundamentals, Valuation, Volatility.")

        if sources_available:
            state["raw_data"] = json.dumps(result, indent=2, default=str)
            state["sources_failed"] = sources_failed

            print(f"  - Sources available: {result['sources_available']}")
            if sources_failed:
                print(f"  - Sources failed: {sources_failed}")
        else:
            # All MCPs failed - raise error
            raise RuntimeError(f"All MCP servers failed for {company}. Check API configurations.")

    except Exception as e:
        error_msg = str(e)
        print(f"Research failed: {error_msg}")
        add_log("researcher", f"ERROR: {error_msg}")
        raise RuntimeError(f"Research failed for {company}: {error_msg}")

    return state