Spaces:
Sleeping
Sleeping
| """ | |
| Research Gateway | |
| A2A client for calling the Research Service (Researcher-Agent) via Google A2A protocol. | |
| Acts as the gateway between the main SWOT Agent and the external Research Service. | |
| Supports real-time partial metrics streaming during task execution. | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| from typing import Optional, Callable, Set | |
| import httpx | |
| logger = logging.getLogger("research-gateway") | |
| # Research Service configuration - defaults to HuggingFace Spaces deployment | |
| A2A_RESEARCHER_URL = os.getenv( | |
| "A2A_RESEARCHER_URL", | |
| "https://vn6295337-researcher-agent.hf.space" | |
| ) | |
| A2A_TIMEOUT = float(os.getenv("A2A_TIMEOUT", "120")) # seconds (increased for remote calls) | |
| A2A_POLL_INTERVAL = float(os.getenv("A2A_POLL_INTERVAL", "1")) # seconds | |
| class ResearchGatewayError(Exception): | |
| """Error communicating with Research Service.""" | |
| pass | |
| async def send_message(message_text: str) -> dict: | |
| """ | |
| Send message/send request to start a research task. | |
| Args: | |
| message_text: Text message like "Research Tesla" | |
| Returns: | |
| Task info dict with task ID | |
| """ | |
| async with httpx.AsyncClient() as client: | |
| request = { | |
| "jsonrpc": "2.0", | |
| "id": 1, | |
| "method": "message/send", | |
| "params": { | |
| "message": { | |
| "parts": [{"type": "text", "text": message_text}] | |
| } | |
| } | |
| } | |
| try: | |
| response = await client.post( | |
| A2A_RESEARCHER_URL, | |
| json=request, | |
| timeout=30 | |
| ) | |
| data = response.json() | |
| if "error" in data: | |
| raise ResearchGatewayError(f"A2A error: {data['error']}") | |
| return data.get("result", {}) | |
| except httpx.RequestError as e: | |
| raise ResearchGatewayError(f"Connection error to {A2A_RESEARCHER_URL}: {e}") | |
| async def get_task_status(task_id: str) -> dict: | |
| """ | |
| Get task status via tasks/get request. | |
| Args: | |
| task_id: Task ID from message/send response | |
| Returns: | |
| Task status dict including partial_metrics (if WORKING) or artifacts (if COMPLETED) | |
| """ | |
| async with httpx.AsyncClient() as client: | |
| request = { | |
| "jsonrpc": "2.0", | |
| "id": 1, | |
| "method": "tasks/get", | |
| "params": {"taskId": task_id} | |
| } | |
| try: | |
| response = await client.post( | |
| A2A_RESEARCHER_URL, | |
| json=request, | |
| timeout=30 | |
| ) | |
| data = response.json() | |
| if "error" in data: | |
| raise ResearchGatewayError(f"A2A error: {data['error']}") | |
| return data.get("result", {}).get("task", {}) | |
| except httpx.RequestError as e: | |
| raise ResearchGatewayError(f"Connection error: {e}") | |
| async def wait_for_completion( | |
| task_id: str, | |
| timeout: float = None, | |
| progress_callback: Optional[Callable] = None, | |
| add_log: Optional[Callable] = None | |
| ) -> dict: | |
| """ | |
| Poll task status until completed or failed. | |
| Emits partial_metrics via progress_callback during WORKING status. | |
| Args: | |
| task_id: Task ID to poll | |
| timeout: Max seconds to wait (default: A2A_TIMEOUT) | |
| progress_callback: Optional callback for granular metrics (source, metric, value) | |
| add_log: Optional callback for activity logging (step, message) | |
| Returns: | |
| Completed task dict with artifacts | |
| """ | |
| if timeout is None: | |
| timeout = A2A_TIMEOUT | |
| elapsed = 0 | |
| emitted_metrics: Set[str] = set() # Track which metrics we've already emitted | |
| while elapsed < timeout: | |
| task = await get_task_status(task_id) | |
| status = task.get("status") | |
| # Emit partial metrics during WORKING or COMPLETED status | |
| # Important: Also process on COMPLETED to catch metrics from final sources | |
| if (status in ("working", "completed")) and progress_callback: | |
| partial_metrics = task.get("partial_metrics", []) or [] | |
| for metric in partial_metrics: | |
| # Skip None or invalid metrics | |
| if not metric or not isinstance(metric, dict): | |
| continue | |
| # Create unique key to avoid duplicate emissions | |
| source = metric.get("source") | |
| metric_name = metric.get("metric") | |
| value = metric.get("value") | |
| if not source or not metric_name: | |
| continue | |
| metric_key = f"{source}:{metric_name}:{value}" | |
| if metric_key not in emitted_metrics: | |
| # Pass structured payload dict (matches Researcher-Agent emit_metric) | |
| payload = { | |
| "source": source, | |
| "metric": metric_name, | |
| "value": value, | |
| "end_date": metric.get("end_date"), | |
| "fiscal_year": metric.get("fiscal_year"), | |
| "form": metric.get("form"), | |
| } | |
| progress_callback(payload) | |
| emitted_metrics.add(metric_key) | |
| if status == "completed": | |
| if add_log: | |
| sources = len(task.get("artifacts", [{}])[0].get("data", {}).get("sources_available", [])) | |
| add_log("researcher", f"Research completed: {sources} sources aggregated") | |
| return task | |
| elif status == "failed": | |
| error = task.get("error", {}).get("message", "Unknown error") | |
| if add_log: | |
| add_log("researcher", f"Research failed: {error}") | |
| raise ResearchGatewayError(f"Task failed: {error}") | |
| elif status == "canceled": | |
| if add_log: | |
| add_log("researcher", "Research task was canceled") | |
| raise ResearchGatewayError("Task was canceled") | |
| # Log polling status periodically | |
| if add_log and elapsed > 0 and elapsed % 5 == 0: | |
| add_log("researcher", f"Polling Research Service... ({int(elapsed)}s elapsed)") | |
| await asyncio.sleep(A2A_POLL_INTERVAL) | |
| elapsed += A2A_POLL_INTERVAL | |
| raise ResearchGatewayError(f"Task timed out after {timeout} seconds") | |
| async def call_research_service( | |
| company: str, | |
| ticker: str = "", | |
| progress_callback: Optional[Callable] = None, | |
| add_log: Optional[Callable] = None | |
| ) -> dict: | |
| """ | |
| High-level function to call Research Service and get results. | |
| Supports real-time partial metrics streaming. | |
| Args: | |
| company: Company name to research | |
| ticker: Optional ticker symbol | |
| progress_callback: Optional callback for granular metrics | |
| add_log: Optional callback for activity logging | |
| Returns: | |
| Research data dict from the Research Service | |
| """ | |
| # Format message | |
| if ticker: | |
| message = f"Research {ticker} {company}" | |
| else: | |
| message = f"Research {company}" | |
| logger.info(f"Calling Research Service at {A2A_RESEARCHER_URL}: {message}") | |
| # Log connection | |
| if add_log: | |
| add_log("researcher", f"Connecting to Research Service...") | |
| add_log("researcher", f"A2A URL: {A2A_RESEARCHER_URL}") | |
| # Check health first | |
| healthy = await check_service_health() | |
| if not healthy: | |
| if add_log: | |
| add_log("researcher", "WARNING: Research Service health check failed, attempting anyway...") | |
| logger.warning("Research Service health check failed") | |
| if add_log: | |
| add_log("researcher", f"A2A handshake successful") | |
| # Send message to start task | |
| if add_log: | |
| add_log("researcher", f"Submitting research task for {company} ({ticker})...") | |
| try: | |
| result = await send_message(message) | |
| except ResearchGatewayError as e: | |
| if add_log: | |
| add_log("researcher", f"A2A request failed: {str(e)}") | |
| raise | |
| task_id = result.get("task", {}).get("id") | |
| if not task_id: | |
| raise ResearchGatewayError("No task ID returned from message/send") | |
| logger.info(f"Task created: {task_id}") | |
| if add_log: | |
| add_log("researcher", f"Task submitted: {task_id[:8]}...") | |
| add_log("researcher", "Fetching data from 6 MCP servers in parallel...") | |
| # Wait for completion with partial metrics streaming | |
| task = await wait_for_completion( | |
| task_id, | |
| progress_callback=progress_callback, | |
| add_log=add_log | |
| ) | |
| # Extract data from artifacts | |
| artifacts = task.get("artifacts", []) | |
| if not artifacts: | |
| raise ResearchGatewayError("No artifacts in completed task") | |
| # Find data artifact | |
| for artifact in artifacts: | |
| if artifact.get("type") == "data": | |
| data = artifact.get("data", {}) | |
| # Log sources | |
| if add_log: | |
| sources = data.get("sources_available", []) | |
| failed = data.get("sources_failed", []) | |
| add_log("researcher", f"Sources available: {', '.join(sources)}") | |
| if failed: | |
| add_log("researcher", f"Sources failed: {', '.join(failed)}") | |
| return data | |
| raise ResearchGatewayError("No data artifact found in response") | |
| async def check_service_health() -> bool: | |
| """ | |
| Check if Research Service is healthy. | |
| Returns: | |
| True if server is healthy, False otherwise | |
| """ | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get( | |
| f"{A2A_RESEARCHER_URL}/health", | |
| timeout=10 | |
| ) | |
| data = response.json() | |
| return data.get("status") == "healthy" | |
| except Exception as e: | |
| logger.warning(f"Health check failed: {e}") | |
| return False | |
| async def get_agent_card() -> Optional[dict]: | |
| """ | |
| Fetch the agent card from the Research Service. | |
| Returns: | |
| Agent card dict or None if unavailable | |
| """ | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get( | |
| f"{A2A_RESEARCHER_URL}/.well-known/agent.json", | |
| timeout=10 | |
| ) | |
| return response.json() | |
| except Exception: | |
| return None | |
| # Synchronous wrapper for LangGraph node | |
| def call_research_service_sync( | |
| company: str, | |
| ticker: str = "", | |
| progress_callback: Optional[Callable] = None, | |
| add_log: Optional[Callable] = None | |
| ) -> dict: | |
| """ | |
| Synchronous wrapper for call_research_service. | |
| Use this in LangGraph nodes that don't support async. | |
| """ | |
| return asyncio.run(call_research_service(company, ticker, progress_callback, add_log)) | |
| # Backward compatibility aliases | |
| A2AClientError = ResearchGatewayError | |
| call_researcher_a2a = call_research_service | |
| call_researcher_sync = call_research_service_sync | |
| check_researcher_health = check_service_health | |