Spaces:
Sleeping
Sleeping
File size: 6,818 Bytes
0c591a7 5766b78 0c591a7 176e3c1 0c591a7 176e3c1 67b4836 0c591a7 dc70069 0c591a7 5336338 0c591a7 5766b78 0c591a7 5766b78 0c591a7 5766b78 0c591a7 be3e4c5 0c591a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
Research Gateway Node
Fetches data from the Research Service via A2A protocol.
The Research Service internally calls all 6 MCP servers using TRUE MCP protocol.
This node acts as the gateway between the main SWOT Agent and the external Research Service.
"""
import asyncio
import json
from langsmith import traceable
from src.utils.ticker_lookup import get_ticker, normalize_company_name
async def _fetch_via_research_gateway(company: str, ticker: str = None, progress_callback=None, add_log=None) -> dict:
"""Async helper to fetch data via Research Gateway (A2A protocol)."""
from src.nodes.research_gateway import call_research_service
# Use provided ticker or lookup from company name
if not ticker:
ticker = get_ticker(company)
if not ticker:
print(f"Could not determine ticker for '{company}', using company name as ticker")
ticker = company.upper().replace(" ", "")[:5]
# Normalize company name for display
company_name = normalize_company_name(company)
print(f"Calling Research Service for {company_name} ({ticker})...")
# Call Research Service with callbacks for real-time streaming
result = await call_research_service(
company_name,
ticker,
progress_callback=progress_callback,
add_log=add_log
)
return result
@traceable(name="Researcher")
def researcher_node(state, workflow_id=None, progress_store=None):
"""
Research Gateway node that fetches data via A2A protocol.
Calls the external Research Service which internally fetches from 6 MCP servers:
Fundamentals, Volatility, Macro, Valuation, News, Sentiment
"""
company = state["company_name"]
ticker = state.get("ticker") # Use ticker from stock search if available
# Extract workflow_id and progress_store from state (graph invokes with state only)
if workflow_id is None:
workflow_id = state.get("workflow_id")
if progress_store is None:
progress_store = state.get("progress_store")
print(f"[DEBUG] researcher_node: workflow_id={workflow_id}, progress_store={'yes' if progress_store else 'no'}")
# Update progress if tracking is enabled
if workflow_id and progress_store:
progress_store[workflow_id].update({
"current_step": "researcher",
"revision_count": state.get("revision_count", 0),
"score": state.get("score", 0)
})
# Helper to add activity log
def add_log(step: str, message: str):
if workflow_id and progress_store:
from src.services.workflow_store import add_activity_log
add_activity_log(workflow_id, step, message)
# Create progress callback for granular metric events
# Supports both dict payload (new) and positional args (legacy)
def progress_callback(*args, **kwargs):
if args and isinstance(args[0], dict):
# New structured payload format
p = args[0]
src = p.get("source")
metric = p.get("metric")
value = p.get("value")
end_date = p.get("end_date")
fiscal_year = p.get("fiscal_year")
form = p.get("form")
else:
# Legacy positional args format
src = args[0] if len(args) > 0 else kwargs.get("source")
metric = args[1] if len(args) > 1 else kwargs.get("metric")
value = args[2] if len(args) > 2 else kwargs.get("value")
end_date = args[3] if len(args) > 3 else kwargs.get("end_date")
fiscal_year = args[4] if len(args) > 4 else kwargs.get("fiscal_year")
form = args[5] if len(args) > 5 else kwargs.get("form")
if workflow_id and progress_store and src and metric:
from src.services.workflow_store import add_metric
add_metric(workflow_id, src, metric, value,
end_date=end_date, fiscal_year=fiscal_year, form=form)
try:
# Set all MCP servers to "executing" state before research starts
if workflow_id and progress_store:
from src.services.workflow_store import set_mcp_executing
set_mcp_executing(workflow_id)
# Fetch via Research Gateway (A2A protocol)
print("[Research Gateway] Calling Research Service via A2A...")
result = asyncio.run(_fetch_via_research_gateway(
company,
ticker,
progress_callback=progress_callback,
add_log=add_log
))
# Validate result
if not result or not isinstance(result, dict):
raise RuntimeError(f"Research Service returned invalid data for {company}")
state["data_source"] = "a2a"
# Note: Metrics are streamed via partial_metrics during A2A polling
# Check MCP source availability with tiered logic
# Core sources (need at least 2 of 3): fundamentals, valuation, volatility
# Supplementary sources (non-blocking): macro, news, sentiment
CORE_SOURCES = {"fundamentals", "valuation", "volatility"}
SUPPLEMENTARY_SOURCES = {"macro", "news", "sentiment"}
sources_available = set(result.get("sources_available", []))
sources_failed = result.get("sources_failed", [])
core_available = sources_available & CORE_SOURCES
core_failed = CORE_SOURCES - core_available
supplementary_failed = set(sources_failed) & SUPPLEMENTARY_SOURCES
# Log supplementary failures as non-critical
for source in supplementary_failed:
add_log("researcher", f"{source.capitalize()} unavailable (non-critical)")
# Log core failures as critical
for source in core_failed:
add_log("researcher", f"{source.capitalize()} unavailable (critical)")
# Abort if 2+ core sources failed (need at least 2 of 3)
if len(core_available) < 2:
failed_list = ", ".join(sorted(core_failed))
raise RuntimeError(f"Insufficient core data: {failed_list} unavailable. Need at least 2 of: Fundamentals, Valuation, Volatility.")
if sources_available:
state["raw_data"] = json.dumps(result, indent=2, default=str)
state["sources_failed"] = sources_failed
print(f" - Sources available: {result['sources_available']}")
if sources_failed:
print(f" - Sources failed: {sources_failed}")
else:
# All MCPs failed - raise error
raise RuntimeError(f"All MCP servers failed for {company}. Check API configurations.")
except Exception as e:
error_msg = str(e)
print(f"Research failed: {error_msg}")
add_log("researcher", f"ERROR: {error_msg}")
raise RuntimeError(f"Research failed for {company}: {error_msg}")
return state
|