Spaces:
Runtime error
Runtime error
| """Main LangGraph workflow for market intelligence.""" | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from src.workflows.types import IntelligenceState, ResearchType | |
| from src.agents.researcher import ResearchAgent | |
| from src.agents.analyst import AnalysisAgent | |
| from src.agents.writer import WriterAgent | |
| from src.utils.cost_tracker import CostTracker, BudgetExceededError | |
| from src.utils.logging import setup_logger | |
| logger = setup_logger(__name__) | |
| class MarketIntelligenceWorkflow: | |
| """ | |
| LangGraph workflow orchestrating research, analysis, and writing agents. | |
| Features: | |
| - Multi-agent coordination | |
| - State persistence with checkpointing | |
| - Cost tracking and budget enforcement | |
| - Human-in-the-loop approval | |
| - Error recovery | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str = "./checkpoints.db", | |
| max_budget: float = 2.0, | |
| model_name: str | None = None, | |
| ): | |
| """ | |
| Initialize workflow. | |
| Args: | |
| checkpoint_path: Path to SQLite checkpoint database | |
| max_budget: Maximum cost per run in USD | |
| model_name: Name of the LLM model to use | |
| """ | |
| self.max_budget = max_budget | |
| self.cost_tracker = CostTracker() | |
| self.checkpoint_path = checkpoint_path | |
| self.model_name = model_name | |
| # Initialize agents (shared cost tracker) | |
| self.research_agent = ResearchAgent( | |
| cost_tracker=self.cost_tracker, model=model_name | |
| ) | |
| self.analysis_agent = AnalysisAgent( | |
| cost_tracker=self.cost_tracker, model=model_name | |
| ) | |
| self.writer_agent = WriterAgent( | |
| cost_tracker=self.cost_tracker, model=model_name | |
| ) | |
| # Build workflow graph blueprint | |
| self.graph_builder = self._build_graph() | |
| logger.info("Market Intelligence Workflow initialized") | |
| def _build_graph(self) -> StateGraph: | |
| """Build LangGraph workflow.""" | |
| # Initialize graph | |
| graph = StateGraph(IntelligenceState) | |
| # Add nodes (agent wrappers) | |
| graph.add_node("research", self._research_node) | |
| graph.add_node("analysis", self._analysis_node) | |
| graph.add_node("writing", self._writing_node) | |
| graph.add_node("human_review", self._human_review_node) | |
| # Set entry point | |
| graph.set_entry_point("research") | |
| # Add edges | |
| graph.add_conditional_edges( | |
| "research", | |
| self._should_continue_to_analysis, | |
| {"analysis": "analysis", "end": END}, | |
| ) | |
| graph.add_edge("analysis", "writing") | |
| graph.add_edge("writing", "human_review") | |
| graph.add_conditional_edges( | |
| "human_review", | |
| self._check_approval, | |
| {"approved": END, "revise": "research", "max_revisions": END}, | |
| ) | |
| return graph | |
| async def _research_node(self, state: IntelligenceState) -> dict: | |
| """Research agent node.""" | |
| logger.info(f"Research node: {state['company_name']}") | |
| try: | |
| # Run research agent | |
| research_results = await self.research_agent.run( | |
| company_name=state["company_name"], | |
| industry=state.get("industry"), | |
| research_depth=state.get("research_depth", "comprehensive"), | |
| ) | |
| # Update state | |
| return { | |
| "current_agent": "research", | |
| "research_data": research_results, | |
| "competitors": research_results.get("competitors", ""), | |
| "market_trends": research_results.get("market_trends", ""), | |
| "raw_sources": research_results.get("raw_sources", []), | |
| "iteration": state.get("iteration", 0) + 1, | |
| } | |
| except Exception as e: | |
| logger.error(f"Research node failed: {e}") | |
| return { | |
| "errors": [f"Research failed: {str(e)}"], | |
| "current_agent": "research", | |
| } | |
| async def _analysis_node(self, state: IntelligenceState) -> dict: | |
| """Analysis agent node.""" | |
| logger.info(f"Analysis node: {state['company_name']}") | |
| try: | |
| # Check budget before expensive analysis | |
| self.cost_tracker.check_budget(self.max_budget) | |
| # Run analysis agent | |
| analysis_results = await self.analysis_agent.run( | |
| research_data=state["research_data"] | |
| ) | |
| # Update state | |
| return { | |
| "current_agent": "analysis", | |
| "swot": analysis_results.get("swot", ""), | |
| "competitive_matrix": analysis_results.get("competitive_matrix", ""), | |
| "positioning": analysis_results.get("positioning", ""), | |
| "strategic_recommendations": analysis_results.get( | |
| "strategic_recommendations", "" | |
| ), | |
| } | |
| except BudgetExceededError as e: | |
| logger.error(f"Budget exceeded: {e}") | |
| return { | |
| "errors": [f"Budget exceeded: {str(e)}"], | |
| "current_agent": "analysis", | |
| } | |
| except Exception as e: | |
| logger.error(f"Analysis node failed: {e}") | |
| return { | |
| "errors": [f"Analysis failed: {str(e)}"], | |
| "current_agent": "analysis", | |
| } | |
| async def _writing_node(self, state: IntelligenceState) -> dict: | |
| """Writer agent node.""" | |
| logger.info(f"Writing node: {state['company_name']}") | |
| try: | |
| # Run writer agent | |
| report_results = await self.writer_agent.run( | |
| research_data=state["research_data"], | |
| analysis_data={ | |
| "company_name": state["company_name"], | |
| "swot": state.get("swot", ""), | |
| "competitive_matrix": state.get("competitive_matrix", ""), | |
| "positioning": state.get("positioning", ""), | |
| "strategic_recommendations": state.get( | |
| "strategic_recommendations", "" | |
| ), | |
| }, | |
| ) | |
| # Get cost summary | |
| cost_summary = self.cost_tracker.get_summary() | |
| # Update state | |
| return { | |
| "current_agent": "writing", | |
| "executive_summary": report_results.get("executive_summary", ""), | |
| "full_report": report_results.get("full_report", ""), | |
| "report_metadata": report_results.get("metadata", {}), | |
| "total_cost": cost_summary["total_cost"], | |
| "total_tokens": cost_summary["total_tokens"], | |
| } | |
| except Exception as e: | |
| logger.error(f"Writing node failed: {e}") | |
| return { | |
| "errors": [f"Writing failed: {str(e)}"], | |
| "current_agent": "writing", | |
| } | |
| async def _human_review_node(self, state: IntelligenceState) -> dict: | |
| """Human review node (placeholder for now).""" | |
| logger.info(f"Human review node: {state['company_name']}") | |
| # For now, auto-approve | |
| # In Phase 5, this will connect to the Gradio UI | |
| return { | |
| "current_agent": "human_review", | |
| "approved": True, # Auto-approve for testing | |
| "human_feedback": None, | |
| } | |
| def _should_continue_to_analysis(self, state: IntelligenceState) -> str: | |
| """Decide whether to continue to analysis or end.""" | |
| # Check if research was successful | |
| if state.get("errors") and state["errors"]: | |
| logger.warning("Research had errors, ending workflow") | |
| return "end" | |
| if not state.get("research_data"): | |
| logger.warning("No research data, ending workflow") | |
| return "end" | |
| return "analysis" | |
| def _check_approval(self, state: IntelligenceState) -> str: | |
| """Check if report is approved or needs revision.""" | |
| # Check max revisions | |
| revision_count = state.get("revision_count", 0) | |
| if revision_count >= 2: | |
| logger.warning("Max revisions reached") | |
| return "max_revisions" | |
| # Check approval | |
| if state.get("approved"): | |
| return "approved" | |
| # Revision requested | |
| if state.get("human_feedback"): | |
| return "revise" | |
| # Default to approved | |
| return "approved" | |
| async def run( | |
| self, | |
| company_name: str, | |
| industry: str | None = None, | |
| thread_id: str | None = None, | |
| research_depth: str = "comprehensive", | |
| research_type: ResearchType = ResearchType.COMPANY_ANALYSIS, | |
| ) -> dict: | |
| """ | |
| Run the complete workflow. | |
| Args: | |
| company_name: Target company name | |
| industry: Optional industry context | |
| thread_id: Optional thread ID for checkpointing | |
| Returns: | |
| Final state dictionary | |
| """ | |
| logger.info(f"Starting workflow for: {company_name}") | |
| # Initial state | |
| initial_state: IntelligenceState = { | |
| "research_type": research_type, | |
| "company_name": company_name, | |
| "industry": industry, | |
| "research_depth": research_depth, | |
| "research_data": { | |
| "company_name": company_name, | |
| "industry": industry, | |
| "company_overview": "", | |
| "competitors": "", | |
| "market_trends": "", | |
| "raw_sources": [], | |
| }, | |
| "competitors": "", | |
| "market_trends": "", | |
| "raw_sources": [], | |
| "swot": "", | |
| "competitive_matrix": "", | |
| "positioning": "", | |
| "strategic_recommendations": "", | |
| "executive_summary": "", | |
| "full_report": "", | |
| "report_metadata": {}, | |
| "current_agent": "research", | |
| "iteration": 0, | |
| "total_cost": 0.0, | |
| "total_tokens": 0, | |
| "errors": [], | |
| "human_feedback": None, | |
| "approved": False, | |
| "revision_count": 0, | |
| } | |
| # Run workflow with async checkpointer | |
| config = {"configurable": {"thread_id": thread_id or "default"}} | |
| try: | |
| if self.checkpoint_path == ":memory:": | |
| memory_checkpointer = MemorySaver() | |
| workflow = self.graph_builder.compile(checkpointer=memory_checkpointer) | |
| final_state = await workflow.ainvoke(initial_state, config) # type: ignore[arg-type] | |
| else: | |
| async with AsyncSqliteSaver.from_conn_string( | |
| self.checkpoint_path | |
| ) as checkpointer: | |
| workflow = self.graph_builder.compile(checkpointer=checkpointer) | |
| final_state = await workflow.ainvoke(initial_state, config) # type: ignore[arg-type] | |
| logger.info(f"Workflow complete. Cost: ${final_state['total_cost']:.4f}") | |
| return final_state | |
| except Exception as e: | |
| logger.error(f"Workflow failed: {e}") | |
| raise | |