Instant-SWOT-Agent / src /nodes /research_gateway.py
vn6295337's picture
Fix: Process partial_metrics on completed status
e3dd83d
"""
Research Gateway
A2A client for calling the Research Service (Researcher-Agent) via Google A2A protocol.
Acts as the gateway between the main SWOT Agent and the external Research Service.
Supports real-time partial metrics streaming during task execution.
"""
import asyncio
import logging
import os
from typing import Optional, Callable, Set
import httpx
logger = logging.getLogger("research-gateway")
# Research Service configuration - defaults to HuggingFace Spaces deployment
A2A_RESEARCHER_URL = os.getenv(
"A2A_RESEARCHER_URL",
"https://vn6295337-researcher-agent.hf.space"
)
A2A_TIMEOUT = float(os.getenv("A2A_TIMEOUT", "120")) # seconds (increased for remote calls)
A2A_POLL_INTERVAL = float(os.getenv("A2A_POLL_INTERVAL", "1")) # seconds
class ResearchGatewayError(Exception):
"""Error communicating with Research Service."""
pass
async def send_message(message_text: str) -> dict:
"""
Send message/send request to start a research task.
Args:
message_text: Text message like "Research Tesla"
Returns:
Task info dict with task ID
"""
async with httpx.AsyncClient() as client:
request = {
"jsonrpc": "2.0",
"id": 1,
"method": "message/send",
"params": {
"message": {
"parts": [{"type": "text", "text": message_text}]
}
}
}
try:
response = await client.post(
A2A_RESEARCHER_URL,
json=request,
timeout=30
)
data = response.json()
if "error" in data:
raise ResearchGatewayError(f"A2A error: {data['error']}")
return data.get("result", {})
except httpx.RequestError as e:
raise ResearchGatewayError(f"Connection error to {A2A_RESEARCHER_URL}: {e}")
async def get_task_status(task_id: str) -> dict:
"""
Get task status via tasks/get request.
Args:
task_id: Task ID from message/send response
Returns:
Task status dict including partial_metrics (if WORKING) or artifacts (if COMPLETED)
"""
async with httpx.AsyncClient() as client:
request = {
"jsonrpc": "2.0",
"id": 1,
"method": "tasks/get",
"params": {"taskId": task_id}
}
try:
response = await client.post(
A2A_RESEARCHER_URL,
json=request,
timeout=30
)
data = response.json()
if "error" in data:
raise ResearchGatewayError(f"A2A error: {data['error']}")
return data.get("result", {}).get("task", {})
except httpx.RequestError as e:
raise ResearchGatewayError(f"Connection error: {e}")
async def wait_for_completion(
task_id: str,
timeout: float = None,
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
Poll task status until completed or failed.
Emits partial_metrics via progress_callback during WORKING status.
Args:
task_id: Task ID to poll
timeout: Max seconds to wait (default: A2A_TIMEOUT)
progress_callback: Optional callback for granular metrics (source, metric, value)
add_log: Optional callback for activity logging (step, message)
Returns:
Completed task dict with artifacts
"""
if timeout is None:
timeout = A2A_TIMEOUT
elapsed = 0
emitted_metrics: Set[str] = set() # Track which metrics we've already emitted
while elapsed < timeout:
task = await get_task_status(task_id)
status = task.get("status")
# Emit partial metrics during WORKING or COMPLETED status
# Important: Also process on COMPLETED to catch metrics from final sources
if (status in ("working", "completed")) and progress_callback:
partial_metrics = task.get("partial_metrics", []) or []
for metric in partial_metrics:
# Skip None or invalid metrics
if not metric or not isinstance(metric, dict):
continue
# Create unique key to avoid duplicate emissions
source = metric.get("source")
metric_name = metric.get("metric")
value = metric.get("value")
if not source or not metric_name:
continue
metric_key = f"{source}:{metric_name}:{value}"
if metric_key not in emitted_metrics:
# Pass structured payload dict (matches Researcher-Agent emit_metric)
payload = {
"source": source,
"metric": metric_name,
"value": value,
"end_date": metric.get("end_date"),
"fiscal_year": metric.get("fiscal_year"),
"form": metric.get("form"),
}
progress_callback(payload)
emitted_metrics.add(metric_key)
if status == "completed":
if add_log:
sources = len(task.get("artifacts", [{}])[0].get("data", {}).get("sources_available", []))
add_log("researcher", f"Research completed: {sources} sources aggregated")
return task
elif status == "failed":
error = task.get("error", {}).get("message", "Unknown error")
if add_log:
add_log("researcher", f"Research failed: {error}")
raise ResearchGatewayError(f"Task failed: {error}")
elif status == "canceled":
if add_log:
add_log("researcher", "Research task was canceled")
raise ResearchGatewayError("Task was canceled")
# Log polling status periodically
if add_log and elapsed > 0 and elapsed % 5 == 0:
add_log("researcher", f"Polling Research Service... ({int(elapsed)}s elapsed)")
await asyncio.sleep(A2A_POLL_INTERVAL)
elapsed += A2A_POLL_INTERVAL
raise ResearchGatewayError(f"Task timed out after {timeout} seconds")
async def call_research_service(
company: str,
ticker: str = "",
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
High-level function to call Research Service and get results.
Supports real-time partial metrics streaming.
Args:
company: Company name to research
ticker: Optional ticker symbol
progress_callback: Optional callback for granular metrics
add_log: Optional callback for activity logging
Returns:
Research data dict from the Research Service
"""
# Format message
if ticker:
message = f"Research {ticker} {company}"
else:
message = f"Research {company}"
logger.info(f"Calling Research Service at {A2A_RESEARCHER_URL}: {message}")
# Log connection
if add_log:
add_log("researcher", f"Connecting to Research Service...")
add_log("researcher", f"A2A URL: {A2A_RESEARCHER_URL}")
# Check health first
healthy = await check_service_health()
if not healthy:
if add_log:
add_log("researcher", "WARNING: Research Service health check failed, attempting anyway...")
logger.warning("Research Service health check failed")
if add_log:
add_log("researcher", f"A2A handshake successful")
# Send message to start task
if add_log:
add_log("researcher", f"Submitting research task for {company} ({ticker})...")
try:
result = await send_message(message)
except ResearchGatewayError as e:
if add_log:
add_log("researcher", f"A2A request failed: {str(e)}")
raise
task_id = result.get("task", {}).get("id")
if not task_id:
raise ResearchGatewayError("No task ID returned from message/send")
logger.info(f"Task created: {task_id}")
if add_log:
add_log("researcher", f"Task submitted: {task_id[:8]}...")
add_log("researcher", "Fetching data from 6 MCP servers in parallel...")
# Wait for completion with partial metrics streaming
task = await wait_for_completion(
task_id,
progress_callback=progress_callback,
add_log=add_log
)
# Extract data from artifacts
artifacts = task.get("artifacts", [])
if not artifacts:
raise ResearchGatewayError("No artifacts in completed task")
# Find data artifact
for artifact in artifacts:
if artifact.get("type") == "data":
data = artifact.get("data", {})
# Log sources
if add_log:
sources = data.get("sources_available", [])
failed = data.get("sources_failed", [])
add_log("researcher", f"Sources available: {', '.join(sources)}")
if failed:
add_log("researcher", f"Sources failed: {', '.join(failed)}")
return data
raise ResearchGatewayError("No data artifact found in response")
async def check_service_health() -> bool:
"""
Check if Research Service is healthy.
Returns:
True if server is healthy, False otherwise
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{A2A_RESEARCHER_URL}/health",
timeout=10
)
data = response.json()
return data.get("status") == "healthy"
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
async def get_agent_card() -> Optional[dict]:
"""
Fetch the agent card from the Research Service.
Returns:
Agent card dict or None if unavailable
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{A2A_RESEARCHER_URL}/.well-known/agent.json",
timeout=10
)
return response.json()
except Exception:
return None
# Synchronous wrapper for LangGraph node
def call_research_service_sync(
company: str,
ticker: str = "",
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
Synchronous wrapper for call_research_service.
Use this in LangGraph nodes that don't support async.
"""
return asyncio.run(call_research_service(company, ticker, progress_callback, add_log))
# Backward compatibility aliases
A2AClientError = ResearchGatewayError
call_researcher_a2a = call_research_service
call_researcher_sync = call_research_service_sync
check_researcher_health = check_service_health