Spaces:
Sleeping
Sleeping
File size: 10,881 Bytes
0c591a7 e3dd83d 5336338 0c591a7 5336338 0c591a7 5336338 0c591a7 176e3c1 0c591a7 be3e4c5 0c591a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 | """
Research Gateway
A2A client for calling the Research Service (Researcher-Agent) via Google A2A protocol.
Acts as the gateway between the main SWOT Agent and the external Research Service.
Supports real-time partial metrics streaming during task execution.
"""
import asyncio
import logging
import os
from typing import Optional, Callable, Set
import httpx
logger = logging.getLogger("research-gateway")
# Research Service configuration - defaults to HuggingFace Spaces deployment
A2A_RESEARCHER_URL = os.getenv(
"A2A_RESEARCHER_URL",
"https://vn6295337-researcher-agent.hf.space"
)
A2A_TIMEOUT = float(os.getenv("A2A_TIMEOUT", "120")) # seconds (increased for remote calls)
A2A_POLL_INTERVAL = float(os.getenv("A2A_POLL_INTERVAL", "1")) # seconds
class ResearchGatewayError(Exception):
"""Error communicating with Research Service."""
pass
async def send_message(message_text: str) -> dict:
"""
Send message/send request to start a research task.
Args:
message_text: Text message like "Research Tesla"
Returns:
Task info dict with task ID
"""
async with httpx.AsyncClient() as client:
request = {
"jsonrpc": "2.0",
"id": 1,
"method": "message/send",
"params": {
"message": {
"parts": [{"type": "text", "text": message_text}]
}
}
}
try:
response = await client.post(
A2A_RESEARCHER_URL,
json=request,
timeout=30
)
data = response.json()
if "error" in data:
raise ResearchGatewayError(f"A2A error: {data['error']}")
return data.get("result", {})
except httpx.RequestError as e:
raise ResearchGatewayError(f"Connection error to {A2A_RESEARCHER_URL}: {e}")
async def get_task_status(task_id: str) -> dict:
"""
Get task status via tasks/get request.
Args:
task_id: Task ID from message/send response
Returns:
Task status dict including partial_metrics (if WORKING) or artifacts (if COMPLETED)
"""
async with httpx.AsyncClient() as client:
request = {
"jsonrpc": "2.0",
"id": 1,
"method": "tasks/get",
"params": {"taskId": task_id}
}
try:
response = await client.post(
A2A_RESEARCHER_URL,
json=request,
timeout=30
)
data = response.json()
if "error" in data:
raise ResearchGatewayError(f"A2A error: {data['error']}")
return data.get("result", {}).get("task", {})
except httpx.RequestError as e:
raise ResearchGatewayError(f"Connection error: {e}")
async def wait_for_completion(
task_id: str,
timeout: float = None,
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
Poll task status until completed or failed.
Emits partial_metrics via progress_callback during WORKING status.
Args:
task_id: Task ID to poll
timeout: Max seconds to wait (default: A2A_TIMEOUT)
progress_callback: Optional callback for granular metrics (source, metric, value)
add_log: Optional callback for activity logging (step, message)
Returns:
Completed task dict with artifacts
"""
if timeout is None:
timeout = A2A_TIMEOUT
elapsed = 0
emitted_metrics: Set[str] = set() # Track which metrics we've already emitted
while elapsed < timeout:
task = await get_task_status(task_id)
status = task.get("status")
# Emit partial metrics during WORKING or COMPLETED status
# Important: Also process on COMPLETED to catch metrics from final sources
if (status in ("working", "completed")) and progress_callback:
partial_metrics = task.get("partial_metrics", []) or []
for metric in partial_metrics:
# Skip None or invalid metrics
if not metric or not isinstance(metric, dict):
continue
# Create unique key to avoid duplicate emissions
source = metric.get("source")
metric_name = metric.get("metric")
value = metric.get("value")
if not source or not metric_name:
continue
metric_key = f"{source}:{metric_name}:{value}"
if metric_key not in emitted_metrics:
# Pass structured payload dict (matches Researcher-Agent emit_metric)
payload = {
"source": source,
"metric": metric_name,
"value": value,
"end_date": metric.get("end_date"),
"fiscal_year": metric.get("fiscal_year"),
"form": metric.get("form"),
}
progress_callback(payload)
emitted_metrics.add(metric_key)
if status == "completed":
if add_log:
sources = len(task.get("artifacts", [{}])[0].get("data", {}).get("sources_available", []))
add_log("researcher", f"Research completed: {sources} sources aggregated")
return task
elif status == "failed":
error = task.get("error", {}).get("message", "Unknown error")
if add_log:
add_log("researcher", f"Research failed: {error}")
raise ResearchGatewayError(f"Task failed: {error}")
elif status == "canceled":
if add_log:
add_log("researcher", "Research task was canceled")
raise ResearchGatewayError("Task was canceled")
# Log polling status periodically
if add_log and elapsed > 0 and elapsed % 5 == 0:
add_log("researcher", f"Polling Research Service... ({int(elapsed)}s elapsed)")
await asyncio.sleep(A2A_POLL_INTERVAL)
elapsed += A2A_POLL_INTERVAL
raise ResearchGatewayError(f"Task timed out after {timeout} seconds")
async def call_research_service(
company: str,
ticker: str = "",
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
High-level function to call Research Service and get results.
Supports real-time partial metrics streaming.
Args:
company: Company name to research
ticker: Optional ticker symbol
progress_callback: Optional callback for granular metrics
add_log: Optional callback for activity logging
Returns:
Research data dict from the Research Service
"""
# Format message
if ticker:
message = f"Research {ticker} {company}"
else:
message = f"Research {company}"
logger.info(f"Calling Research Service at {A2A_RESEARCHER_URL}: {message}")
# Log connection
if add_log:
add_log("researcher", f"Connecting to Research Service...")
add_log("researcher", f"A2A URL: {A2A_RESEARCHER_URL}")
# Check health first
healthy = await check_service_health()
if not healthy:
if add_log:
add_log("researcher", "WARNING: Research Service health check failed, attempting anyway...")
logger.warning("Research Service health check failed")
if add_log:
add_log("researcher", f"A2A handshake successful")
# Send message to start task
if add_log:
add_log("researcher", f"Submitting research task for {company} ({ticker})...")
try:
result = await send_message(message)
except ResearchGatewayError as e:
if add_log:
add_log("researcher", f"A2A request failed: {str(e)}")
raise
task_id = result.get("task", {}).get("id")
if not task_id:
raise ResearchGatewayError("No task ID returned from message/send")
logger.info(f"Task created: {task_id}")
if add_log:
add_log("researcher", f"Task submitted: {task_id[:8]}...")
add_log("researcher", "Fetching data from 6 MCP servers in parallel...")
# Wait for completion with partial metrics streaming
task = await wait_for_completion(
task_id,
progress_callback=progress_callback,
add_log=add_log
)
# Extract data from artifacts
artifacts = task.get("artifacts", [])
if not artifacts:
raise ResearchGatewayError("No artifacts in completed task")
# Find data artifact
for artifact in artifacts:
if artifact.get("type") == "data":
data = artifact.get("data", {})
# Log sources
if add_log:
sources = data.get("sources_available", [])
failed = data.get("sources_failed", [])
add_log("researcher", f"Sources available: {', '.join(sources)}")
if failed:
add_log("researcher", f"Sources failed: {', '.join(failed)}")
return data
raise ResearchGatewayError("No data artifact found in response")
async def check_service_health() -> bool:
"""
Check if Research Service is healthy.
Returns:
True if server is healthy, False otherwise
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{A2A_RESEARCHER_URL}/health",
timeout=10
)
data = response.json()
return data.get("status") == "healthy"
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
async def get_agent_card() -> Optional[dict]:
"""
Fetch the agent card from the Research Service.
Returns:
Agent card dict or None if unavailable
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{A2A_RESEARCHER_URL}/.well-known/agent.json",
timeout=10
)
return response.json()
except Exception:
return None
# Synchronous wrapper for LangGraph node
def call_research_service_sync(
company: str,
ticker: str = "",
progress_callback: Optional[Callable] = None,
add_log: Optional[Callable] = None
) -> dict:
"""
Synchronous wrapper for call_research_service.
Use this in LangGraph nodes that don't support async.
"""
return asyncio.run(call_research_service(company, ticker, progress_callback, add_log))
# Backward compatibility aliases
A2AClientError = ResearchGatewayError
call_researcher_a2a = call_research_service
call_researcher_sync = call_research_service_sync
check_researcher_health = check_service_health
|