File size: 10,881 Bytes
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3dd83d
 
 
5336338
0c591a7
5336338
 
 
0c591a7
5336338
 
 
 
 
 
0c591a7
176e3c1
 
 
 
 
 
 
 
 
 
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be3e4c5
 
 
 
 
 
 
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""
Research Gateway

A2A client for calling the Research Service (Researcher-Agent) via Google A2A protocol.
Acts as the gateway between the main SWOT Agent and the external Research Service.

Supports real-time partial metrics streaming during task execution.
"""

import asyncio
import logging
import os
from typing import Optional, Callable, Set

import httpx

logger = logging.getLogger("research-gateway")

# Research Service configuration - defaults to HuggingFace Spaces deployment
A2A_RESEARCHER_URL = os.getenv(
    "A2A_RESEARCHER_URL",
    "https://vn6295337-researcher-agent.hf.space"
)
A2A_TIMEOUT = float(os.getenv("A2A_TIMEOUT", "120"))  # seconds (increased for remote calls)
A2A_POLL_INTERVAL = float(os.getenv("A2A_POLL_INTERVAL", "1"))  # seconds


class ResearchGatewayError(Exception):
    """Error communicating with Research Service."""
    pass


async def send_message(message_text: str) -> dict:
    """
    Send message/send request to start a research task.

    Args:
        message_text: Text message like "Research Tesla"

    Returns:
        Task info dict with task ID
    """
    async with httpx.AsyncClient() as client:
        request = {
            "jsonrpc": "2.0",
            "id": 1,
            "method": "message/send",
            "params": {
                "message": {
                    "parts": [{"type": "text", "text": message_text}]
                }
            }
        }

        try:
            response = await client.post(
                A2A_RESEARCHER_URL,
                json=request,
                timeout=30
            )
            data = response.json()

            if "error" in data:
                raise ResearchGatewayError(f"A2A error: {data['error']}")

            return data.get("result", {})

        except httpx.RequestError as e:
            raise ResearchGatewayError(f"Connection error to {A2A_RESEARCHER_URL}: {e}")


async def get_task_status(task_id: str) -> dict:
    """
    Get task status via tasks/get request.

    Args:
        task_id: Task ID from message/send response

    Returns:
        Task status dict including partial_metrics (if WORKING) or artifacts (if COMPLETED)
    """
    async with httpx.AsyncClient() as client:
        request = {
            "jsonrpc": "2.0",
            "id": 1,
            "method": "tasks/get",
            "params": {"taskId": task_id}
        }

        try:
            response = await client.post(
                A2A_RESEARCHER_URL,
                json=request,
                timeout=30
            )
            data = response.json()

            if "error" in data:
                raise ResearchGatewayError(f"A2A error: {data['error']}")

            return data.get("result", {}).get("task", {})

        except httpx.RequestError as e:
            raise ResearchGatewayError(f"Connection error: {e}")


async def wait_for_completion(
    task_id: str,
    timeout: float = None,
    progress_callback: Optional[Callable] = None,
    add_log: Optional[Callable] = None
) -> dict:
    """
    Poll task status until completed or failed.
    Emits partial_metrics via progress_callback during WORKING status.

    Args:
        task_id: Task ID to poll
        timeout: Max seconds to wait (default: A2A_TIMEOUT)
        progress_callback: Optional callback for granular metrics (source, metric, value)
        add_log: Optional callback for activity logging (step, message)

    Returns:
        Completed task dict with artifacts
    """
    if timeout is None:
        timeout = A2A_TIMEOUT

    elapsed = 0
    emitted_metrics: Set[str] = set()  # Track which metrics we've already emitted

    while elapsed < timeout:
        task = await get_task_status(task_id)
        status = task.get("status")

        # Emit partial metrics during WORKING or COMPLETED status
        # Important: Also process on COMPLETED to catch metrics from final sources
        if (status in ("working", "completed")) and progress_callback:
            partial_metrics = task.get("partial_metrics", []) or []
            for metric in partial_metrics:
                # Skip None or invalid metrics
                if not metric or not isinstance(metric, dict):
                    continue
                # Create unique key to avoid duplicate emissions
                source = metric.get("source")
                metric_name = metric.get("metric")
                value = metric.get("value")
                if not source or not metric_name:
                    continue
                metric_key = f"{source}:{metric_name}:{value}"
                if metric_key not in emitted_metrics:
                    # Pass structured payload dict (matches Researcher-Agent emit_metric)
                    payload = {
                        "source": source,
                        "metric": metric_name,
                        "value": value,
                        "end_date": metric.get("end_date"),
                        "fiscal_year": metric.get("fiscal_year"),
                        "form": metric.get("form"),
                    }
                    progress_callback(payload)
                    emitted_metrics.add(metric_key)

        if status == "completed":
            if add_log:
                sources = len(task.get("artifacts", [{}])[0].get("data", {}).get("sources_available", []))
                add_log("researcher", f"Research completed: {sources} sources aggregated")
            return task
        elif status == "failed":
            error = task.get("error", {}).get("message", "Unknown error")
            if add_log:
                add_log("researcher", f"Research failed: {error}")
            raise ResearchGatewayError(f"Task failed: {error}")
        elif status == "canceled":
            if add_log:
                add_log("researcher", "Research task was canceled")
            raise ResearchGatewayError("Task was canceled")

        # Log polling status periodically
        if add_log and elapsed > 0 and elapsed % 5 == 0:
            add_log("researcher", f"Polling Research Service... ({int(elapsed)}s elapsed)")

        await asyncio.sleep(A2A_POLL_INTERVAL)
        elapsed += A2A_POLL_INTERVAL

    raise ResearchGatewayError(f"Task timed out after {timeout} seconds")


async def call_research_service(
    company: str,
    ticker: str = "",
    progress_callback: Optional[Callable] = None,
    add_log: Optional[Callable] = None
) -> dict:
    """
    High-level function to call Research Service and get results.
    Supports real-time partial metrics streaming.

    Args:
        company: Company name to research
        ticker: Optional ticker symbol
        progress_callback: Optional callback for granular metrics
        add_log: Optional callback for activity logging

    Returns:
        Research data dict from the Research Service
    """
    # Format message
    if ticker:
        message = f"Research {ticker} {company}"
    else:
        message = f"Research {company}"

    logger.info(f"Calling Research Service at {A2A_RESEARCHER_URL}: {message}")

    # Log connection
    if add_log:
        add_log("researcher", f"Connecting to Research Service...")
        add_log("researcher", f"A2A URL: {A2A_RESEARCHER_URL}")

    # Check health first
    healthy = await check_service_health()
    if not healthy:
        if add_log:
            add_log("researcher", "WARNING: Research Service health check failed, attempting anyway...")
        logger.warning("Research Service health check failed")

    if add_log:
        add_log("researcher", f"A2A handshake successful")

    # Send message to start task
    if add_log:
        add_log("researcher", f"Submitting research task for {company} ({ticker})...")

    try:
        result = await send_message(message)
    except ResearchGatewayError as e:
        if add_log:
            add_log("researcher", f"A2A request failed: {str(e)}")
        raise

    task_id = result.get("task", {}).get("id")

    if not task_id:
        raise ResearchGatewayError("No task ID returned from message/send")

    logger.info(f"Task created: {task_id}")
    if add_log:
        add_log("researcher", f"Task submitted: {task_id[:8]}...")
        add_log("researcher", "Fetching data from 6 MCP servers in parallel...")

    # Wait for completion with partial metrics streaming
    task = await wait_for_completion(
        task_id,
        progress_callback=progress_callback,
        add_log=add_log
    )

    # Extract data from artifacts
    artifacts = task.get("artifacts", [])
    if not artifacts:
        raise ResearchGatewayError("No artifacts in completed task")

    # Find data artifact
    for artifact in artifacts:
        if artifact.get("type") == "data":
            data = artifact.get("data", {})
            # Log sources
            if add_log:
                sources = data.get("sources_available", [])
                failed = data.get("sources_failed", [])
                add_log("researcher", f"Sources available: {', '.join(sources)}")
                if failed:
                    add_log("researcher", f"Sources failed: {', '.join(failed)}")
            return data

    raise ResearchGatewayError("No data artifact found in response")


async def check_service_health() -> bool:
    """
    Check if Research Service is healthy.

    Returns:
        True if server is healthy, False otherwise
    """
    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(
                f"{A2A_RESEARCHER_URL}/health",
                timeout=10
            )
            data = response.json()
            return data.get("status") == "healthy"
    except Exception as e:
        logger.warning(f"Health check failed: {e}")
        return False


async def get_agent_card() -> Optional[dict]:
    """
    Fetch the agent card from the Research Service.

    Returns:
        Agent card dict or None if unavailable
    """
    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(
                f"{A2A_RESEARCHER_URL}/.well-known/agent.json",
                timeout=10
            )
            return response.json()
    except Exception:
        return None


# Synchronous wrapper for LangGraph node
def call_research_service_sync(
    company: str,
    ticker: str = "",
    progress_callback: Optional[Callable] = None,
    add_log: Optional[Callable] = None
) -> dict:
    """
    Synchronous wrapper for call_research_service.

    Use this in LangGraph nodes that don't support async.
    """
    return asyncio.run(call_research_service(company, ticker, progress_callback, add_log))


# Backward compatibility aliases
A2AClientError = ResearchGatewayError
call_researcher_a2a = call_research_service
call_researcher_sync = call_research_service_sync
check_researcher_health = check_service_health